diff --git a/.gitmodules b/.gitmodules index a4254bed..ceffdb5e 100755 --- a/.gitmodules +++ b/.gitmodules @@ -27,4 +27,4 @@ url = https://gitcode.com/xLLM-AI/spdlog.git [submodule "third_party/Mooncake"] path = third_party/Mooncake - url = https://gitcode.com/xLLM-AI/Mooncake.git + url = https://gitcode.com/xLLM-AI/Mooncake.git \ No newline at end of file diff --git a/CMakeLists.txt b/CMakeLists.txt index 645ce0a2..21001a57 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,8 +1,10 @@ cmake_minimum_required(VERSION 3.26) set_property(GLOBAL PROPERTY USE_FOLDERS ON) +set(CMAKE_CUDA_COMPILER "/usr/local/cuda/bin/nvcc") option(USE_NPU "Enable NPU support" OFF) option(USE_MLU "Enable MLU support" OFF) +option(USE_CUDA "Enable CUDA support" OFF) if(DEVICE_ARCH STREQUAL "ARM") set(CMAKE_SYSTEM_PROCESSOR aarch64) @@ -101,7 +103,7 @@ set(CMAKE_CXX_STANDARD 20) set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_CXX_EXTENSIONS ON) -if(USE_NPU) +if(USE_NPU OR USE_CUDA) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D_GLIBCXX_USE_CXX11_ABI=0") add_definitions(-D_GLIBCXX_USE_CXX11_ABI=0) elseif(USE_MLU) @@ -178,6 +180,32 @@ if (DEFINED ENV{DEPENDENCES_ROOT}) message(STATUS "Using DEPENDENCES_ROOT: $ENV{DEPENDENCES_ROOT}") endif() +# set architecture for CUDA +if (NOT DEFINED CMAKE_CUDA_ARCHITECTURES AND USE_CUDA) + set(CMAKE_CUDA_ARCHITECTURES 80) +endif() + +# Build TORCH_CUDA_ARCH_LIST +if(USE_CUDA) + # Build TORCH_CUDA_ARCH_LIST + set(TORCH_CUDA_ARCH_LIST "") + foreach(CUDA_ARCH IN LISTS CMAKE_CUDA_ARCHITECTURES) + if(CUDA_ARCH MATCHES "^([0-9])([0-9])a$") + set(TORCH_ARCH "${CMAKE_MATCH_1}.${CMAKE_MATCH_2}a") + elseif(CUDA_ARCH MATCHES "^([0-9])([0-9])*$") + set(TORCH_ARCH "${CMAKE_MATCH_1}.${CMAKE_MATCH_2}") + elseif(CUDA_ARCH STREQUAL "native") + set(TORCH_ARCH "Auto") + else() + message(FATAL_ERROR "${CUDA_ARCH} is not supported") + endif() + list(APPEND TORCH_CUDA_ARCH_LIST ${TORCH_ARCH}) + endforeach() + + message(STATUS "CMAKE_CUDA_ARCHITECTURES: ${CMAKE_CUDA_ARCHITECTURES}") + message(STATUS "TORCH_CUDA_ARCH_LIST: ${TORCH_CUDA_ARCH_LIST}") +endif() + # configure vcpkg # have to set CMAKE_TOOLCHAIN_FILE before first project call. # if (DEFINED ENV{VCPKG_ROOT} AND NOT DEFINED CMAKE_TOOLCHAIN_FILE) @@ -217,7 +245,12 @@ endif() set(CPPREST_EXCLUDE_WEBSOCKETS ON CACHE BOOL "Exclude websockets functionality." FORCE) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-format-truncation") -project("xllm" LANGUAGES C CXX) +if(USE_CUDA) + project("xllm" LANGUAGES C CXX CUDA) + find_package(CUDAToolkit REQUIRED) +else() + project("xllm" LANGUAGES C CXX) +endif() # find_package(CUDAToolkit REQUIRED) @@ -352,6 +385,43 @@ if(USE_MLU) ) endif() +if(USE_CUDA) + add_definitions(-DUSE_CUDA) + add_compile_definitions(TORCH_CUDA=1) + set(CMAKE_VERBOSE_MAKEFILE ON) + include_directories( + $ENV{PYTHON_INCLUDE_PATH} + $ENV{PYTORCH_INSTALL_PATH}/include + $ENV{PYTORCH_INSTALL_PATH}/include/torch/csrc/api/include + ) + + link_directories( + $ENV{PYTHON_LIB_PATH} + $ENV{PYTORCH_INSTALL_PATH}/lib + $ENV{CUDA_TOOLKIT_ROOT_DIR}/lib64 + ) + + set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS} -O3) + # The following definitions must be undefined since half-precision operation is required. + set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS} + -U__CUDA_NO_HALF_OPERATORS__ + -U__CUDA_NO_HALF_CONVERSIONS__ + -U__CUDA_NO_HALF2_OPERATORS__ + -U__CUDA_NO_BFLOAT16_CONVERSIONS__) + set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS} --use_fast_math -Xfatbin -compress-all) + message(STATUS "CUDA_NVCC_FLAGS: ${CUDA_NVCC_FLAGS}") + + # find_package(NCCL REQUIRED) + + # find cudnn + execute_process(COMMAND python -c "import nvidia.cudnn; print(nvidia.cudnn.__file__)" OUTPUT_VARIABLE CUDNN_PYTHON_PATH) + get_filename_component(CUDNN_ROOT_DIR "${CUDNN_PYTHON_PATH}" DIRECTORY) + link_directories( + ${CUDNN_ROOT_DIR}/lib64 + ${CUDNN_ROOT_DIR}/lib + ) +endif() + # check if USE_CXX11_ABI is set correctly # if (DEFINED USE_CXX11_ABI) # parse_make_options(${TORCH_CXX_FLAGS} "TORCH_CXX_FLAGS") diff --git a/setup.py b/setup.py index 5b43e398..fbe40005 100644 --- a/setup.py +++ b/setup.py @@ -98,7 +98,6 @@ def get_python_include_path(): return None -# PYTORCH_INSTALL_PATH and LIBTORCH_ROOT def get_torch_root_path(): try: import torch @@ -115,6 +114,12 @@ def get_torch_mlu_root_path(): except ImportError: return None +def get_nccl_root_path(): + try: + from nvidia import nccl + return str(Path(nccl.__file__).parent) + except ImportError: + return None def set_npu_envs(): PYTORCH_NPU_INSTALL_PATH = os.getenv("PYTORCH_NPU_INSTALL_PATH") @@ -212,7 +217,16 @@ def set_mlu_envs(): os.environ["LIBTORCH_ROOT"] = get_torch_root_path() os.environ["PYTORCH_INSTALL_PATH"] = get_torch_root_path() os.environ["PYTORCH_MLU_INSTALL_PATH"] = get_torch_mlu_root_path() - + +def set_cuda_envs(): + os.environ["PYTHON_INCLUDE_PATH"] = get_python_include_path() + os.environ["PYTHON_LIB_PATH"] = get_torch_root_path() + os.environ["LIBTORCH_ROOT"] = get_torch_root_path() + os.environ["PYTORCH_INSTALL_PATH"] = get_torch_root_path() + os.environ["CUDA_TOOLKIT_ROOT_DIR"] = "/usr/local/cuda" + os.environ["NCCL_ROOT"] = get_nccl_root_path() + os.environ["NCCL_VERSION"] = "2" + class CMakeExtension(Extension): def __init__(self, name: str, path: str, sourcedir: str = "") -> None: super().__init__(name, sources=[]) @@ -223,7 +237,7 @@ def __init__(self, name: str, path: str, sourcedir: str = "") -> None: class ExtBuild(build_ext): user_options = build_ext.user_options + [ ("base-dir=", None, "base directory of xLLM project"), - ("device=", None, "target device type (a3 or a2 or mlu)"), + ("device=", None, "target device type (a3 or a2 or mlu or cuda)"), ("arch=", None, "target arch type (x86 or arm)"), ("install-xllm-kernels=", None, "install xllm_kernels RPM package (true/false)"), ] @@ -302,8 +316,14 @@ def build_extension(self, ext: CMakeExtension): cmake_args += ["-DUSE_MLU=ON"] # set mlu environment variables set_mlu_envs() + elif self.device == "cuda": + cuda_architectures = "80;89;90" + cmake_args += ["-DUSE_CUDA=ON", + f"-DCMAKE_CUDA_ARCHITECTURES={cuda_architectures}"] + # set cuda environment variables + set_cuda_envs() else: - raise ValueError("Please set --device to a2 or a3 or mlu.") + raise ValueError("Please set --device to a2 or a3 or mlu or cuda.") # Adding CMake arguments set as environment variable @@ -353,7 +373,7 @@ def build_extension(self, ext: CMakeExtension): class BuildDistWheel(bdist_wheel): user_options = bdist_wheel.user_options + [ - ("device=", None, "target device type (a3 or a2 or mlu)"), + ("device=", None, "target device type (a3 or a2 or mlu or cuda)"), ("arch=", None, "target arch type (x86 or arm)"), ] @@ -530,7 +550,7 @@ def apply_patch(): idx = sys.argv.index('--device') if idx + 1 < len(sys.argv): device = sys.argv[idx+1].lower() - if device not in ('a2', 'a3', 'mlu'): + if device not in ('a2', 'a3', 'mlu', 'cuda'): print("Error: --device must be a2 or a3 or mlu (case-insensitive)") sys.exit(1) # Remove the arguments so setup() doesn't see them diff --git a/xllm/core/common/CMakeLists.txt b/xllm/core/common/CMakeLists.txt index 3410b2e5..f1e49f48 100644 --- a/xllm/core/common/CMakeLists.txt +++ b/xllm/core/common/CMakeLists.txt @@ -15,6 +15,7 @@ cc_library( rate_limiter.h types.h device_monitor.h + flashinfer_workspace.h SRCS etcd_client.cpp global_flags.cpp @@ -23,6 +24,7 @@ cc_library( options.cpp rate_limiter.cpp device_monitor.cpp + flashinfer_workspace.cpp DEPS util absl::random_random diff --git a/xllm/core/common/flashinfer_workspace.cpp b/xllm/core/common/flashinfer_workspace.cpp new file mode 100644 index 00000000..e7b72dda --- /dev/null +++ b/xllm/core/common/flashinfer_workspace.cpp @@ -0,0 +1,49 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "flashinfer_workspace.h" + +#include + +#include "global_flags.h" + +namespace xllm { + +void FlashinferWorkspace::initialize(const torch::Device& device) { + LOG(INFO) << "FlashinferWorkspace initialize on device: " << device; + float_workspace_buffer_ = + torch::empty({FLAGS_workspace_buffer_size}, + torch::dtype(torch::kUInt8).device(device)); + int_workspace_buffer_ = torch::empty( + {128 * 1024 * 1024}, torch::dtype(torch::kUInt8).device(device)); + page_locked_int_workspace_buffer_ = torch::empty( + {int_workspace_buffer_.size(0)}, + torch::dtype(torch::kUInt8).device(torch::kCPU).pinned_memory(true)); + LOG(INFO) << "FlashinferWorkspace initialize end"; +} + +torch::Tensor FlashinferWorkspace::get_float_workspace_buffer() { + return float_workspace_buffer_; +} + +torch::Tensor FlashinferWorkspace::get_int_workspace_buffer() { + return int_workspace_buffer_; +} + +torch::Tensor FlashinferWorkspace::get_page_locked_int_workspace_buffer() { + return page_locked_int_workspace_buffer_; +} + +} // namespace xllm \ No newline at end of file diff --git a/xllm/core/common/flashinfer_workspace.h b/xllm/core/common/flashinfer_workspace.h new file mode 100644 index 00000000..bbd875a3 --- /dev/null +++ b/xllm/core/common/flashinfer_workspace.h @@ -0,0 +1,49 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include + +#include + +#include "macros.h" + +namespace xllm { + +class FlashinferWorkspace { + public: + static FlashinferWorkspace& get_instance() { + static FlashinferWorkspace instance; + return instance; + }; + + void initialize(const torch::Device& device); + + torch::Tensor get_float_workspace_buffer(); + torch::Tensor get_int_workspace_buffer(); + torch::Tensor get_page_locked_int_workspace_buffer(); + + private: + FlashinferWorkspace() = default; + ~FlashinferWorkspace() = default; + DISALLOW_COPY_AND_ASSIGN(FlashinferWorkspace); + + torch::Tensor float_workspace_buffer_; + torch::Tensor int_workspace_buffer_; + torch::Tensor page_locked_int_workspace_buffer_; +}; + +} // namespace xllm \ No newline at end of file diff --git a/xllm/core/common/global_flags.cpp b/xllm/core/common/global_flags.cpp index 9e41164e..cc2168ec 100644 --- a/xllm/core/common/global_flags.cpp +++ b/xllm/core/common/global_flags.cpp @@ -276,6 +276,7 @@ DEFINE_int32(transfer_listen_port, 26000, "The KVCacheTranfer listen port."); DEFINE_bool(enable_shm, false, "Whether to enable shared memory for executing model."); + // --- function call config --- DEFINE_string(tool_call_parser, @@ -353,6 +354,7 @@ DEFINE_int32(micro_batch_num, "Default use two micro batches for multi-stream parallel."); // --- dit config --- + DEFINE_int32(max_requests_per_batch, 1, "Max number of request per batch."); // --- continuous kv cache config --- @@ -377,22 +379,34 @@ DEFINE_int64(buffer_size_per_seq, "Buffer size per sequence in bytes, default 0."); // --- beam search config --- + DEFINE_bool(enable_beam_search_kernel, false, "Whether to enable beam search kernel."); // --- reasoning parser config --- + DEFINE_string(reasoning_parser, "", "Specify the reasoning parser for handling reasoning " "interactions(e.g. glm45, qwen3, deepseek-r1)."); // --- qwen3 reranker config --- + DEFINE_bool(enable_qwen3_reranker, false, "Whether to enable qwen3 reranker."); +// --- flashinfer config --- + +DEFINE_int32(workspace_buffer_size, + 128 * 1024 * 1024, + "The user reserved workspace buffer used to store intermediate " + "attention results in split-k algorithm for flashinfer."); + +// --- prefetch weight config --- + DEFINE_bool( enable_prefetch_weight, false, "Whether to enable prefetch weight,only applicable to Qwen3-dense model." "The default prefetching ratio for gateup weight is 40%." - "If adjustments are needed, e.g. export PREFETCH_COEFFOCIENT=0.5"); + "If adjustments are needed, e.g. export PREFETCH_COEFFOCIENT=0.5"); \ No newline at end of file diff --git a/xllm/core/common/global_flags.h b/xllm/core/common/global_flags.h index 7fc36442..3d17dd54 100644 --- a/xllm/core/common/global_flags.h +++ b/xllm/core/common/global_flags.h @@ -204,3 +204,5 @@ DECLARE_string(reasoning_parser); DECLARE_bool(enable_shm); DECLARE_bool(enable_prefetch_weight); + +DECLARE_int32(workspace_buffer_size); diff --git a/xllm/core/distributed_runtime/worker_server.cpp b/xllm/core/distributed_runtime/worker_server.cpp index 22d10f90..ab334461 100644 --- a/xllm/core/distributed_runtime/worker_server.cpp +++ b/xllm/core/distributed_runtime/worker_server.cpp @@ -98,6 +98,7 @@ void WorkerServer::create_server( CollectiveCommunicator comm(worker_global_rank, world_size, dp_size, ep_size); const ParallelArgs* parallel_args = comm.parallel_args(); + // TODO: fix bug when creating cuda process group #if defined(USE_MLU) || defined(USE_CUDA) comm.create_process_groups(master_node_addr, device); #endif diff --git a/xllm/core/framework/batch/batch_input_builder.cpp b/xllm/core/framework/batch/batch_input_builder.cpp index 2ff34176..7f2fe880 100644 --- a/xllm/core/framework/batch/batch_input_builder.cpp +++ b/xllm/core/framework/batch/batch_input_builder.cpp @@ -216,7 +216,7 @@ void BatchInputBuilder::process_sequences_multithreaded(uint32_t start_idx, state_.q_seq_lens.insert(state_.q_seq_lens.end(), state.q_seq_lens.begin(), state.q_seq_lens.end()); -#elif defined(USE_MLU) +#elif defined(USE_MLU) || defined(USE_CUDA) int32_t seq_len_offset = state_.seq_lens.back(); // skip the first element which is 0 for (size_t i = 1; i < state.seq_lens.size(); ++i) { @@ -288,7 +288,7 @@ void BatchInputBuilder::process_single_sequence( #if defined(USE_NPU) state.seq_lens.push_back(seq_len); state.q_seq_lens.push_back(q_seq_len); -#elif defined(USE_MLU) +#elif defined(USE_MLU) || defined(USE_CUDA) state.seq_lens.push_back(state.seq_lens.back() + seq_len); state.q_seq_lens.push_back(state.q_seq_lens.back() + q_seq_len); #endif @@ -448,7 +448,12 @@ void BatchInputBuilder::setup_kv_cache_info( block_size = block.size(); block_ids.push_back(block.id()); u_block_ids.emplace_back(block.id()); + state.paged_kv_indices.push_back(block.id()); } + state.paged_kv_indptr.push_back(state.paged_kv_indptr.back() + blocks.size()); + int32_t last_page_len = + (seq_len % block_size == 0) ? block_size : seq_len % block_size; + state.paged_kv_last_page_len.push_back(last_page_len); int32_t kv_cache_block_idx = n_kv_cache_tokens / block_size; for (auto iter = block_ids.begin() + kv_cache_block_idx; @@ -517,12 +522,15 @@ void BatchInputBuilder::padding_decode_batch_size( #if defined(USE_NPU) state_.seq_lens.push_back(num_decoding_tokens); state_.q_seq_lens.push_back(num_decoding_tokens); -#elif defined(USE_MLU) +#elif defined(USE_MLU) || defined(USE_CUDA) state_.seq_lens.push_back(state_.seq_lens.back() + num_decoding_tokens); state_.q_seq_lens.push_back(state_.q_seq_lens.back() + num_decoding_tokens); #endif state_.block_tables_vec.emplace_back(); + state_.paged_kv_indices.push_back(0); + state_.paged_kv_indptr.push_back(state_.paged_kv_indptr.back() + 1); + state_.paged_kv_last_page_len.push_back(1); } } } @@ -560,6 +568,14 @@ ForwardInput BatchInputBuilder::state_to_forward_input() { input_params.decode_seq_range = util::find_ones_indices(input_params.q_seq_lens_vec); + // for flashinfer + input_params.paged_kv_indptr = + torch::tensor(state_.paged_kv_indptr, torch::kInt); + input_params.paged_kv_indices = + torch::tensor(state_.paged_kv_indices, torch::kInt); + input_params.paged_kv_last_page_len = + torch::tensor(state_.paged_kv_last_page_len, torch::kInt); + // Setup multimodal data input_params.mm_data = MMData::batch(mm_data_vec_); @@ -634,6 +650,12 @@ RawForwardInput BatchInputBuilder::state_to_raw_forward_input() { raw_forward_input.transfer_kv_infos = std::move(state_.transfer_kv_infos); raw_forward_input.prefill_seq_len = state_.prefill_seq_len; + // for flashinfer + raw_forward_input.paged_kv_indptr = std::move(state_.paged_kv_indptr); + raw_forward_input.paged_kv_indices = std::move(state_.paged_kv_indices); + raw_forward_input.paged_kv_last_page_len = + std::move(state_.paged_kv_last_page_len); + raw_forward_input.embedding_ids = std::move(state_.embedding_ids); raw_forward_input.extra_token_ids = std::move(state_.extra_token_ids); // beam search kernel input diff --git a/xllm/core/framework/batch/batch_input_builder.h b/xllm/core/framework/batch/batch_input_builder.h index 9b76bfb1..508610fd 100644 --- a/xllm/core/framework/batch/batch_input_builder.h +++ b/xllm/core/framework/batch/batch_input_builder.h @@ -86,7 +86,7 @@ class BatchInputBuilder { #if defined(USE_NPU) std::vector seq_lens; std::vector q_seq_lens; -#elif defined(USE_MLU) +#elif defined(USE_MLU) || defined(USE_CUDA) std::vector seq_lens = {0}; // cu_seq_lens std::vector q_seq_lens = {0}; // q_cu_seq_len #endif @@ -107,6 +107,11 @@ class BatchInputBuilder { // for continuous kvcache std::vector new_cache_slot_offsets; //[n_tokens] std::vector kv_cache_start_offsets; //[n_seq] + + // for flashinfer + std::vector paged_kv_indptr = {0}; + std::vector paged_kv_indices; + std::vector paged_kv_last_page_len; }; // Helper methods for sequence processing diff --git a/xllm/core/framework/batch/batch_test.cpp b/xllm/core/framework/batch/batch_test.cpp index b79f7b6d..2645fe56 100644 --- a/xllm/core/framework/batch/batch_test.cpp +++ b/xllm/core/framework/batch/batch_test.cpp @@ -152,7 +152,7 @@ TEST(BatchTest, Basic) { #if defined(USE_NPU) const std::vector q_seq_lens = {9, 1, 1, 4}; -#elif defined(USE_MLU) +#else const std::vector q_seq_lens = {0, 9, 10, 11, 15}; #endif EXPECT_TRUE(equal(input_params.q_seq_lens, q_seq_lens)); @@ -160,7 +160,7 @@ TEST(BatchTest, Basic) { // seq4's kv_seq_len = q_len + num_cached_tokens (q_len<=max_allowed_tokens) #if defined(USE_NPU) const std::vector kv_seq_lens = {9, 8, 16, 8}; -#elif defined(USE_MLU) +#else const std::vector kv_seq_lens = {0, 9, 17, 33, 41}; #endif EXPECT_TRUE(equal(input_params.kv_seq_lens, kv_seq_lens)); diff --git a/xllm/core/framework/model/model_input_params.h b/xllm/core/framework/model/model_input_params.h index 6669baaa..19a231da 100755 --- a/xllm/core/framework/model/model_input_params.h +++ b/xllm/core/framework/model/model_input_params.h @@ -97,6 +97,11 @@ struct ModelInputParams { // Copy graph_buffer to device params.graph_buffer = safe_to(graph_buffer, device, true); + // params for flashinfer + params.paged_kv_indptr = safe_to(paged_kv_indptr, device); + params.paged_kv_indices = safe_to(paged_kv_indices, device); + params.paged_kv_last_page_len = safe_to(paged_kv_last_page_len, device); + return params; } @@ -201,6 +206,21 @@ struct ModelInputParams { // Graph execution buffer for temporary tensor storage // Used by ACL Graph Executor to avoid repeated memory allocation torch::Tensor graph_buffer; + + // the indptr of the paged kv-cache + // used in flashinfer + // IntTensor: [n_seq + 1] + torch::Tensor paged_kv_indptr; + + // the page indices of the paged kv cache + // used in flashinfer + torch::Tensor paged_kv_indices; + + // the number of entries in the last page of each request in + // the paged kv cache + // used in flashinfer + // IntTensor: [n_seq] + torch::Tensor paged_kv_last_page_len; }; } // namespace xllm diff --git a/xllm/core/framework/parallel_state/cuda_process_group.h b/xllm/core/framework/parallel_state/cuda_process_group.h index 3e1ed375..349cf008 100644 --- a/xllm/core/framework/parallel_state/cuda_process_group.h +++ b/xllm/core/framework/parallel_state/cuda_process_group.h @@ -34,7 +34,9 @@ class ProcessGroupNccl : public ProcessGroup { : ProcessGroup(device) { c10::intrusive_ptr pg_options = c10d::ProcessGroupNCCL::Options::create(); +#if TORCH_VERSION_MAJOR >= 2 && TORCH_VERSION_MINOR >= 7 pg_options->group_name = group_name; +#endif int rank = global_rank; if (world_size != rank_size) { auto [local_rank, group_ranks] = @@ -47,12 +49,6 @@ class ProcessGroupNccl : public ProcessGroup { pg_ = std::make_unique( store, rank, rank_size, pg_options); } - - ~ProcessGroupNccl() override { - if (pg_) { - pg_->shutdown(); - } - } }; std::unique_ptr create_process_group( diff --git a/xllm/core/framework/parallel_state/mlu_process_group.h b/xllm/core/framework/parallel_state/mlu_process_group.h index 3a0f1138..09a95f55 100644 --- a/xllm/core/framework/parallel_state/mlu_process_group.h +++ b/xllm/core/framework/parallel_state/mlu_process_group.h @@ -47,12 +47,6 @@ class ProcessGroupCncl : public ProcessGroup { pg_ = std::make_unique( store, rank, rank_size, pg_options); } - - ~ProcessGroupCncl() override { - if (pg_) { - pg_->shutdown(); - } - } }; std::unique_ptr create_process_group( diff --git a/xllm/core/kernels/CMakeLists.txt b/xllm/core/kernels/CMakeLists.txt index 4aa1941b..3bba0e16 100644 --- a/xllm/core/kernels/CMakeLists.txt +++ b/xllm/core/kernels/CMakeLists.txt @@ -8,6 +8,9 @@ if(USE_MLU) add_subdirectory(mlu) endif() +if(USE_CUDA) + add_subdirectory(cuda) +endif() cc_library( NAME @@ -21,4 +24,5 @@ cc_library( torch $<$:npu_kernels> $<$:mlu_kernels> + $<$:cuda_kernels> ) \ No newline at end of file diff --git a/xllm/core/kernels/cuda/CMakeLists.txt b/xllm/core/kernels/cuda/CMakeLists.txt new file mode 100644 index 00000000..8c5e7c6f --- /dev/null +++ b/xllm/core/kernels/cuda/CMakeLists.txt @@ -0,0 +1,22 @@ +include(cc_library) + +file(GLOB_RECURSE CUDA_HEADER_FILES + "${CMAKE_CURRENT_LIST_DIR}/*.h" +) + +file(GLOB_RECURSE CUDA_SOURCE_FILES + "${CMAKE_CURRENT_LIST_DIR}/*.cpp" + "${CMAKE_CURRENT_LIST_DIR}/*.cu" +) + +cc_library( + NAME + cuda_kernels + HDRS + ${CUDA_HEADER_FILES} + SRCS + ${CUDA_SOURCE_FILES} + DEPS + tvm_ffi + torch +) diff --git a/xllm/core/kernels/cuda/activation.cpp b/xllm/core/kernels/cuda/activation.cpp new file mode 100644 index 00000000..d949d2fe --- /dev/null +++ b/xllm/core/kernels/cuda/activation.cpp @@ -0,0 +1,38 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "cuda_ops_api.h" + +namespace xllm::kernel::cuda { + +void act_and_mul(torch::Tensor out, + torch::Tensor input, + const std::string& act_mode) { + if (act_mode != "silu" && act_mode != "gelu" && act_mode != "gelu_tanh") { + throw std::runtime_error("Unsupported act mode: " + act_mode + + ", only support silu, gelu, gelu_tanh"); + } + + std::string uri = act_mode + "_and_mul"; + + auto lib = torch::DynamicLibrary(path_to_uri(uri).c_str(), nullptr, true); + std::string schema_name = uri; + + torch::Dispatcher::singleton() + .findSchemaOrThrow(schema_name.c_str(), "") + .typed() + .call(out, input, support_pdl()); +} +} // namespace xllm::kernel::cuda \ No newline at end of file diff --git a/xllm/core/kernels/cuda/batch_decode.cpp b/xllm/core/kernels/cuda/batch_decode.cpp new file mode 100644 index 00000000..f93328d0 --- /dev/null +++ b/xllm/core/kernels/cuda/batch_decode.cpp @@ -0,0 +1,132 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "cuda_ops_api.h" + +namespace xllm::kernel::cuda { + +void batch_decode(torch::Tensor float_workspace_buffer, + torch::Tensor int_workspace_buffer, + torch::Tensor page_locked_int_workspace_buffer, + torch::Tensor query, + torch::Tensor k_cache, + torch::Tensor v_cache, + torch::Tensor q_cu_seq_lens, + torch::Tensor paged_kv_indptr, + torch::Tensor paged_kv_indices, + torch::Tensor paged_kv_last_page_len, + int64_t window_size_left, + torch::Tensor output, + std::optional& output_lse, + bool enable_cuda_graph) { + std::string uri = get_batch_decode_uri(query.scalar_type(), + k_cache.scalar_type(), + output.scalar_type(), + paged_kv_indptr.scalar_type(), + query.size(-1), + v_cache.size(-1), + /*pos_encoding_mode=*/0, + /*use_sliding_window=*/false, + /*use_logits_soft_cap=*/false); + + torch::Tensor qo_indptr_host = q_cu_seq_lens.to(torch::kCPU); + const int64_t batch_size = q_cu_seq_lens.size(0) - 1; + const double sm_scale = compute_sm_scale(query.size(-1)); + + torch::Tensor empty_q_data = + torch::empty({0}, torch::TensorOptions().dtype(query.scalar_type())); + torch::Tensor empty_kv_data = + torch::empty({0}, torch::TensorOptions().dtype(k_cache.scalar_type())); + + auto lib = torch::DynamicLibrary(path_to_uri(uri).c_str(), nullptr, true); + std::string plan_schema_name = uri + "::plan"; + std::string run_schema_name = uri + "::run"; + + torch::Tensor plan_info = torch::Dispatcher::singleton() + .findSchemaOrThrow(plan_schema_name.c_str(), "") + .typed() + .call(float_workspace_buffer, + int_workspace_buffer, + page_locked_int_workspace_buffer, + qo_indptr_host, + batch_size, + query.size(1), // num_qo_heads + k_cache.size(2), // num_kv_heads + k_cache.size(1), // block_size + enable_cuda_graph, + /*window_left=*/-1, + /* logits_soft_cap=*/0.0, + query.size(-1), // head_dim_qk + v_cache.size(-1), // head_dim_vo + empty_q_data, + empty_kv_data); + + torch::Dispatcher::singleton() + .findSchemaOrThrow(run_schema_name.c_str(), "") + .typed, + int64_t, + int64_t, + bool, + std::optional, + double, + double, + double, + double)>() + .call(float_workspace_buffer, + int_workspace_buffer, + plan_info, + query, + k_cache, + v_cache, + paged_kv_indptr, + paged_kv_indices, + paged_kv_last_page_len, + output, + output_lse, + /*kv_layout_code=*/0, // NHD layout + /*window_left=*/-1, + support_pdl(), + /*maybe_alibi_slopes=*/std::optional(), + /*logits_soft_cap=*/0.0, + /*sm_scale=*/sm_scale, + /*rope_rcp_scale=*/1.0, + /*rope_rcp_theta=*/1.0 / 10000.0); +} + +} // namespace xllm::kernel::cuda \ No newline at end of file diff --git a/xllm/core/kernels/cuda/batch_prefill.cpp b/xllm/core/kernels/cuda/batch_prefill.cpp new file mode 100644 index 00000000..c836bd2c --- /dev/null +++ b/xllm/core/kernels/cuda/batch_prefill.cpp @@ -0,0 +1,143 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "cuda_ops_api.h" + +namespace xllm::kernel::cuda { + +void batch_prefill(torch::Tensor float_workspace_buffer, + torch::Tensor int_workspace_buffer, + torch::Tensor page_locked_int_workspace_buffer, + torch::Tensor query, + torch::Tensor key, + torch::Tensor value, + torch::Tensor q_cu_seq_lens, + torch::Tensor kv_cu_seq_lens, + int64_t window_left, + torch::Tensor output, + std::optional& output_lse, + bool enable_cuda_graph) { + std::string uri = get_batch_prefill_uri(/*backend=*/"fa2", + query.scalar_type(), + key.scalar_type(), + output.scalar_type(), + q_cu_seq_lens.scalar_type(), + query.size(-1), + value.size(-1), + /*pos_encoding_mode=*/0, + /*use_sliding_window=*/false, + /*use_logits_soft_cap=*/false, + /*use_fp16_qk_reduction=*/false); + + torch::Tensor kv_indptr_host = kv_cu_seq_lens.to(torch::kCPU); + torch::Tensor qo_indptr_host = q_cu_seq_lens.to(torch::kCPU); + torch::Tensor kv_len_arr_host = + kv_indptr_host.slice(0, 1) - kv_indptr_host.slice(0, 0, -1); + const int64_t total_num_rows = qo_indptr_host.size(0); + const int64_t batch_size = q_cu_seq_lens.size(0) - 1; + const double sm_scale = compute_sm_scale(query.size(-1)); + + auto lib = torch::DynamicLibrary(path_to_uri(uri).c_str(), nullptr, true); + std::string plan_schema_name = uri + "::plan"; + std::string run_schema_name = uri + "::ragged_run"; + + auto plan_info = torch::Dispatcher::singleton() + .findSchemaOrThrow(plan_schema_name.c_str(), "") + .typed() + .call(float_workspace_buffer, + int_workspace_buffer, + page_locked_int_workspace_buffer, + qo_indptr_host, + kv_indptr_host, + kv_len_arr_host, + total_num_rows, + batch_size, + query.size(1), // num_qo_heads + key.size(1), // num_kv_heads + /*page_size=*/1, + enable_cuda_graph, + query.size(-1), // head_dim_qk + value.size(-1), // head_dim_vo + /*causal=*/true); + + torch::Dispatcher::singleton() + .findSchemaOrThrow(run_schema_name.c_str(), "") + .typed, + int64_t, + int64_t, + int64_t, + bool, + std::optional, + std::optional, + std::optional, + std::optional, + std::optional, + std::optional, + double, + double, + double, + double, + int64_t)>() + .call(float_workspace_buffer, + int_workspace_buffer, + plan_info, + query, + key, + value, + q_cu_seq_lens, + kv_cu_seq_lens, + output, + output_lse, + /*mask_mode_code=CAUSAL*/ 1, + /*kv_layout_code=*/0, // NHD layout + /*window_left=*/-1, + support_pdl(), + /*maybe_custom_mask=*/std::optional(), + /*maybe_mask_indptr=*/std::optional(), + /*maybe_alibi_slopes=*/std::optional(), + /*maybe_prefix_len_ptr=*/std::optional(), + /*maybe_token_pos_in_items_ptr=*/std::optional(), + /*maybe_max_item_len_ptr=*/std::optional(), + /*logits_soft_cap=*/0.0, + /*sm_scale=*/sm_scale, + /*rope_rcp_scale=*/1.0, + /*rope_rcp_theta=*/1.0 / 10000.0, + /*token_pos_in_items_len=*/0); +} + +} // namespace xllm::kernel::cuda \ No newline at end of file diff --git a/xllm/core/kernels/cuda/cuda_ops_api.h b/xllm/core/kernels/cuda/cuda_ops_api.h new file mode 100644 index 00000000..c54b3536 --- /dev/null +++ b/xllm/core/kernels/cuda/cuda_ops_api.h @@ -0,0 +1,82 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include +#include + +#include + +#include "utils.h" + +namespace xllm::kernel::cuda { + +void apply_rope_pos_ids_cos_sin_cache(torch::Tensor q, + torch::Tensor k, + torch::Tensor cos_sin_cache, + torch::Tensor pos_ids, + bool interleave); + +// act_mode only support silu, gelu, gelu_tanh +void act_and_mul(torch::Tensor out, + torch::Tensor input, + const std::string& act_mode); + +void reshape_paged_cache( + torch::Tensor slot_ids, // [n_tokens] + torch::Tensor keys, // [n_tokens, n_kv_heads, head_dim] + torch::Tensor values, // [n_tokens, n_kv_heads, head_dim] + torch::Tensor key_cache, // [n_blocks, block_size, n_heads, head_dim] + torch::Tensor value_cache); + +void batch_prefill(torch::Tensor float_workspace_buffer, + torch::Tensor int_workspace_buffer, + torch::Tensor page_locked_int_workspace_buffer, + torch::Tensor query, + torch::Tensor key, + torch::Tensor value, + torch::Tensor q_cu_seq_lens, + torch::Tensor kv_cu_seq_lens, + int64_t window_left, + torch::Tensor output, + std::optional& output_lse, + bool enable_cuda_graph); + +void batch_decode(torch::Tensor float_workspace_buffer, + torch::Tensor int_workspace_buffer, + torch::Tensor page_locked_int_workspace_buffer, + torch::Tensor query, + torch::Tensor k_cache, + torch::Tensor v_cache, + torch::Tensor q_cu_seq_lens, + torch::Tensor paged_kv_indptr, + torch::Tensor paged_kv_indices, + torch::Tensor paged_kv_last_page_len, + int64_t window_left, + torch::Tensor output, + std::optional& output_lse, + bool enable_cuda_graph); + +void rmsnorm(torch::Tensor output, + torch::Tensor input, + torch::Tensor weight, + double eps); + +torch::Tensor matmul(torch::Tensor a, + torch::Tensor b, + std::optional bias); + +} // namespace xllm::kernel::cuda \ No newline at end of file diff --git a/xllm/core/kernels/cuda/matmul.cpp b/xllm/core/kernels/cuda/matmul.cpp new file mode 100644 index 00000000..2af7a61a --- /dev/null +++ b/xllm/core/kernels/cuda/matmul.cpp @@ -0,0 +1,27 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "cuda_ops_api.h" + +namespace xllm::kernel::cuda { + +torch::Tensor matmul(torch::Tensor a, + torch::Tensor b, + std::optional bias) { + namespace F = torch::nn::functional; + return F::linear(a, b, bias.value_or(torch::Tensor())); +} + +} // namespace xllm::kernel::cuda \ No newline at end of file diff --git a/xllm/core/kernels/cuda/norm.cpp b/xllm/core/kernels/cuda/norm.cpp new file mode 100644 index 00000000..514cae30 --- /dev/null +++ b/xllm/core/kernels/cuda/norm.cpp @@ -0,0 +1,33 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "cuda_ops_api.h" + +namespace xllm::kernel::cuda { + +void rmsnorm(torch::Tensor output, + torch::Tensor input, + torch::Tensor weight, + double eps) { + auto lib = torch::DynamicLibrary(path_to_uri("norm").c_str(), nullptr, true); + std::string schema_name = "norm::rmsnorm"; + + torch::Dispatcher::singleton() + .findSchemaOrThrow(schema_name.c_str(), "") + .typed() + .call(output, input, weight, eps, support_pdl()); +} + +} // namespace xllm::kernel::cuda \ No newline at end of file diff --git a/xllm/core/kernels/cuda/reshape_paged_cache.cu b/xllm/core/kernels/cuda/reshape_paged_cache.cu new file mode 100644 index 00000000..c27eddf6 --- /dev/null +++ b/xllm/core/kernels/cuda/reshape_paged_cache.cu @@ -0,0 +1,108 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "cuda_ops_api.h" + +namespace { +// NOLINTBEGIN(cppcoreguidelines-macro-usage) +#define DISPATCH_CASE_FLOATING_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) +#define DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) +// NOLINTEND(cppcoreguidelines-macro-usage) +} // namespace + +namespace xllm::kernel::cuda { + +template +__global__ void reshape_paged_cache_kernel( + const int* __restrict__ slot_ids, // [n_tokens] + const T* __restrict__ keys, // [n_tokens, n_heads, head_dim] + const T* __restrict__ values, // [n_tokens, n_heads, head_dim] + T* __restrict__ key_cache, + T* __restrict__ value_cache, + int64_t k_stride, + int64_t v_stride, + int64_t n_kv_heads, + int64_t head_dim, + int64_t block_size) { + // block/token index + const int64_t bid = blockIdx.x; + // which slot to write to + const int64_t slot_id = slot_ids[bid]; + // block index + const int64_t block_idx = slot_id / block_size; + // offset within block + const int64_t block_offset = slot_id % block_size; + // base index for the block in cache + const int64_t block_base_idx = block_idx * block_size * n_kv_heads * head_dim; + // copy value one by one for the token + for (int64_t i = threadIdx.x; i < n_kv_heads * head_dim; i += blockDim.x) { + const int64_t k_src_idx = bid * k_stride + i; + const int64_t v_src_idx = bid * v_stride + i; + // cache: [n_blocks, block_size, n_heads, head_dim] + const int64_t head_base_idx = + block_base_idx + block_offset * n_kv_heads * head_dim; + // which head to write to + const int head_idx = i / head_dim; + // which dim within head to write to + const int head_offset = i % head_dim; + const int64_t dst_idx = head_base_idx + head_idx * head_dim + head_offset; + key_cache[dst_idx] = keys[k_src_idx]; + value_cache[dst_idx] = values[v_src_idx]; + } +} + +void reshape_paged_cache( + torch::Tensor slot_ids, // [n_tokens] + torch::Tensor keys, // [n_tokens, n_kv_heads, head_dim] + torch::Tensor values, // [n_tokens, n_kv_heads, head_dim] + torch::Tensor key_cache, // [n_blocks, block_size, n_heads, head_dim] + torch::Tensor value_cache) { + // keys and values should be continuous at n_kv_heads and head_dim dims + CHECK(keys.stride(-1) == 1 && keys.stride(-2) == keys.size(-1)); + CHECK(values.stride(-1) == 1 && values.stride(-2) == values.size(-1)); + const int64_t n_tokens = keys.size(-3); + const int64_t n_kv_heads = keys.size(-2); + const int64_t head_dim = keys.size(-1); + const int64_t block_size = key_cache.size(-3); + // it is possible that keys and values have different strides + const int64_t k_stride = keys.stride(-3); + const int64_t v_stride = values.stride(-3); + const int64_t n = n_kv_heads * head_dim; + dim3 grid(n_tokens); + dim3 block(std::min(n, 1024)); + DISPATCH_FLOATING_TYPES( + keys.scalar_type(), "reshape_paged_cache_kernel", [&] { + reshape_paged_cache_kernel + <<>>( + slot_ids.data_ptr(), + keys.data_ptr(), + values.data_ptr(), + key_cache.data_ptr(), + value_cache.data_ptr(), + k_stride, + v_stride, + n_kv_heads, + head_dim, + block_size); + }); +} + +} // namespace xllm::kernel::cuda \ No newline at end of file diff --git a/xllm/core/kernels/cuda/rope.cpp b/xllm/core/kernels/cuda/rope.cpp new file mode 100644 index 00000000..be2e9744 --- /dev/null +++ b/xllm/core/kernels/cuda/rope.cpp @@ -0,0 +1,44 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "cuda_ops_api.h" + +namespace xllm::kernel::cuda { + +void apply_rope_pos_ids_cos_sin_cache(torch::Tensor q, + torch::Tensor k, + torch::Tensor cos_sin_cache, + torch::Tensor pos_ids, + bool interleave) { + const int64_t head_dim = cos_sin_cache.size(-1); + q = q.view({q.size(0), -1, head_dim}); + k = k.view({k.size(0), -1, head_dim}); + + auto lib = torch::DynamicLibrary(path_to_uri("rope").c_str(), nullptr, true); + std::string schema_name = "rope::apply_rope_pos_ids_cos_sin_cache"; + + torch::Dispatcher::singleton() + .findSchemaOrThrow(schema_name.c_str(), "") + .typed() + .call(q, k, q, k, cos_sin_cache, pos_ids, interleave); +} + +} // namespace xllm::kernel::cuda \ No newline at end of file diff --git a/xllm/core/kernels/cuda/utils.cpp b/xllm/core/kernels/cuda/utils.cpp new file mode 100644 index 00000000..b73ed98b --- /dev/null +++ b/xllm/core/kernels/cuda/utils.cpp @@ -0,0 +1,114 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "utils.h" + +#include + +#include + +namespace { +const std::string base_ops_path = + "/root/.cache/flashinfer/0.5.0/80_89_90a/cached_ops"; + +const std::unordered_map + filename_safe_dtype_map = { + {torch::kFloat16, "f16"}, + {torch::kBFloat16, "bf16"}, + {torch::kFloat8_e4m3fn, "e4m3"}, + {torch::kFloat8_e5m2, "e5m2"}, + {torch::kInt8, "i8"}, + {torch::kUInt8, "u8"}, + {torch::kInt32, "i32"}, + {torch::kUInt32, "u32"}, + {torch::kInt64, "i64"}, + {torch::kUInt64, "u64"}, +}; +} // namespace + +namespace xllm::kernel::cuda { + +// Whether to enable Programmatic Dependent Launch (PDL). See +// https://docs.nvidia.com/cuda/cuda-c-programming-guide/#programmatic-dependent-launch-and-synchronization +// Only supported for >= sm90, and currently only for FA2, CUDA core, and +// trtllm-gen decode. +bool support_pdl() { + cudaDeviceProp prop; + cudaGetDeviceProperties(&prop, /*device_id=*/0); + return prop.major >= 9; +} + +double compute_sm_scale(int64_t head_dim) { + return 1.0 / std::sqrt(static_cast(head_dim)); +} + +std::string path_to_uri(const std::string& uri) { + return base_ops_path + "/" + uri + "/" + uri + ".so"; +} + +std::string get_batch_prefill_uri(const std::string& backend, + torch::ScalarType dtype_q, + torch::ScalarType dtype_kv, + torch::ScalarType dtype_o, + torch::ScalarType dtype_idx, + int64_t head_dim_qk, + int64_t head_dim_vo, + int64_t pos_encoding_mode, + bool use_sliding_window, + bool use_logits_soft_cap, + bool use_fp16_qk_reduction) { + std::ostringstream oss; + oss << "batch_prefill_with_kv_cache_" + << "dtype_q_" << filename_safe_dtype_map.at(dtype_q) << "_" + << "dtype_kv_" << filename_safe_dtype_map.at(dtype_kv) << "_" + << "dtype_o_" << filename_safe_dtype_map.at(dtype_o) << "_" + << "dtype_idx_" << filename_safe_dtype_map.at(dtype_idx) << "_" + << "head_dim_qk_" << head_dim_qk << "_" + << "head_dim_vo_" << head_dim_vo << "_" + << "posenc_" << pos_encoding_mode << "_" + << "use_swa_" << (use_sliding_window ? "True" : "False") << "_" + << "use_logits_cap_" << (use_logits_soft_cap ? "True" : "False") << "_" + << "f16qk_" << (use_fp16_qk_reduction ? "True" : "False"); + + if (backend == "fa3") oss << "_sm90"; + + return oss.str(); +} + +std::string get_batch_decode_uri(torch::ScalarType dtype_q, + torch::ScalarType dtype_kv, + torch::ScalarType dtype_o, + torch::ScalarType dtype_idx, + int64_t head_dim_qk, + int64_t head_dim_vo, + int64_t pos_encoding_mode, + bool use_sliding_window, + bool use_logits_soft_cap) { + std::ostringstream oss; + oss << "batch_decode_with_kv_cache_" + << "dtype_q_" << filename_safe_dtype_map.at(dtype_q) << "_" + << "dtype_kv_" << filename_safe_dtype_map.at(dtype_kv) << "_" + << "dtype_o_" << filename_safe_dtype_map.at(dtype_o) << "_" + << "dtype_idx_" << filename_safe_dtype_map.at(dtype_idx) << "_" + << "head_dim_qk_" << head_dim_qk << "_" + << "head_dim_vo_" << head_dim_vo << "_" + << "posenc_" << pos_encoding_mode << "_" + << "use_swa_" << (use_sliding_window ? "True" : "False") << "_" + << "use_logits_cap_" << (use_logits_soft_cap ? "True" : "False"); + + return oss.str(); +} + +} // namespace xllm::kernel::cuda \ No newline at end of file diff --git a/xllm/core/kernels/cuda/utils.h b/xllm/core/kernels/cuda/utils.h new file mode 100644 index 00000000..55482cdf --- /dev/null +++ b/xllm/core/kernels/cuda/utils.h @@ -0,0 +1,52 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include + +#include + +namespace xllm::kernel::cuda { + +bool support_pdl(); + +double compute_sm_scale(int64_t head_dim); + +std::string path_to_uri(const std::string& uri); + +std::string get_batch_prefill_uri(const std::string& backend, + torch::ScalarType dtype_q, + torch::ScalarType dtype_kv, + torch::ScalarType dtype_o, + torch::ScalarType dtype_idx, + int64_t head_dim_qk, + int64_t head_dim_vo, + int64_t pos_encoding_mode, + bool use_sliding_window, + bool use_logits_soft_cap, + bool use_fp16_qk_reduction); + +std::string get_batch_decode_uri(torch::ScalarType dtype_q, + torch::ScalarType dtype_kv, + torch::ScalarType dtype_o, + torch::ScalarType dtype_idx, + int64_t head_dim_qk, + int64_t head_dim_vo, + int64_t pos_encoding_mode, + bool use_sliding_window, + bool use_logits_soft_cap); + +} // namespace xllm::kernel::cuda \ No newline at end of file diff --git a/xllm/core/kernels/mlu/mlu_ops_api.h b/xllm/core/kernels/mlu/mlu_ops_api.h index 14c18783..cbf84b19 100644 --- a/xllm/core/kernels/mlu/mlu_ops_api.h +++ b/xllm/core/kernels/mlu/mlu_ops_api.h @@ -26,11 +26,6 @@ limitations under the License. namespace xllm::kernel::mlu { -static const std::string kActModeSilu = "silu"; -static const std::string kActModeGelu = "gelu"; -static const std::string kActModeQuickGelu = "quick_gelu"; -static const std::string kActModeSwish = "swish"; - void apply_rotary(torch::Tensor& q, torch::Tensor& k, const torch::Tensor& sin, diff --git a/xllm/core/kernels/ops_api.cpp b/xllm/core/kernels/ops_api.cpp index 01db841c..8690be19 100644 --- a/xllm/core/kernels/ops_api.cpp +++ b/xllm/core/kernels/ops_api.cpp @@ -15,8 +15,15 @@ limitations under the License. #include "ops_api.h" -namespace xllm { -namespace kernel { +#if defined(USE_MLU) +#include "mlu/mlu_ops_api.h" +#elif defined(USE_CUDA) +#include "cuda/cuda_ops_api.h" +#endif + +#include + +namespace xllm::kernel { void apply_rotary(RotaryParams& params) { #if defined(USE_MLU) @@ -30,6 +37,25 @@ void apply_rotary(RotaryParams& params) { params.discrete, params.dynamic_ntk, params.max_query_len); +#elif defined(USE_CUDA) + if (!params.position_ids.has_value()) { + std::vector position_ids_vec; + position_ids_vec.reserve(params.cu_query_lens[-1].item()); + for (size_t i = 0; i < params.cu_query_lens.size(0); ++i) { + int32_t start_pos = params.cu_query_lens[i].item(); + int32_t end_pos = params.cu_query_lens[i + 1].item(); + std::iota(position_ids_vec.begin() + start_pos, + position_ids_vec.begin() + end_pos, + start_pos); + } + params.position_ids = torch::from_blob( + position_ids_vec.data(), position_ids_vec.size(), torch::kInt32); + } + cuda::apply_rope_pos_ids_cos_sin_cache(params.q, + params.k, + params.cos_sin, + params.position_ids.value(), + params.interleaved); #else throw std::runtime_error("apply_rotary not implemented"); #endif @@ -45,6 +71,8 @@ void active(ActivationParams& params) { params.is_gated, params.start_expert_id, params.expert_size); +#elif defined(USE_CUDA) + cuda::act_and_mul(params.output, params.input, params.act_mode); #else throw std::runtime_error("active not implemented"); #endif @@ -58,6 +86,12 @@ void reshape_paged_cache(ReshapePagedCacheParams& params) { params.v_cache, params.slot_mapping, params.direction); +#elif defined(USE_CUDA) + cuda::reshape_paged_cache(params.slot_mapping, + params.key, + params.value.value_or(torch::Tensor()), + params.k_cache, + params.v_cache.value_or(torch::Tensor())); #else throw std::runtime_error("reshape_paged_cache not implemented"); #endif @@ -87,6 +121,19 @@ void batch_prefill(AttentionParams& params) { params.window_size_right, params.compute_dtype, params.return_lse); +#elif defined(USE_CUDA) + cuda::batch_prefill(params.float_workspace_buffer, + params.int_workspace_buffer, + params.page_locked_int_workspace_buffer, + params.query, + params.key, + params.value, + params.q_cu_seq_lens, + params.kv_cu_seq_lens, + params.window_size_left, + params.output, + params.output_lse, + params.enable_cuda_graph); #else throw std::runtime_error("batch_prefill not implemented"); #endif @@ -114,6 +161,23 @@ void batch_decode(AttentionParams& params) { params.scale, params.return_lse, params.kv_cache_quant_bit_size); +#elif defined(USE_CUDA) + params.query = params.query.squeeze(1); + params.output = params.output.squeeze(1); + cuda::batch_decode(params.float_workspace_buffer, + params.int_workspace_buffer, + params.page_locked_int_workspace_buffer, + params.query, + params.k_cache, + params.v_cache, + params.q_cu_seq_lens, + params.paged_kv_indptr, + params.paged_kv_indices, + params.paged_kv_last_page_len, + params.window_size_left, + params.output, + params.output_lse, + params.enable_cuda_graph); #else throw std::runtime_error("batch_decode not implemented"); #endif @@ -136,6 +200,8 @@ void fused_layernorm(FusedLayerNormParams& params) { params.store_output_before_norm, params.store_output_after_norm, params.dynamic_quant); +#elif defined(USE_CUDA) + cuda::rmsnorm(params.output, params.input, params.weight, params.eps); #else throw std::runtime_error("fused_layernorm not implemented"); #endif @@ -145,6 +211,8 @@ torch::Tensor matmul(MatmulParams& params) { #if defined(USE_MLU) return mlu::matmul( params.a, params.b, params.bias, params.c, params.alpha, params.beta); +#elif defined(USE_CUDA) + return cuda::matmul(params.a, params.b, params.bias); #else throw std::runtime_error("matmul not implemented"); #endif @@ -182,6 +250,8 @@ torch::Tensor fused_moe(FusedMoEParams& params) { params.world_size, params.shared_expert_num, params.parallel_mode); +#elif defined(USE_CUDA) + throw std::runtime_error("fused_moe for cudanot implemented"); #else throw std::runtime_error("fused_moe not implemented"); #endif @@ -271,5 +341,4 @@ void masked_indexer_select_paged_kv(MaskedIndexerSelectPagedKVParams& params) { #endif } -} // namespace kernel -} // namespace xllm +} // namespace xllm::kernel diff --git a/xllm/core/kernels/ops_api.h b/xllm/core/kernels/ops_api.h index c5c64948..4250b70a 100644 --- a/xllm/core/kernels/ops_api.h +++ b/xllm/core/kernels/ops_api.h @@ -17,12 +17,12 @@ limitations under the License. #include "param.h" -#if defined(USE_MLU) -#include "mlu/mlu_ops_api.h" -#endif +namespace xllm::kernel { -namespace xllm { -namespace kernel { +static const std::string kActModeSilu = "silu"; +static const std::string kActModeGelu = "gelu"; +static const std::string kActModeQuickGelu = "quick_gelu"; +static const std::string kActModeSwish = "swish"; void apply_rotary(RotaryParams& params); @@ -51,5 +51,4 @@ torch::Tensor random_sample(RandomSampleParams& params); void masked_indexer_select_paged_kv(MaskedIndexerSelectPagedKVParams& params); -} // namespace kernel -} // namespace xllm +} // namespace xllm::kernel diff --git a/xllm/core/kernels/param.h b/xllm/core/kernels/param.h index 0c433847..8c604468 100644 --- a/xllm/core/kernels/param.h +++ b/xllm/core/kernels/param.h @@ -21,8 +21,7 @@ limitations under the License. #include #include -namespace xllm { -namespace kernel { +namespace xllm::kernel { // Note: add default values for optional parameters in the struct definition @@ -242,5 +241,4 @@ struct MaskedIndexerSelectPagedKVParams { int64_t quant_block_size; }; -} // namespace kernel -} // namespace xllm +} // namespace xllm::kernel diff --git a/xllm/core/layers/CMakeLists.txt b/xllm/core/layers/CMakeLists.txt index 53b987ac..5863c9e6 100644 --- a/xllm/core/layers/CMakeLists.txt +++ b/xllm/core/layers/CMakeLists.txt @@ -62,7 +62,6 @@ cc_library( word_embedding.h lm_head.h block_copy.h - linear.h SRCS multi_head_attention.cpp DEPS diff --git a/xllm/core/layers/common/CMakeLists.txt b/xllm/core/layers/common/CMakeLists.txt index 32e680ce..26c03aba 100755 --- a/xllm/core/layers/common/CMakeLists.txt +++ b/xllm/core/layers/common/CMakeLists.txt @@ -15,6 +15,7 @@ cc_library( qwen3_decoder_layer.h qwen3_moe_decoder_layer.h linear_impl.h + linear.h word_embedding_impl.h layer_utils.h indexer.h @@ -43,6 +44,7 @@ cc_library( glog::glog gflags::gflags torch + :platform ) # Add test for DenseMLP diff --git a/xllm/core/layers/common/attention.cpp b/xllm/core/layers/common/attention.cpp index adb8b911..8caa5c82 100644 --- a/xllm/core/layers/common/attention.cpp +++ b/xllm/core/layers/common/attention.cpp @@ -15,6 +15,7 @@ limitations under the License. #include "attention.h" +#include "common/flashinfer_workspace.h" #include "kernels/ops_api.h" DECLARE_bool(enable_chunked_prefill); @@ -37,6 +38,13 @@ AttentionMetadata AttentionMetadata::build(const ModelInputParams& params, attn_metadata.slot_mapping = params.new_cache_slots; attn_metadata.compute_dtype = compute_dtype; + // for flashinfer + attn_metadata.paged_kv_indptr = params.paged_kv_indptr; + attn_metadata.paged_kv_indices = params.paged_kv_indices; + attn_metadata.paged_kv_last_page_len = params.paged_kv_last_page_len; + attn_metadata.q_cu_seq_lens = params.q_seq_lens; + attn_metadata.kv_cu_seq_lens = params.kv_seq_lens; // cumulative kv seqlens + bool is_start_loc_match = (params.q_seq_lens_vec == params.kv_seq_lens_vec); attn_metadata.is_chunked_prefill = is_prefill && !is_start_loc_match; attn_metadata.is_prefill = is_prefill && !attn_metadata.is_chunked_prefill; @@ -96,6 +104,16 @@ std::tuple> AttentionImpl::forward( attention_params.window_size_left = sliding_window_; attention_params.scale = scale_; attention_params.compute_dtype = attn_metadata.compute_dtype; + // for flashinfer + attention_params.float_workspace_buffer = + FlashinferWorkspace::get_instance().get_float_workspace_buffer(); + attention_params.int_workspace_buffer = + FlashinferWorkspace::get_instance().get_int_workspace_buffer(); + attention_params.page_locked_int_workspace_buffer = + FlashinferWorkspace::get_instance() + .get_page_locked_int_workspace_buffer(); + attention_params.kv_cu_seq_lens = attn_metadata.kv_cu_seq_lens; + attention_params.q_cu_seq_lens = attn_metadata.q_cu_seq_lens; if (attn_metadata.is_prefill) { attention_params.key = key; @@ -127,6 +145,12 @@ std::tuple> AttentionImpl::forward( attention_params.block_table = attn_metadata.block_table; attention_params.kv_seq_lens = attn_metadata.kv_seq_lens; + // for flashinfer + attention_params.paged_kv_indptr = attn_metadata.paged_kv_indptr; + attention_params.paged_kv_indices = attn_metadata.paged_kv_indices; + attention_params.paged_kv_last_page_len = + attn_metadata.paged_kv_last_page_len; + xllm::kernel::batch_decode(attention_params); } diff --git a/xllm/core/layers/common/attention.h b/xllm/core/layers/common/attention.h index 7e210001..60fd92da 100644 --- a/xllm/core/layers/common/attention.h +++ b/xllm/core/layers/common/attention.h @@ -44,6 +44,13 @@ struct AttentionMetadata { std::string compute_dtype; bool is_prefill; bool is_chunked_prefill; + + // for flashinfer + torch::Tensor paged_kv_indptr; + torch::Tensor paged_kv_indices; + torch::Tensor paged_kv_last_page_len; + torch::Tensor q_cu_seq_lens; + torch::Tensor kv_cu_seq_lens; }; class AttentionImpl : public torch::nn::Module { diff --git a/xllm/core/layers/common/dense_mlp.h b/xllm/core/layers/common/dense_mlp.h index 2e517c19..a682e581 100644 --- a/xllm/core/layers/common/dense_mlp.h +++ b/xllm/core/layers/common/dense_mlp.h @@ -21,7 +21,7 @@ limitations under the License. #include "framework/parallel_state/parallel_args.h" #include "framework/quant_args.h" #include "framework/state_dict/state_dict.h" -#include "layers/linear.h" +#include "linear.h" namespace xllm { namespace layer { diff --git a/xllm/core/layers/common/fused_moe.h b/xllm/core/layers/common/fused_moe.h index 2e3154d0..35a9ddbf 100644 --- a/xllm/core/layers/common/fused_moe.h +++ b/xllm/core/layers/common/fused_moe.h @@ -24,7 +24,7 @@ limitations under the License. #include "framework/quant_args.h" #include "framework/state_dict/state_dict.h" #include "framework/state_dict/utils.h" -#include "layers/linear.h" +#include "linear.h" namespace xllm { namespace layer { diff --git a/xllm/core/layers/common/indexer.h b/xllm/core/layers/common/indexer.h index 70e8af71..a1d17362 100644 --- a/xllm/core/layers/common/indexer.h +++ b/xllm/core/layers/common/indexer.h @@ -28,7 +28,7 @@ limitations under the License. #include "framework/quant_args.h" #include "framework/state_dict/state_dict.h" #include "framework/state_dict/utils.h" -#include "layers/linear.h" +#include "linear.h" #include "rotary_embedding.h" namespace xllm { diff --git a/xllm/core/layers/linear.h b/xllm/core/layers/common/linear.h similarity index 98% rename from xllm/core/layers/linear.h rename to xllm/core/layers/common/linear.h index 7870dbeb..a2b238ab 100644 --- a/xllm/core/layers/linear.h +++ b/xllm/core/layers/common/linear.h @@ -18,14 +18,11 @@ limitations under the License. #include #include -#if defined(USE_MLU) -#include "common/linear_impl.h" -#endif +#include "linear_impl.h" namespace xllm { namespace layer { -#if defined(USE_MLU) class ColumnParallelLinear : public torch::nn::ModuleHolder { public: @@ -123,7 +120,6 @@ class ReplicatedLinear : public torch::nn::ModuleHolder { quant_args, options)) {} }; -#endif } // namespace layer } // namespace xllm diff --git a/xllm/core/layers/common/qwen3_attention.h b/xllm/core/layers/common/qwen3_attention.h index 9d5536ce..6b2bc2ba 100644 --- a/xllm/core/layers/common/qwen3_attention.h +++ b/xllm/core/layers/common/qwen3_attention.h @@ -23,8 +23,8 @@ limitations under the License. #include "framework/parallel_state/parallel_args.h" #include "framework/quant_args.h" #include "framework/state_dict/state_dict.h" -#include "layers/linear.h" #include "layers/rms_norm.h" +#include "linear.h" #include "rotary_embedding.h" namespace xllm { diff --git a/xllm/core/layers/common/rotary_embedding.cpp b/xllm/core/layers/common/rotary_embedding.cpp index 1280e29c..1dcd57d8 100644 --- a/xllm/core/layers/common/rotary_embedding.cpp +++ b/xllm/core/layers/common/rotary_embedding.cpp @@ -16,6 +16,7 @@ limitations under the License. #include "rotary_embedding.h" #include "kernels/ops_api.h" +#include "platform/device.h" namespace xllm { namespace layer { @@ -29,8 +30,7 @@ RotaryEmbeddingImpl::RotaryEmbeddingImpl(int rotary_dim, max_position_embeddings_(max_position_embeddings), rope_theta_(rope_theta), interleaved_(interleaved) { - auto dev_options = - torch::TensorOptions().device(torch::DeviceType::PrivateUse1); + auto dev_options = torch::TensorOptions().device(Device::type_torch()); auto inv_freq_t = torch::arange(/*start=*/0, /*end=*/rotary_dim_, diff --git a/xllm/core/layers/common/tests/tests_utils.cpp b/xllm/core/layers/common/tests/tests_utils.cpp index 86cc315a..e868acf5 100644 --- a/xllm/core/layers/common/tests/tests_utils.cpp +++ b/xllm/core/layers/common/tests/tests_utils.cpp @@ -15,6 +15,8 @@ limitations under the License. #include "tests_utils.h" +#include "core/platform/device.h" + namespace xllm { namespace layer { namespace test { @@ -118,7 +120,7 @@ QuantArgs CreateDefaultQuantArgs() { torch::TensorOptions CreateDefaultTensorOptions() { return torch::TensorOptions() .dtype(torch::kBFloat16) - .device(c10::DeviceType::PrivateUse1, 0) + .device(Device::type_torch(), 0) .requires_grad(false); } @@ -126,7 +128,7 @@ ParallelArgs CreateDefaultParallelArgs( std::unique_ptr& mock_process_group) { // Create mock ProcessGroup for MLU testing mock_process_group = std::make_unique( - torch::Device(c10::DeviceType::PrivateUse1, 0)); + torch::Device(Device::type_torch(), 0)); // Initialize ParallelArgs with mock ProcessGroup ParallelArgs parallel_args(0, 1, mock_process_group.get()); diff --git a/xllm/core/layers/common/tests/tests_utils.h b/xllm/core/layers/common/tests/tests_utils.h index 7918f494..e8598cbe 100644 --- a/xllm/core/layers/common/tests/tests_utils.h +++ b/xllm/core/layers/common/tests/tests_utils.h @@ -123,9 +123,11 @@ class MockBackend : public c10d::Backend { int getSize() const { return world_size_; } +#if TORCH_VERSION_MAJOR >= 2 && TORCH_VERSION_MINOR >= 7 void shutdown() override { // Mock implementation - do nothing } +#endif private: int rank_; diff --git a/xllm/core/platform/CMakeLists.txt b/xllm/core/platform/CMakeLists.txt index 97a53a22..3549c616 100644 --- a/xllm/core/platform/CMakeLists.txt +++ b/xllm/core/platform/CMakeLists.txt @@ -19,6 +19,7 @@ cc_library( $<$:cnrt> $<$:cndrv> $<$:cuda> + $<$:cudart> ) if(USE_NPU) diff --git a/xllm/core/platform/device.cpp b/xllm/core/platform/device.cpp index 6c3763c6..4687fb60 100644 --- a/xllm/core/platform/device.cpp +++ b/xllm/core/platform/device.cpp @@ -15,9 +15,12 @@ limitations under the License. #include "device.h" #if defined(USE_MLU) -#include +#include #include #include +#elif defined(USE_CUDA) +#include +#include #endif namespace xllm { @@ -27,16 +30,13 @@ Device::Device(torch::Device device) : device_(device) {} Device::operator torch::Device() const { return unwrap(); } void Device::set_device() const { - int ret = 0; #if defined(USE_NPU) - ret = c10_npu::SetDevice(index()); + c10_npu::set_device(index()); #elif defined(USE_MLU) torch_mlu::setDevice(index()); +#elif defined(USE_CUDA) + c10::cuda::set_device(index()); #endif - - if (ret != 0) { - LOG(ERROR) << "set device id: " << index() << " failed, ret:" << ret; - } } const torch::Device& Device::unwrap() const { return device_; } @@ -45,9 +45,8 @@ int32_t Device::index() const { return device_.index(); } // set device before init device context void Device::init_device_context() const { - std::string device_name = type() + ":" + std::to_string(index()); #if defined(USE_NPU) - torch_npu::init_npu(device_name); + torch_npu::init_npu(index()); #endif } @@ -56,14 +55,26 @@ int Device::device_count() { return c10_npu::device_count(); #elif defined(USE_MLU) return torch_mlu::device_count(); +#elif defined(USE_CUDA) + return c10::cuda::device_count(); #endif } -const std::string Device::type() { +std::string Device::type_str() { #if defined(USE_NPU) return "npu"; #elif defined(USE_MLU) return "mlu"; +#elif defined(USE_CUDA) + return "cuda"; +#endif +} + +torch::DeviceType Device::type_torch() { +#if defined(USE_NPU) || defined(USE_MLU) + return torch::kPrivateUse1; +#elif defined(USE_CUDA) + return torch::kCUDA; #endif } @@ -75,8 +86,9 @@ Device::DeviceMem Device::get_device_mem() const { #if defined(USE_NPU) aclrtGetMemInfo(ACL_HBM_MEM, &free_memory, &total_memory); #elif defined(USE_MLU) - std::tie(free_memory, total_memory) = - torch_mlu::MLUCachingAllocator::MemGetInfo(index()); + cnrtMemGetInfo(&free_memory, &total_memory); +#elif defined(USE_CUDA) + cudaMemGetInfo(&free_memory, &total_memory); #endif device_mem.total_memory = static_cast(total_memory); device_mem.free_memory = static_cast(free_memory); @@ -89,11 +101,13 @@ int64_t Device::free_memory() { return get_device_mem().free_memory; } int Device::synchronize_default_stream() { #if defined(USE_NPU) - return aclrtSynchronizeStream(c10_npu::getCurrentNPUStream(index()).stream()); + c10_npu::getCurrentNPUStream(index()).synchronize(); #elif defined(USE_MLU) torch_mlu::getCurrentMLUStream(index()).synchronize(); - return 0; +#elif defined(USE_CUDA) + c10::cuda::getCurrentCUDAStream().synchronize(); #endif + return 0; } std::unique_ptr Device::get_stream_from_pool() { diff --git a/xllm/core/platform/device.h b/xllm/core/platform/device.h index 65c5b5a6..0a08f649 100644 --- a/xllm/core/platform/device.h +++ b/xllm/core/platform/device.h @@ -15,6 +15,7 @@ limitations under the License. #pragma once +#include #include #include @@ -38,7 +39,8 @@ class Device { void init_device_context() const; static int device_count(); - static const std::string type(); + static std::string type_str(); + static torch::DeviceType type_torch(); int64_t total_memory(); int64_t free_memory(); diff --git a/xllm/core/platform/stream.cpp b/xllm/core/platform/stream.cpp index 5cb15b48..6e69276d 100644 --- a/xllm/core/platform/stream.cpp +++ b/xllm/core/platform/stream.cpp @@ -21,15 +21,13 @@ namespace xllm { Stream::Stream() : stream_(c10_npu::getNPUStreamFromPool()) {} #elif defined(USE_MLU) Stream::Stream() : stream_(torch_mlu::getStreamFromPool()) {} +#elif defined(USE_CUDA) +Stream::Stream() : stream_(c10::cuda::getStreamFromPool()) {} #endif int Stream::synchronize() const { -#if defined(USE_NPU) - return aclrtSynchronizeStream(stream_.stream()); -#elif defined(USE_MLU) stream_.unwrap().synchronize(); return 0; -#endif } c10::StreamGuard Stream::set_stream_guard() const { diff --git a/xllm/core/platform/stream.h b/xllm/core/platform/stream.h index 7cb65913..843105cb 100644 --- a/xllm/core/platform/stream.h +++ b/xllm/core/platform/stream.h @@ -21,13 +21,17 @@ limitations under the License. #endif // clang-format on +#include +#include + #include #if defined(USE_NPU) #include #include #elif defined(USE_MLU) -#include #include +#elif defined(USE_CUDA) +#include #endif namespace xllm { @@ -50,6 +54,8 @@ class Stream { c10_npu::NPUStream stream_; #elif defined(USE_MLU) torch_mlu::MLUStream stream_; +#elif defined(USE_CUDA) + c10::cuda::CUDAStream stream_; #endif }; diff --git a/xllm/core/platform/vmm_api.cpp b/xllm/core/platform/vmm_api.cpp index fc390611..09803434 100644 --- a/xllm/core/platform/vmm_api.cpp +++ b/xllm/core/platform/vmm_api.cpp @@ -98,7 +98,7 @@ void create_vir_ptr(VirPtr& vir_ptr, size_t aligned_size) { #elif defined(USE_MLU) ret = cnMemAddressReserve(&vir_ptr, aligned_size, 0, 0, 0); #elif defined(USE_CUDA) - ret = cuMemAddressReserve(&vir_ptr, aligned_size, 0, nullptr, 0); + ret = cuMemAddressReserve(&vir_ptr, aligned_size, 0, 0, 0); #endif CHECK_EQ(ret, 0) << "Failed to create virtual memory handle"; } diff --git a/xllm/core/runtime/forward_params.h b/xllm/core/runtime/forward_params.h index dd4a3d8f..8f8b63ad 100644 --- a/xllm/core/runtime/forward_params.h +++ b/xllm/core/runtime/forward_params.h @@ -180,6 +180,10 @@ struct RawForwardInput { std::vector kv_cache_start_offsets; //[n_seq] // beam search kernel input std::vector acc_logprob_vec; + // for flashinfer + std::vector paged_kv_indptr; //[n_seq + 1] + std::vector paged_kv_indices; //[num_used_pages] + std::vector paged_kv_last_page_len; //[n_seq] }; struct RawSampleOutput { diff --git a/xllm/core/runtime/llm_engine.cpp b/xllm/core/runtime/llm_engine.cpp index b2eddb5c..363df019 100644 --- a/xllm/core/runtime/llm_engine.cpp +++ b/xllm/core/runtime/llm_engine.cpp @@ -304,7 +304,7 @@ bool LLMEngine::allocate_kv_cache(const Engine::KVCacheCapacity& kv_cache_cap) { kv_cache_shape.emplace_back(std::vector{ kv_cache_cap.n_blocks, block_size, 1, args_.qk_rope_head_dim()}); } else { -#if defined(USE_NPU) +#if defined(USE_NPU) || defined(USE_CUDA) kv_cache_shape.emplace_back(std::vector{ kv_cache_cap.n_blocks, block_size, n_local_kv_heads_, head_dim_}); kv_cache_shape.emplace_back(std::vector{ diff --git a/xllm/core/runtime/llm_worker_impl.cpp b/xllm/core/runtime/llm_worker_impl.cpp index 7ce016a7..62d94053 100644 --- a/xllm/core/runtime/llm_worker_impl.cpp +++ b/xllm/core/runtime/llm_worker_impl.cpp @@ -26,6 +26,7 @@ limitations under the License. #include #include "common/device_monitor.h" +#include "common/flashinfer_workspace.h" #include "common/metrics.h" #include "common/types.h" #include "core/common/global_flags.h" @@ -43,6 +44,8 @@ LLMWorkerImpl::LLMWorkerImpl(const ParallelArgs& parallel_args, const runtime::Options& options) : WorkerImpl(parallel_args, device, options) { device_.set_device(); + // initialize flashinfer workspace + FlashinferWorkspace::get_instance().initialize(device_); } bool LLMWorkerImpl::init_model(ModelContext& context) { diff --git a/xllm/core/runtime/master.cpp b/xllm/core/runtime/master.cpp index 0ade6f56..d05faaa7 100644 --- a/xllm/core/runtime/master.cpp +++ b/xllm/core/runtime/master.cpp @@ -41,10 +41,6 @@ limitations under the License. #include "util/scope_guard.h" #include "util/timer.h" -#if defined(USE_NPU) -#include -#endif - namespace xllm { Master::Master(const Options& options, EngineType type) : options_(options) { diff --git a/xllm/core/runtime/params_utils.cpp b/xllm/core/runtime/params_utils.cpp index 428c0c3e..16bf6fff 100644 --- a/xllm/core/runtime/params_utils.cpp +++ b/xllm/core/runtime/params_utils.cpp @@ -64,6 +64,16 @@ void proto_to_forward_input(const proto::ForwardInput* pb_forward_input, std::vector(pb_forward_input->q_seq_lens().begin(), pb_forward_input->q_seq_lens().end()); // aprint(q_seq_lens, "q_seq_lens", global_rank_); + // for flashinfer + std::vector paged_kv_indptr = + std::vector(pb_forward_input->paged_kv_indptr().begin(), + pb_forward_input->paged_kv_indptr().end()); + std::vector paged_kv_indices = + std::vector(pb_forward_input->paged_kv_indices().begin(), + pb_forward_input->paged_kv_indices().end()); + std::vector paged_kv_last_page_len = + std::vector(pb_forward_input->paged_kv_last_page_len().begin(), + pb_forward_input->paged_kv_last_page_len().end()); std::vector> block_tables_vec; for (size_t i = 0; i < pb_forward_input->block_tables_vec().size(); ++i) { block_tables_vec.emplace_back(std::vector( @@ -213,6 +223,12 @@ void proto_to_forward_input(const proto::ForwardInput* pb_forward_input, input_params.kv_seq_lens_vec = std::move(seq_lens); input_params.q_seq_lens_vec = std::move(q_seq_lens); + input_params.paged_kv_indptr = torch::tensor(paged_kv_indptr, tensor_options); + input_params.paged_kv_indices = + torch::tensor(paged_kv_indices, tensor_options); + input_params.paged_kv_last_page_len = + torch::tensor(paged_kv_last_page_len, tensor_options); + input_params.new_cache_slots = torch::tensor(new_token_slot_ids, tensor_options); input_params.decode_seq_range = decode_seq_range; @@ -396,6 +412,13 @@ void forward_input_to_proto(const RawForwardInput& inputs, ADD_VECTOR_TO_PROTO(pb_forward_input->mutable_seq_lens(), inputs.seq_lens); ADD_VECTOR_TO_PROTO(pb_forward_input->mutable_q_seq_lens(), inputs.q_seq_lens); + // for flashinfer + ADD_VECTOR_TO_PROTO(pb_forward_input->mutable_paged_kv_indptr(), + inputs.paged_kv_indptr); + ADD_VECTOR_TO_PROTO(pb_forward_input->mutable_paged_kv_indices(), + inputs.paged_kv_indices); + ADD_VECTOR_TO_PROTO(pb_forward_input->mutable_paged_kv_last_page_len(), + inputs.paged_kv_last_page_len); ADD_VECTOR_TO_PROTO(pb_forward_input->mutable_new_token_slot_ids(), inputs.new_token_slot_ids); pb_forward_input->mutable_block_tables_vec()->Reserve( diff --git a/xllm/core/runtime/worker_impl.cpp b/xllm/core/runtime/worker_impl.cpp index 3fad5fc7..4afe14ee 100644 --- a/xllm/core/runtime/worker_impl.cpp +++ b/xllm/core/runtime/worker_impl.cpp @@ -24,6 +24,8 @@ limitations under the License. #include "kernels/npu/xllm_ops/replace_token.h" #elif defined(USE_MLU) #include +#elif defined(USE_CUDA) +#include #endif #include @@ -92,7 +94,7 @@ bool WorkerImpl::allocate_kv_cache( value_cache = at_npu::native::npu_format_cast( torch::empty(kv_cache_shape[1], torch::dtype(dtype_).device(device_)), 2); -#elif defined(USE_MLU) +#elif defined(USE_MLU) || defined(USE_CUDA) key_cache = torch::empty(kv_cache_shape[0], torch::dtype(dtype_).device(device_)); value_cache = @@ -300,6 +302,8 @@ std::tuple WorkerImpl::estimate_kv_cache_capacity() { device_id, &torch_cache, &torch_largest_block); #elif defined(USE_MLU) torch_mlu::MLUCachingAllocator::emptyCache(); +#elif defined(USE_CUDA) + c10::cuda::CUDACachingAllocator::emptyCache(); #endif const auto available_memory = device_.free_memory(); const auto total_memory = device_.total_memory(); @@ -351,16 +355,16 @@ void WorkerImpl::update_last_step_output( ForwardInput WorkerImpl::update_input_by_last_step_output( ForwardInput& inputs) { -#if defined(USE_A3) || defined(USE_MLU) +#if defined(USE_A2) + xllm_ops::replace_token(inputs.token_ids, + last_step_output_.sample_output.next_tokens); +#else auto& flatten_tokens = inputs.token_ids; auto neg_mask = (flatten_tokens < 0); auto clamped_neg_indices = torch::clamp(-flatten_tokens, 0); auto replacement = last_step_output_.sample_output.next_tokens.index( {clamped_neg_indices - 1}); inputs.token_ids = torch::where(neg_mask, replacement, flatten_tokens); -#else - xllm_ops::replace_token(inputs.token_ids, - last_step_output_.sample_output.next_tokens); #endif return inputs; } diff --git a/xllm/core/util/device_name_utils.cpp b/xllm/core/util/device_name_utils.cpp index 7d0919a3..53654f7b 100644 --- a/xllm/core/util/device_name_utils.cpp +++ b/xllm/core/util/device_name_utils.cpp @@ -36,7 +36,7 @@ std::vector DeviceNameUtils::parse_devices( } devices.reserve(num_devices); for (int i = 0; i < num_devices; ++i) { - std::string device_name = Device::type() + ":" + std::to_string(i); + std::string device_name = Device::type_str() + ":" + std::to_string(i); devices.emplace_back(torch::Device(device_name)); } return devices; @@ -49,14 +49,14 @@ std::vector DeviceNameUtils::parse_devices( for (const auto& device_str : device_strs) { std::vector parts = absl::StrSplit(device_str, ':'); CHECK(parts.size() == 2) << "Invalid device string format: " << device_str; - CHECK(parts[0] == Device::type()) + CHECK(parts[0] == Device::type_str()) << "Unsupported device type: " << parts[0]; int device_index; CHECK(absl::SimpleAtoi(parts[1], &device_index)) << "Invalid device index: " << parts[1]; - devices.emplace_back(c10::DeviceType::PrivateUse1, device_index); + devices.emplace_back(Device::type_torch(), device_index); device_types.insert(devices.back().type()); } CHECK(!devices.empty()) << "No devices specified."; diff --git a/xllm/models/llm/llm_model_base.h b/xllm/models/llm/llm_model_base.h index 156f8169..bd8cd4d2 100644 --- a/xllm/models/llm/llm_model_base.h +++ b/xllm/models/llm/llm_model_base.h @@ -125,11 +125,9 @@ class LlmDecoderLayerImplBase : public torch::nn::Module { } virtual void merge_loaded_weights() { decoder_layer_->merge_loaded_weights(); -#if defined(USE_NPU) block_copy_->merge_loaded_weights(); -#endif } -#elif defined(USE_MLU) +#else virtual torch::Tensor forward(torch::Tensor& x, torch::Tensor& positions, const layer::AttentionMetadata& attn_metadata, @@ -166,7 +164,7 @@ class LlmModelImplBase : public torch::nn::Module { torch::Tensor get_input_embeddings(torch::Tensor input_ids) { #if defined(USE_NPU) return embed_tokens_[0](input_ids, 0); -#elif defined(USE_MLU) +#else return embed_tokens_[0](input_ids); #endif } @@ -204,7 +202,7 @@ class LlmModelImplBase : public torch::nn::Module { } else { #if defined(USE_NPU) h = embed_tokens_[i](tokens[i], 0); -#elif defined(USE_MLU) +#else h = embed_tokens_[i](tokens[i]); #endif } @@ -307,7 +305,7 @@ class LlmModelImplBase : public torch::nn::Module { } auto cancated_h = torch::cat(hs, 0); return norm_(cancated_h, 0); -#elif defined(USE_MLU) +#else bool is_prefill = input_params[0].q_max_seq_len > 1; auto attn_metadata = layer::AttentionMetadata::build(input_params[0], is_prefill); @@ -407,7 +405,7 @@ class LlmForCausalLMImplBase : public torch::nn::Module { #if defined(USE_NPU) lm_head_ = register_module("lm_head", layer::LmHead(context)); -#elif defined(USE_MLU) +#else // lm_head_ is default to no quantization lm_head_ = register_module("lm_head", @@ -446,7 +444,7 @@ class LlmForCausalLMImplBase : public torch::nn::Module { // test #if defined(USE_NPU) return lm_head_(hidden_states, seleted_idxes, 0); -#elif defined(USE_MLU) +#else if (seleted_idxes.defined()) { h = h.index_select(/*dim=*/0, seleted_idxes); } diff --git a/xllm/models/llm/qwen3.h b/xllm/models/llm/qwen3.h index 4e958790..216e760c 100755 --- a/xllm/models/llm/qwen3.h +++ b/xllm/models/llm/qwen3.h @@ -56,7 +56,7 @@ class QWen3ModelImpl : public LlmModelImplBase { attn_mask_ = layer::AttentionMask(options.device(), options.dtype().toScalarType(), /*mask_value=*/mask_value); -#elif defined(USE_MLU) +#else norm_ = register_module( "norm", layer::RmsNorm( @@ -119,7 +119,7 @@ class QWen3ModelImpl : public LlmModelImplBase { } else { #if defined(USE_NPU) h = embed_tokens_[i](tokens[i], 0); -#elif defined(USE_MLU) +#else h = embed_tokens_[i](tokens[i]); #endif } @@ -225,7 +225,7 @@ class QWen3ModelImpl : public LlmModelImplBase { } auto cancated_h = torch::cat(hs, 0); return norm_(cancated_h, 0); -#elif defined(USE_MLU) +#else bool is_prefill = input_params[0].q_max_seq_len > 1; auto attn_metadata = layer::AttentionMetadata::build(input_params[0], is_prefill); diff --git a/xllm/models/llm/qwen3_moe.h b/xllm/models/llm/qwen3_moe.h index d540f94d..e1996b86 100644 --- a/xllm/models/llm/qwen3_moe.h +++ b/xllm/models/llm/qwen3_moe.h @@ -59,7 +59,7 @@ class Qwen3MoeDecoderLayerImpl : public torch::nn::Module { event, event_flag); } -#elif defined(USE_MLU) +#else torch::Tensor forward(torch::Tensor& x, torch::Tensor& positions, const layer::AttentionMetadata& attn_metadata, @@ -163,7 +163,7 @@ class Qwen3MoeModelImpl : public torch::nn::Module { /*mask_value=*/mask_value); norm_ = register_module("norm", layer::RmsNorm(context)); mapping_data_ = parallel_args.mapping_data(); -#elif defined(USE_MLU) +#else norm_ = register_module( "norm", layer::RmsNorm( @@ -290,7 +290,7 @@ class Qwen3MoeModelImpl : public torch::nn::Module { } } return norm_(h, 0); -#elif defined(USE_MLU) +#else ModelInputParams modified_input_params = input_params; layer::update_dummy_run_input(dp_rank_, positions, modified_input_params); bool is_prefill = modified_input_params.q_max_seq_len > 1; @@ -384,7 +384,7 @@ class Qwen3MoeForCausalLMImpl : public torch::nn::Module { model_ = register_module("model", Qwen3MoeModel(context)); #if defined(USE_NPU) lm_head_ = register_module("lm_head", layer::LmHead(context)); -#elif defined(USE_MLU) +#else // lm_head_ is default to no quantization lm_head_ = register_module("lm_head", @@ -415,7 +415,7 @@ class Qwen3MoeForCausalLMImpl : public torch::nn::Module { const torch::Tensor& seleted_idxes) { #if defined(USE_NPU) return lm_head_(hidden_states, seleted_idxes, 0); -#elif defined(USE_MLU) +#else // select tokens if provided auto h = hidden_states; if (seleted_idxes.defined()) { diff --git a/xllm/processors/CMakeLists.txt b/xllm/processors/CMakeLists.txt index fd3e6070..27365efe 100755 --- a/xllm/processors/CMakeLists.txt +++ b/xllm/processors/CMakeLists.txt @@ -11,17 +11,9 @@ set(BASE_DEPS :chat_template glog::glog torch + torch_python ) -if(USE_NPU) - # Check if NPU is being used - include_directories($ENV{XLLM_KERNELS_PATH}/include/xllm_kernels/core/include) - include_directories($ENV{XLLM_KERNELS_PATH}/include/xllm_kernels) - - # Modify dependencies for NPU - list(APPEND BASE_DEPS torch_npu) - list(APPEND BASE_DEPS :npu_layers) -endif() # Define the library cc_library( @@ -42,4 +34,4 @@ cc_library( pywarpper_image_processor.cpp DEPS ${BASE_DEPS} -) \ No newline at end of file +)