Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
d19d92d
refactor turbomind engine
lzhangzz Dec 19, 2025
a7ab1e5
simplify interface
lzhangzz Dec 22, 2025
e134666
minor
lzhangzz Dec 22, 2025
af46169
metrics
lzhangzz Dec 22, 2025
bedc618
refactor & logprobs
lzhangzz Dec 24, 2025
3d558e0
fix output logits
lzhangzz Dec 24, 2025
9cc59f8
fix logprobs
lzhangzz Dec 24, 2025
66be4db
rename
lzhangzz Dec 24, 2025
4c7c22a
mrope
lzhangzz Dec 25, 2025
08a0b89
Merge remote-tracking branch 'origin/main' into engine2a
lzhangzz Dec 26, 2025
5fdd37b
fix cuda-12.4 build
lzhangzz Dec 29, 2025
0a1ccdc
ix cuda-12.4 build
lzhangzz Dec 29, 2025
6223b2f
fix cuda-12.4 build
lzhangzz Dec 29, 2025
1066193
fix MSVC build
lzhangzz Dec 29, 2025
22ae933
fix MSVC build
lzhangzz Dec 29, 2025
299f595
fix guided decoding
lzhangzz Dec 29, 2025
4f3f3d6
fix warm-up for TP
lzhangzz Dec 29, 2025
0d08346
fix VLMs
lzhangzz Dec 30, 2025
2a0e0ad
refactor DP
lzhangzz Jan 5, 2026
581bc7b
remove redundant `rank` parameter
lzhangzz Jan 5, 2026
7bd3545
add `no queue` error & fix lint
lzhangzz Jan 5, 2026
620d712
fix vocab size
lzhangzz Jan 5, 2026
708419f
fix attn output for finished seqs
lzhangzz Jan 6, 2026
a541990
fix lint
lzhangzz Jan 7, 2026
949060c
fix lint
lzhangzz Jan 7, 2026
05e7059
add async flag
lzhangzz Jan 7, 2026
1165254
fix prefix caching
lzhangzz Jan 8, 2026
6fdeadd
minor fix
lzhangzz Jan 8, 2026
58c5cd1
Merge remote-tracking branch 'origin/main' into engine2a
lzhangzz Jan 8, 2026
5454038
fix lint
lzhangzz Jan 8, 2026
85dad95
fix lint
lzhangzz Jan 8, 2026
065c46a
fix typo
lzhangzz Jan 9, 2026
db86948
fix log level
lzhangzz Jan 9, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,9 @@ FetchContent_MakeAvailable(repo-cutlass)
FetchContent_Declare(
yaml-cpp
GIT_REPOSITORY https://github.com/jbeder/yaml-cpp.git
GIT_TAG 0.8.0
GIT_TAG 65c1c270dbe7eec37b2df2531d7497c4eea79aee
GIT_PROGRESS TRUE
USES_TERMINAL_DOWNLOAD TRUE
PATCH_COMMAND git apply ${CMAKE_CURRENT_SOURCE_DIR}/cmake/yaml-cpp_cmake_policy.patch
UPDATE_DISCONNECTED 1
)
set(YAML_BUILD_SHARED_LIBS OFF CACHE BOOL "Build static library of yaml-cpp")
FetchContent_MakeAvailable(yaml-cpp)
Expand All @@ -87,7 +85,6 @@ FetchContent_Declare(
GIT_SUBMODULES "3rdparty/dlpack"
GIT_PROGRESS TRUE
USES_TERMINAL_DOWNLOAD TRUE
UPDATE_DISCONNECTED 1
)

FetchContent_GetProperties(xgrammar)
Expand All @@ -110,6 +107,7 @@ endif()

# the environment variable
# ASAN_OPTIONS=protect_shadow_gap=0,intercept_tls_get_addr=0
# LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libasan.so.6:/usr/lib/x86_64-linux-gnu/libstdc++.so.6
# must be set at runtime
# https://github.com/google/sanitizers/issues/1322
if (LMDEPLOY_ASAN_ENABLE)
Expand Down
18 changes: 6 additions & 12 deletions lmdeploy/turbomind/turbomind.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def _create_weight(self, model_comm):
# create weight
def _create_weight_func(device_id):
rank = self.node_id * self.gpu_count + device_id
model_comm.create_shared_weights(device_id, rank)
model_comm.create_weights(device_id, rank)

with ThreadPoolExecutor(max_workers=self.gpu_count) as executor:
futures = []
Expand All @@ -234,7 +234,7 @@ def _get_model_params(self):

def _get_params(device_id, que):
rank = self.node_id * self.gpu_count + device_id
out = model_comm.get_params(device_id, rank)
out = model_comm.get_weights(device_id, rank)
que.put(out)

que = Queue()
Expand Down Expand Up @@ -266,12 +266,6 @@ def _postprocess_config(self, tm_config: TurbomindModelConfig, engine_config: Tu
# update some attributes of `engine_config` which depends on
# `session_len`
self.engine_config = engine_config
if engine_config.max_prefill_token_num is not None \
and engine_config.num_tokens_per_iter == 0:
self.engine_config.num_tokens_per_iter = \
engine_config.max_prefill_token_num
self.engine_config.max_prefill_iters = (self.config.session_len + engine_config.max_prefill_token_num -
1) // engine_config.max_prefill_token_num

# pack `self.config` and `self.engine_config` into a dict
self.config_dict = self.config.to_dict()
Expand All @@ -290,9 +284,9 @@ def _from_hf(self, model_path: str, engine_config: TurbomindEngineConfig):

self._postprocess_config(tm_model.tm_config, engine_config)

model_comm = _tm.AbstractTransformerModel.create_llama_model(model_dir='',
config=yaml.safe_dump(self.config_dict),
weight_type=self.config.model_config.weight_type)
model_comm = _tm.TurboMind.create(model_dir='',
config=yaml.safe_dump(self.config_dict),
weight_type=self.config.model_config.weight_type)

# create empty weight
self._create_weight(model_comm)
Expand Down Expand Up @@ -574,7 +568,7 @@ def model_inst(self):
return self._model_inst

def _create_model_instance(self, device_id):
model_inst = self.tm_model.model_comm.create_model_instance(device_id)
model_inst = self.tm_model.model_comm.create_request(device_id)
return model_inst

def _get_extra_output_processors(self, outputs: Dict[str, torch.Tensor], gen_config: GenerationConfig,
Expand Down
22 changes: 17 additions & 5 deletions src/turbomind/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,26 @@
add_subdirectory(utils)
add_subdirectory(core)
add_subdirectory(kernels)
add_subdirectory(layers)
add_subdirectory(comm)
add_subdirectory(generation)
add_subdirectory(models)
add_subdirectory(engine)
if(BUILD_PYT)
add_subdirectory(th_op)
endif()

if(BUILD_PY_FFI)
add_subdirectory(python)
endif()
add_subdirectory(triton_backend)

add_library(turbomind STATIC turbomind.cc)
set_property(TARGET turbomind PROPERTY POSITION_INDEPENDENT_CODE ON)
target_link_libraries(turbomind PUBLIC
engine
models
device_comm
host_comm
core
memory_utils
nvtx_utils
CUDA::cublasLt
CUDA::cudart
yaml-cpp::yaml-cpp)

3 changes: 2 additions & 1 deletion src/turbomind/core/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ add_library(core STATIC
layout.cc
tensor.cc
tensor.cu
module.cc)
module.cc
copy.cc)

target_link_libraries(core PUBLIC cuda_utils logger CUDA::cudart CUDA::cuda_driver)

Expand Down
6 changes: 5 additions & 1 deletion src/turbomind/core/buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,16 @@ void Copy(const Buffer& a, Ref<Buffer> b_)
Copy(a, b_, Context::stream());
}

namespace detail {

void* Copy(const void* a, ssize_t n, void* b, const Stream& stream)
{
check_cuda_error(cudaMemcpyAsync(b, a, n, cudaMemcpyDefault, stream.handle()));
return (char*)b + n;
return (uint8_t*)b + n;
}

} // namespace detail

void Clear(Ref<Buffer> b_, const Stream& stream)
{
auto& b = b_.get();
Expand Down
49 changes: 36 additions & 13 deletions src/turbomind/core/buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,16 +161,20 @@ inline bool operator!=(const Buffer& a, const Buffer& b)
return !(a == b);
}

///////////////////////////////////////////////////////////
// fill

void Fill(Buffer& b, const void* v);

void Fill(Buffer&& b, const void* v);
inline Buffer empty_like(const Buffer& buffer)
{
return Buffer{buffer.size(), buffer.dtype(), buffer.device()};
}

void Fill(Buffer& b, const void* v, const Stream& stream);
inline Buffer empty_like(const Buffer& buffer, Device device)
{
return Buffer{buffer.size(), buffer.dtype(), device};
}

void Fill(Buffer&& b, const void* v, const Stream& stream);
inline Buffer empty_like(const Buffer& buffer, DataType dtype)
{
return Buffer{buffer.size(), dtype, buffer.device()};
}

template<class T>
struct Buffer_: public Buffer {
Expand All @@ -187,10 +191,10 @@ struct Buffer_: public Buffer {

Buffer_(ssize_t size, Device device): Buffer{size, data_type_v<T>, device} {}

Buffer_(const Buffer_&) = default;
Buffer_(const Buffer_&) = default;
Buffer_& operator=(const Buffer_&) = default;

Buffer_(Buffer_&&) noexcept = default;
Buffer_(Buffer_&&) noexcept = default;
Buffer_& operator=(Buffer_&&) noexcept = default;

Buffer_(const Buffer& b)
Expand Down Expand Up @@ -284,7 +288,7 @@ struct Buffer_: public Buffer {
static decltype(auto) ensure_dtype(U&& u) noexcept
{
TM_CHECK_EQ(u.dtype(), data_type_v<T>);
return (U &&) u;
return (U&&)u;
}
};

Expand Down Expand Up @@ -323,24 +327,43 @@ inline void Copy_(const Buffer_<T>& a, ssize_t n, Buffer_<T>& b_)
Copy((const Buffer&)a, n, (Buffer&)b_);
}

namespace detail {

void* Copy(const void* a, ssize_t n, void* b, const Stream& stream);

} // namespace detail

template<class T>
inline T* Copy(const T* a, ssize_t n, T* b, const Stream& stream)
{
return (T*)Copy((const void*)a, sizeof(T) * n, (void*)b, stream);
return (T*)detail::Copy((const void*)a, sizeof(T) * n, (void*)b, stream);
}

template<class T>
inline T* Copy(const T* a, ssize_t n, T* b)
{
return Copy(a, n, b, Context::stream());
return (T*)detail::Copy((const void*)a, sizeof(T) * n, (void*)b, Context::stream());
}

struct CopyT {
template<class... Args>
auto operator()(Args&&... args) const
{
return Copy(((Args&&)args)...);
}
};

void Clear(Ref<Buffer> b_, const Stream& stream);

void Clear(Ref<Buffer> b_);

template<class T>
std::vector<T> to_vector(const Buffer_<T>& b)
{
TM_CHECK(b.device().type == kCPU || b.device().type == kCPUpinned);
return std::vector<T>(b.begin(), b.end());
}

// clang-format off
template<class Archive>
void save(Archive& ar, const Buffer& buffer)
Expand Down
97 changes: 97 additions & 0 deletions src/turbomind/core/copy.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@

#include "src/turbomind/core/copy.h"

#include <cstdint>
#include <type_traits>
#include <variant>

#include <cuda_runtime.h>
#include <driver_types.h>

#include "src/turbomind/core/check.h"

namespace turbomind::core {

// picked from "cudaTypedefs.h"
typedef CUresult(CUDAAPI* PFN_cuMemcpyBatchAsync_v12080)(CUdeviceptr_v2* dsts,
CUdeviceptr_v2* srcs,
size_t* sizes,
size_t count,
CUmemcpyAttributes_v1* attrs,
size_t* attrIdxs,
size_t numAttrs,
size_t* failIdx,
CUstream hStream);

/// TODO: add `PFN_cuMemcpyBatchAsync_v13000`

namespace {

const auto& GetCopyAPI()
{
static auto inst = []() -> std::variant<std::monostate, PFN_cuMemcpyBatchAsync_v12080> {
const auto symbol = "cuMemcpyBatchAsync";
cudaDriverEntryPointQueryResult status{};
void* fpn{};
TM_CHECK_EQ(cudaGetDriverEntryPoint(symbol, &fpn, cudaEnableDefault, &status), 0);
if (fpn && status == cudaDriverEntryPointSuccess) {
return (PFN_cuMemcpyBatchAsync_v12080)fpn;
}
else {
return {};
}
}();
return inst;
}

} // namespace

BatchCopy::~BatchCopy() = default;

BatchCopy::BatchCopy(): self_{this}
{
Reset();
}

void BatchCopy::Run()
{
if (src_.empty()) {
return;
}

std::visit(
[&](auto&& copy) {
using T = std::decay_t<decltype(copy)>;
if constexpr (std::is_same_v<T, PFN_cuMemcpyBatchAsync_v12080>) {
CUmemcpyAttributes_v1 attr{};
attr.srcAccessOrder = CU_MEMCPY_SRC_ACCESS_ORDER_STREAM;
attr.flags = CU_MEMCPY_FLAG_PREFER_OVERLAP_WITH_COMPUTE;
std::vector<size_t> ais(src_.size(), 0);
size_t fail_idx{SIZE_MAX};

auto status = copy((CUdeviceptr_v2*)dst_.data(),
(CUdeviceptr_v2*)src_.data(),
size_.data(),
src_.size(),
&attr,
ais.data(),
1,
&fail_idx,
core::Context::stream().handle());

if (auto i = fail_idx; i != SIZE_MAX) {
TM_CHECK(0) << (void*)src_[i] << " " << size_[i] << " " << (void*)dst_[i] << " code " << status;
}
}
else {
for (unsigned i = 0; i < src_.size(); ++i) {
core::Copy(src_[i], size_[i], dst_[i]);
}
}
},
GetCopyAPI());

Reset();
}

} // namespace turbomind::core
Loading
Loading