Skip to content

Commit f1a0802

Browse files
committed
feat: enable torch_npu graph mode for Qwen-3 dense with single and multi-card TP support.
1 parent 90dfadb commit f1a0802

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

80 files changed

+2246
-3349
lines changed

CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,7 @@ else()
298298
endif()
299299

300300
if(USE_NPU)
301+
add_definitions(-DUSE_NPU_TORCH)
301302
add_definitions(-DUSE_NPU)
302303
add_definitions(-DBUILD_LIBTORCH)
303304
add_definitions(-DTORCH_SETCUSTOMHANDLER=ON)
@@ -309,6 +310,7 @@ if(USE_NPU)
309310
$ENV{PYTORCH_INSTALL_PATH}/include
310311
$ENV{PYTORCH_INSTALL_PATH}/include/torch/csrc/api/include
311312
$ENV{PYTORCH_NPU_INSTALL_PATH}/include
313+
$ENV{PYTORCH_INSTALL_PATH}/include/torch/csrc/distributed
312314
$ENV{NPU_HOME_PATH}/include
313315
$ENV{ATB_HOME_PATH}/include
314316
$ENV{NPU_HOME_PATH}/opp/vendors/xllm/op_api/include/

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from setuptools.command.bdist_wheel import bdist_wheel
1818
from setuptools.command.build_ext import build_ext
1919

20-
BUILD_TEST_FILE = True
20+
BUILD_TEST_FILE = False
2121
BUILD_EXPORT = True
2222

2323
# get cpu architecture

xllm/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ target_link_libraries(xllm PRIVATE glog::glog brpc leveldb::leveldb ZLIB::ZLIB p
3434
add_dependencies(xllm brpc-static)
3535

3636
if(USE_NPU)
37-
set(COMMON_LIBS Python::Python ascendcl hccl c_sec nnopbase ms_tools_ext)
37+
set(COMMON_LIBS Python::Python ascendcl hccl c_sec nnopbase ms_tools_ext torch_npu torch_python)
3838
elseif(USE_MLU)
3939
set(COMMON_LIBS Python::Python)
4040
endif()

xllm/core/common/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ cc_library(
2828
absl::random_random
2929
absl::strings
3030
torch
31+
$<$<BOOL:${USE_NPU}>:torch_python>
3132
$<$<BOOL:${USE_NPU}>:torch_npu>
3233
$<$<BOOL:${USE_MSPTI}>:mspti>
3334
$<$<BOOL:${USE_NPU}>:ms_tools_ext>

xllm/core/common/global_flags.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -389,3 +389,5 @@ DEFINE_string(reasoning_parser,
389389

390390
// --- qwen3 reranker config ---
391391
DEFINE_bool(enable_qwen3_reranker, false, "Whether to enable qwen3 reranker.");
392+
393+
DEFINE_bool(enable_native_npu, true, "Whether to enable native NPU support.");

xllm/core/common/global_flags.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,3 +202,5 @@ DECLARE_bool(enable_qwen3_reranker);
202202
DECLARE_string(reasoning_parser);
203203

204204
DECLARE_bool(enable_shm);
205+
206+
DECLARE_bool(enable_native_npu);

xllm/core/distributed_runtime/spawn_worker_server/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ cc_binary(
1212
:models
1313
:model
1414
:distributed_runtime
15+
:parallel_state
1516
absl::strings
1617
xllm_kernels
1718
ascendcl

xllm/core/distributed_runtime/worker_server.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,12 @@ void WorkerServer::create_server(
100100
const ParallelArgs* parallel_args = comm.parallel_args();
101101
#if defined(USE_MLU) || defined(USE_CUDA)
102102
comm.create_process_groups(master_node_addr, device);
103+
#elif defined(USE_NPU)
104+
// TODO: Refactor to use model_type or other appropriate enumeration for
105+
// condition checking
106+
if (FLAGS_enable_native_npu) {
107+
comm.create_process_groups(master_node_addr, device);
108+
}
103109
#endif
104110

105111
WorkerType worker_type =

xllm/core/framework/model/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@ set(BASE_DEPS
1717
if(USE_NPU)
1818
list(APPEND BASE_DEPS :npu_layers)
1919
list(APPEND BASE_DEPS :platform_npu)
20-
else()
21-
list(APPEND BASE_DEPS :common_layers)
2220
endif()
2321

22+
list(APPEND BASE_DEPS :common_layers)
23+
2424

2525
# Define the library
2626
cc_library(

xllm/core/framework/parallel_state/collective_communicator.cpp

Lines changed: 3 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@ limitations under the License.
1818
#include "mapping_npu.h"
1919

2020
#if defined(USE_NPU)
21+
#include <torch_npu/csrc/distributed/ProcessGroupHCCL.hpp>
22+
23+
#include "npu_process_group.h"
2124
#include "xllm_kernels/core/include/atb_speed/base/external_comm_manager.h"
2225
#include "xllm_kernels/core/include/atb_speed/utils/singleton.h"
2326
#include "xllm_kernels/models/base/param/mapping.h"
@@ -30,23 +33,6 @@ limitations under the License.
3033
#include "parallel_args.h"
3134
#include "util/net.h"
3235

33-
namespace {
34-
#if defined(USE_NPU)
35-
std::unique_ptr<xllm::ProcessGroup> create_process_group(
36-
int rank,
37-
int world_size,
38-
int rank_size,
39-
int port,
40-
bool trans,
41-
const std::string& host,
42-
const std::string& group_name,
43-
const torch::Device& device) {
44-
LOG(FATAL) << "Unsupported device type";
45-
return nullptr;
46-
}
47-
#endif
48-
} // namespace
49-
5036
namespace xllm {
5137

5238
CollectiveCommunicator::CollectiveCommunicator(int global_rank,

0 commit comments

Comments
 (0)