Skip to content

Commit f8807a8

Browse files
a120092009Gossity
andauthored
feat: support dp and ep and fix rope_forward on mlu device. (#285)
Co-authored-by: guoxueting <[email protected]>
1 parent 3a84d60 commit f8807a8

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+677
-298
lines changed

xllm/core/framework/parallel_state/CMakeLists.txt

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,8 @@ cc_library(
1616
SRCS
1717
mapping_npu.cpp
1818
parallel_state.cpp
19+
process_group.cpp
1920
$<$<BOOL:${USE_NPU}>:npu_process_group.cpp>
20-
$<$<BOOL:${USE_MLU}>:mlu_process_group.cpp>
21-
$<$<BOOL:${USE_CUDA}>:cuda_process_group.cpp>
2221
collective_communicator.cpp
2322
DEPS
2423
:common
@@ -45,4 +44,4 @@ if(USE_NPU)
4544
c_sec
4645
spdlog::spdlog
4746
)
48-
endif()
47+
endif()

xllm/core/framework/parallel_state/collective_communicator.cpp

Lines changed: 54 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,6 @@ limitations under the License.
2222
#include "xllm_kernels/core/include/atb_speed/utils/singleton.h"
2323
#include "xllm_kernels/models/base/param/mapping.h"
2424
#elif defined(USE_MLU)
25-
#include <torch_mlu/csrc/framework/distributed/process_group_cncl.hpp>
26-
2725
#include "mlu_process_group.h"
2826
#elif defined(USE_CUDA)
2927
#include "cuda_process_group.h"
@@ -33,25 +31,20 @@ limitations under the License.
3331
#include "util/net.h"
3432

3533
namespace {
34+
#if defined(USE_NPU)
3635
std::unique_ptr<xllm::ProcessGroup> create_process_group(
3736
int rank,
3837
int world_size,
3938
int rank_size,
4039
int port,
40+
bool trans,
4141
const std::string& host,
4242
const std::string& group_name,
4343
const torch::Device& device) {
44-
#if defined(USE_MLU)
45-
return std::make_unique<xllm::ProcessGroupCncl>(
46-
rank, world_size, rank_size, port, host, group_name, device);
47-
#elif defined(USE_CUDA)
48-
return std::make_unique<xllm::ProcessGroupNccl>(
49-
rank, world_size, rank_size, port, host, group_name, device);
50-
#else
5144
LOG(FATAL) << "Unsupported device type";
5245
return nullptr;
53-
#endif
5446
}
47+
#endif
5548
} // namespace
5649

5750
namespace xllm {
@@ -130,24 +123,69 @@ void CollectiveCommunicator::create_process_groups(
130123
int global_rank = parallel_args_->rank();
131124
int world_size = parallel_args_->world_size();
132125
int dp_size = parallel_args_->dp_size();
133-
134-
process_group_ = create_process_group(
135-
global_rank, world_size, world_size, ++port, host, "world_group", device);
126+
int ep_size = parallel_args_->ep_size();
127+
process_group_ = create_process_group(global_rank,
128+
world_size,
129+
world_size,
130+
++port,
131+
false,
132+
host,
133+
"world_group",
134+
device);
135+
parallel_args_->process_group_ = process_group_.get();
136136

137137
int tp_size = world_size / dp_size;
138138
CHECK_EQ(tp_size * dp_size, world_size);
139139
int port_offset = global_rank / tp_size + 1;
140-
141140
tp_group_ = create_process_group(global_rank,
142141
world_size,
143142
tp_size,
144143
port + port_offset,
144+
false,
145145
host,
146146
"tp_group",
147147
device);
148-
149-
parallel_args_->process_group_ = process_group_.get();
150148
parallel_args_->tp_group_ = tp_group_.get();
149+
port += dp_size;
150+
151+
if (dp_size > 1) {
152+
port_offset = global_rank % tp_size + 1;
153+
dp_local_process_group_ = create_process_group(global_rank,
154+
world_size,
155+
dp_size,
156+
port + port_offset,
157+
true,
158+
host,
159+
"dp_group",
160+
device);
161+
parallel_args_->dp_local_process_group_ = dp_local_process_group_.get();
162+
port += tp_size;
163+
}
164+
165+
if (ep_size > 1) {
166+
int moe_tp_size = world_size / ep_size;
167+
port_offset = global_rank / moe_tp_size + 1;
168+
moe_tp_group_ = create_process_group(global_rank,
169+
world_size,
170+
moe_tp_size,
171+
port + port_offset,
172+
false,
173+
host,
174+
"moe_tp_group",
175+
device);
176+
parallel_args_->moe_tp_group_ = moe_tp_group_.get();
177+
port += ep_size;
178+
port_offset = global_rank % moe_tp_size + 1;
179+
moe_ep_group_ = create_process_group(global_rank,
180+
world_size,
181+
ep_size,
182+
port + port_offset,
183+
true,
184+
host,
185+
"moe_ep_group",
186+
device);
187+
parallel_args_->moe_ep_group_ = moe_ep_group_.get();
188+
}
151189
}
152190

153191
const ParallelArgs* CollectiveCommunicator::parallel_args() {

xllm/core/framework/parallel_state/collective_communicator.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ class CollectiveCommunicator {
4242
std::unique_ptr<ProcessGroup> process_group_;
4343
std::unique_ptr<ProcessGroup> dp_local_process_group_;
4444
std::unique_ptr<ProcessGroup> tp_group_;
45+
std::unique_ptr<ProcessGroup> moe_tp_group_;
46+
std::unique_ptr<ProcessGroup> moe_ep_group_;
4547
};
4648

4749
} // namespace xllm

xllm/core/framework/parallel_state/cuda_process_group.cpp

Lines changed: 0 additions & 69 deletions
This file was deleted.

xllm/core/framework/parallel_state/cuda_process_group.h

Lines changed: 35 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -23,30 +23,48 @@ namespace xllm {
2323

2424
class ProcessGroupNccl : public ProcessGroup {
2525
public:
26-
ProcessGroupNccl(int rank,
26+
ProcessGroupNccl(int global_rank,
2727
int world_size,
2828
int rank_size,
2929
int port,
30+
bool trans,
3031
const std::string& host,
3132
const std::string& group_name,
32-
const torch::Device& device);
33+
const torch::Device& device)
34+
: ProcessGroup(device) {
35+
c10::intrusive_ptr<c10d::ProcessGroupNCCL::Options> pg_options =
36+
c10d::ProcessGroupNCCL::Options::create();
37+
pg_options->group_name = group_name;
38+
int rank = global_rank;
39+
if (world_size != rank_size) {
40+
auto [local_rank, group_ranks] =
41+
get_group_rank(world_size, global_rank, rank_size, trans);
42+
pg_options->global_ranks_in_group = group_ranks;
43+
rank = local_rank;
44+
}
3345

34-
~ProcessGroupNccl() override;
46+
auto store = create_tcp_store(host, port, rank);
47+
pg_ = std::make_unique<c10d::ProcessGroupNCCL>(
48+
store, rank, rank_size, pg_options);
49+
}
3550

36-
void allreduce(torch::Tensor& input) override;
37-
38-
void allgather(torch::Tensor input,
39-
std::vector<torch::Tensor>& outputs) override;
40-
41-
private:
42-
// rank of current process
43-
int rank_ = 0;
44-
45-
// number of processes
46-
int world_size_ = 0;
47-
48-
// nccl process group
49-
std::unique_ptr<c10d::ProcessGroupNCCL> nccl_pg_;
51+
~ProcessGroupNccl() override {
52+
if (pg_) {
53+
pg_->shutdown();
54+
}
55+
}
5056
};
5157

58+
std::unique_ptr<xllm::ProcessGroup> create_process_group(
59+
int rank,
60+
int world_size,
61+
int rank_size,
62+
int port,
63+
bool trans,
64+
const std::string& host,
65+
const std::string& group_name,
66+
const torch::Device& device) {
67+
return std::make_unique<ProcessGroupNccl>(
68+
rank, world_size, rank_size, port, trans, host, group_name, device);
69+
}
5270
} // namespace xllm

xllm/core/framework/parallel_state/mlu_process_group.cpp

Lines changed: 0 additions & 67 deletions
This file was deleted.

0 commit comments

Comments
 (0)