Skip to content

Commit eda1805

Browse files
a120092009Gossity
andauthored
feat: support qwen3-32b and qwen3-30b-a3b on mlu device. (#227)
Co-authored-by: guoxueting <[email protected]>
1 parent aa54363 commit eda1805

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

77 files changed

+3319
-185
lines changed

CMakeLists.txt

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -332,7 +332,24 @@ endif()
332332

333333
if(USE_MLU)
334334
add_definitions(-DUSE_MLU)
335-
# TODO(mlu): set mlu environment variables
335+
set(CMAKE_VERBOSE_MAKEFILE ON)
336+
include_directories(
337+
$ENV{PYTHON_INCLUDE_PATH}
338+
$ENV{PYTORCH_INSTALL_PATH}/include
339+
$ENV{PYTORCH_INSTALL_PATH}/include/torch/csrc/api/include
340+
$ENV{PYTORCH_MLU_INSTALL_PATH}
341+
$ENV{PYTORCH_MLU_INSTALL_PATH}/../
342+
$ENV{PYTORCH_MLU_INSTALL_PATH}/csrc
343+
$ENV{NEUWARE_HOME}/include
344+
)
345+
346+
link_directories(
347+
$ENV{PYTHON_LIB_PATH}
348+
$ENV{PYTORCH_INSTALL_PATH}/lib
349+
$ENV{PYTORCH_MLU_INSTALL_PATH}/csrc/lib
350+
$ENV{PYTORCH_MLU_INSTALL_PATH}
351+
$ENV{NEUWARE_HOME}/lib64
352+
)
336353
endif()
337354

338355
# check if USE_CXX11_ABI is set correctly

setup.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,14 @@ def get_torch_root_path():
106106
except ImportError:
107107
return None
108108

109+
def get_torch_mlu_root_path():
110+
try:
111+
import torch_mlu
112+
import os
113+
return os.path.dirname(os.path.abspath(torch_mlu.__file__))
114+
except ImportError:
115+
return None
116+
109117

110118
def set_npu_envs():
111119
PYTORCH_NPU_INSTALL_PATH = os.getenv("PYTORCH_NPU_INSTALL_PATH")
@@ -196,11 +204,13 @@ def set_npu_envs():
196204
os.environ["LCCL_DETERMINISTIC"] = "0"
197205
os.environ["LCCL_PARALLEL"] = "0"
198206

199-
# TODO(mlu): set mlu environment variables
207+
200208
def set_mlu_envs():
201209
os.environ["PYTHON_INCLUDE_PATH"] = get_python_include_path()
202210
os.environ["PYTHON_LIB_PATH"] = get_torch_root_path()
203211
os.environ["LIBTORCH_ROOT"] = get_torch_root_path()
212+
os.environ["PYTORCH_INSTALL_PATH"] = get_torch_root_path()
213+
os.environ["PYTORCH_MLU_INSTALL_PATH"] = get_torch_mlu_root_path()
204214

205215
class CMakeExtension(Extension):
206216
def __init__(self, name: str, path: str, sourcedir: str = "") -> None:

xllm/core/common/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ cc_library(
3434
Boost::serialization
3535
cpprest
3636
etcd-cpp-api
37+
$<$<BOOL:${USE_MLU}>:torch_mlu>
3738
)
3839

3940
cc_library(

xllm/core/distributed_runtime/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ cc_library(
3131
proto::xllm_proto
3232
absl::flat_hash_set
3333
:parallel_state
34+
:collective_service
3435
)
3536

3637
cc_library(

xllm/core/distributed_runtime/collective_service.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@ void CollectiveService::Sync(::google::protobuf::RpcController* controller,
4444
const proto::AddressInfo* request,
4545
proto::CommUniqueIdList* response,
4646
::google::protobuf::Closure* done) {
47-
#if defined(USE_NPU)
4847
brpc::ClosureGuard done_guard(done);
4948

5049
std::string address = request->address();
@@ -53,10 +52,9 @@ void CollectiveService::Sync(::google::protobuf::RpcController* controller,
5352
std::lock_guard<std::mutex> lock(mutex_);
5453
addrs_map_[global_rank] = address;
5554
}
56-
55+
#if defined(USE_NPU)
5756
to_proto_list(root_infos_, response);
5857
#endif
59-
return;
6058
}
6159

6260
std::unordered_map<int32_t, std::string> CollectiveService::wait() {

xllm/core/distributed_runtime/worker_server.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,9 @@ void WorkerServer::create_server(const runtime::Options& options,
9191

9292
CollectiveCommunicator comm(worker_global_rank, world_size, dp_size, ep_size);
9393
const ParallelArgs* parallel_args = comm.parallel_args();
94+
#if defined(USE_MLU)
95+
comm.create_process_groups_cncl(master_node_addr, device);
96+
#endif
9497

9598
WorkerType worker_type =
9699
(options.task_type() == "generate") ? WorkerType::LLM : WorkerType::ELM;

xllm/core/framework/CMakeLists.txt

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ add_subdirectory(xtensor)
2222
add_subdirectory(dit_cache)
2323

2424
cc_library(
25-
NAME
25+
NAME
2626
parallel_state
2727
HDRS
2828
mapping_npu.h
@@ -33,30 +33,33 @@ cc_library(
3333
DEPS
3434
:common
3535
torch
36-
hccl
36+
$<$<BOOL:${USE_MLU}>:torch_mlu>
37+
$<$<BOOL:${USE_NPU}>:hccl>
3738
glog::glog
3839
)
3940

4041

41-
cc_test(
42-
NAME
43-
mapping_npu_test
44-
SRCS
45-
mapping_npu_test.cpp
46-
DEPS
47-
parallel_state
48-
absl::synchronization
49-
absl::time
50-
GTest::gtest_main
51-
xllm_kernels
52-
ascendcl
53-
atb
54-
c_sec
55-
spdlog::spdlog
56-
)
42+
if(USE_NPU)
43+
cc_test(
44+
NAME
45+
mapping_npu_test
46+
SRCS
47+
mapping_npu_test.cpp
48+
DEPS
49+
parallel_state
50+
absl::synchronization
51+
absl::time
52+
GTest::gtest_main
53+
xllm_kernels
54+
ascendcl
55+
atb
56+
c_sec
57+
spdlog::spdlog
58+
)
59+
endif()
5760

5861
cc_library(
59-
NAME
62+
NAME
6063
model_loader
6164
HDRS
6265
hf_model_loader.h
@@ -78,7 +81,7 @@ cc_library(
7881
)
7982

8083
cc_library(
81-
NAME
84+
NAME
8285
model_context
8386
HDRS
8487
model_context.h

xllm/core/framework/block/CMakeLists.txt

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -28,24 +28,26 @@ cc_library(
2828
torch
2929
)
3030

31-
set(TEST_SRCS
32-
block_manager_test.cpp
33-
)
31+
if(USE_NPU)
32+
set(TEST_SRCS
33+
block_manager_test.cpp
34+
)
3435

35-
cc_test(
36-
NAME
37-
block_test
38-
SRCS
39-
${TEST_SRCS}
40-
DEPS
41-
:block
42-
:flags
43-
:kv_cache
44-
:prefix_cache
45-
absl::random_random
46-
Boost::serialization
47-
GTest::gtest_main
48-
)
36+
cc_test(
37+
NAME
38+
block_test
39+
SRCS
40+
${TEST_SRCS}
41+
DEPS
42+
:block
43+
:flags
44+
:kv_cache
45+
:prefix_cache
46+
absl::random_random
47+
Boost::serialization
48+
GTest::gtest_main
49+
)
4950

50-
target_link_libraries(block_test PRIVATE brpc OpenSSL::SSL OpenSSL::Crypto ascendcl Folly::folly)
51-
add_dependencies(block_test brpc-static)
51+
target_link_libraries(block_test PRIVATE brpc OpenSSL::SSL OpenSSL::Crypto ascendcl Folly::folly)
52+
add_dependencies(block_test brpc-static)
53+
endif()

xllm/core/framework/dit_model_context.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ const QuantArgs& DiTModelContext::get_quant_args(
7272
}
7373
}
7474

75+
#if defined(USE_NPU)
7576
ModelContext DiTModelContext::get_model_context(
7677
const std::string& component) const {
7778
return ModelContext(parallel_args_,
@@ -80,5 +81,6 @@ ModelContext DiTModelContext::get_model_context(
8081
tensor_options_,
8182
context_);
8283
}
84+
#endif
8385

8486
} // namespace xllm

xllm/core/framework/dit_model_context.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,9 @@ class DiTModelContext {
4242

4343
const QuantArgs& get_quant_args(const std::string& component) const;
4444

45+
#if defined(USE_NPU)
4546
ModelContext get_model_context(const std::string& component) const;
47+
#endif
4648

4749
const ParallelArgs& get_parallel_args() const { return parallel_args_; }
4850

@@ -52,7 +54,9 @@ class DiTModelContext {
5254

5355
const std::string& model_type() const { return model_type_; }
5456

57+
#if defined(USE_NPU)
5558
const atb::Context* get_atb_context() const { return context_; }
59+
#endif
5660

5761
private:
5862
std::unordered_map<std::string, ModelArgs> model_args_;

0 commit comments

Comments
 (0)