Skip to content

Commit 34a136e

Browse files
authored
[Comm] Fix_NPU_Comm (#71723) (#71742)
1 parent 5aab1be commit 34a136e

File tree

5 files changed

+151
-114
lines changed

5 files changed

+151
-114
lines changed

paddle/common/macros.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ limitations under the License. */
3030
#define COMM_CONTEXT phi::distributed::NCCLCommContext
3131
#elif (defined(PADDLE_WITH_XPU) && defined(PADDLE_WITH_XPU_BKCL))
3232
#define COMM_CONTEXT phi::distributed::BKCLCommContext
33+
#elif defined(PADDLE_WITH_CUSTOM_DEVICE)
34+
#define COMM_CONTEXT phi::distributed::XCCLCommContext
3335
#endif
3436

3537
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
@@ -38,6 +40,9 @@ limitations under the License. */
3840
#elif defined(PADDLE_WITH_XPU_BKCL)
3941
#define CREATE_COMM_CONTEXT \
4042
phi::distributed::CommContextManager::CreateBKCLCommContext
43+
#elif defined(PADDLE_WITH_CUSTOM_DEVICE)
44+
#define CREATE_COMM_CONTEXT \
45+
phi::distributed::CommContextManager::CreateXCCLCommContext
4146
#endif
4247

4348
namespace common {

paddle/phi/api/generator/api_gen.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -526,6 +526,9 @@ def source_include(header_file_path):
526526
#elif (defined(PADDLE_WITH_XPU) && defined(PADDLE_WITH_XPU_BKCL))
527527
#include "paddle/phi/core/distributed/comm_context_manager.h"
528528
#include "paddle/phi/core/distributed/bkcl_comm_context.h"
529+
#elif PADDLE_WITH_CUSTOM_DEVICE
530+
#include "paddle/phi/core/distributed/comm_context_manager.h"
531+
#include "paddle/phi/core/distributed/xccl_comm_context.h"
529532
#endif
530533
531534
#ifdef PADDLE_WITH_DISTRIBUTE

paddle/phi/api/generator/dist_api_gen.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -90,11 +90,17 @@
9090
auto store = phi::distributed::CreateOrGetGlobalTCPStore();
9191
CREATE_COMM_CONTEXT(store, std::to_string(ring_id), rank, nranks);
9292
}}
93+
#elif defined(PADDLE_WITH_CUSTOM_DEVICE)
94+
const auto & comm_context_manager_ = phi::distributed::CommContextManager::GetInstance();
95+
if (nranks > 1 && !comm_context_manager_.Has(std::to_string(ring_id))) {{
96+
auto store = phi::distributed::CreateOrGetGlobalTCPStore();
97+
CREATE_COMM_CONTEXT(store, std::to_string(ring_id), phi::distributed::GetDefaultPlace(), rank, nranks);
98+
}}
9399
#endif
94100
"""
95101

96102
SET_NCCL_COMMCONTEXT = """
97-
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || defined(PADDLE_WITH_XPU_BKCL)
103+
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || defined(PADDLE_WITH_XPU_BKCL) || defined(PADDLE_WITH_CUSTOM_DEVICE)
98104
const auto & comm_context_manager = phi::distributed::CommContextManager::GetInstance();
99105
COMM_CONTEXT* comm_context = nullptr;
100106
if (comm_context_manager.Has(std::to_string(ring_id))) {{
@@ -107,8 +113,19 @@
107113
"NCCLCommContext is nullptr, collective op should "
108114
"has ring_id(%d) attr.",
109115
std::to_string(ring_id)));
110-
if (!comm_context->GetDevContext() || !comm_context->GetDevContext()->GetCommContext())
111-
{{
116+
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || defined(PADDLE_WITH_XPU_BKCL)
117+
if (!comm_context->GetDevContext() || !comm_context->GetDevContext()->GetCommContext())
118+
{{
119+
auto kernel_res = phi::KernelFactory::Instance().SelectKernelOrThrowError(
120+
"{}", {{kernel_backend, kernel_layout, kernel_data_type}}, true);
121+
if (FLAGS_low_precision_op_list) {{
122+
phi::KernelFactory::Instance().AddToLowPrecisionKernelList("{}", kernel_data_type);
123+
}}
124+
Backend act_kernel_backend = kernel_res.has_fallback_cpu ? Backend::CPU : kernel_backend;
125+
auto* dev_context = GetDeviceContextByBackend(act_kernel_backend);
126+
dev_context->SetCommContext(comm_context);
127+
}}
128+
#elif defined(PADDLE_WITH_CUSTOM_DEVICE)
112129
auto kernel_res = phi::KernelFactory::Instance().SelectKernelOrThrowError(
113130
"{}", {{kernel_backend, kernel_layout, kernel_data_type}}, true);
114131
if (FLAGS_low_precision_op_list) {{
@@ -117,7 +134,7 @@
117134
Backend act_kernel_backend = kernel_res.has_fallback_cpu ? Backend::CPU : kernel_backend;
118135
auto* dev_context = GetDeviceContextByBackend(act_kernel_backend);
119136
dev_context->SetCommContext(comm_context);
120-
}}
137+
#endif
121138
}}
122139
#endif
123140
"""
@@ -1384,7 +1401,9 @@ def generate_nccl_commcontext_init_code(self) -> str:
13841401
return NCCL_COMMCONTEXT_INIT.format(self.kernel['func'][0])
13851402

13861403
def generate_set_nccl_commcontext_code(self) -> str:
1387-
return SET_NCCL_COMMCONTEXT.format(self.kernel['func'][0], self.api)
1404+
return SET_NCCL_COMMCONTEXT.format(
1405+
self.kernel['func'][0], self.api, self.kernel['func'][0], self.api
1406+
)
13881407

13891408
def generate_reshard_input_code(self) -> str:
13901409
input_reshard_code = ""

paddle/phi/api/generator/dist_bw_api_gen.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -523,6 +523,9 @@ def source_include(header_file_path, fw_header_file_path):
523523
#elif defined(PADDLE_WITH_XPU_BKCL)
524524
#include "paddle/phi/core/distributed/comm_context_manager.h"
525525
#include "paddle/phi/core/distributed/bkcl_comm_context.h"
526+
#elif defined(PADDLE_WITH_CUSTOM_DEVICE)
527+
#include "paddle/phi/core/distributed/comm_context_manager.h"
528+
#include "paddle/phi/core/distributed/xccl_comm_context.h"
526529
#endif
527530
528531
#ifdef PADDLE_WITH_DISTRIBUTE

paddle/phi/kernels/custom/c_softmax_with_entropy_kernel.cc

Lines changed: 116 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -30,119 +30,126 @@ void CSoftmaxWithEntropyKernel(const Context& dev_ctx,
3030
const DenseTensor& logits_in,
3131
const DenseTensor& label_in,
3232
int64_t ignore_index,
33-
int ring_id,
3433
int rank,
3534
int nranks,
3635
DenseTensor* softmax,
3736
DenseTensor* loss) {
38-
const int rid = ring_id;
39-
auto map = distributed::ProcessGroupMapFromGid::getInstance();
40-
if (map->has(rid)) {
41-
const phi::DenseTensor* logits = &logits_in;
42-
const phi::DenseTensor* labels = &label_in;
43-
auto softmax_dims = softmax->dims();
44-
auto loss_dims = loss->dims();
45-
46-
const int rid = ring_id;
47-
48-
distributed::ProcessGroup* pg = map->get(rid);
49-
distributed::AllreduceOptions opts;
50-
51-
// allocate memory on device.
52-
const auto& logits_dims = logits->dims();
53-
54-
const int axis = logits_dims.size() - 1;
55-
const int N = phi::funcs::SizeToAxis(axis, logits_dims);
56-
const int D = phi::funcs::SizeFromAxis(axis, logits_dims);
57-
58-
auto logits_2d = std::make_shared<phi::DenseTensor>();
59-
auto labels_1d = std::make_shared<phi::DenseTensor>();
60-
logits_2d->ShareDataWith(*logits).Resize({N, D});
61-
labels_1d->ShareDataWith(*labels).Resize({N});
62-
paddle::Tensor logits_2d_tensor(logits_2d), labels_1d_tensor(labels_1d);
63-
64-
// step 1, obtain logit_max
65-
auto logits_2d_max_tensor = logits_2d_tensor.max({1}, true);
66-
std::vector<phi::DenseTensor> in_out;
67-
in_out.push_back(*reinterpret_cast<phi::DenseTensor*>(
68-
logits_2d_max_tensor.impl().get()));
69-
opts.reduce_op = distributed::ReduceOp::MAX;
70-
pg->AllReduce(in_out, in_out, opts)->Synchronize();
71-
72-
// step 2, obtain logit - logit_max
73-
auto logits_2d_sub_max = paddle::experimental::clip(
74-
logits_2d_tensor - logits_2d_max_tensor, -64., 0.);
75-
76-
// step 3, obtain predict target
77-
const int start_index = rank * D;
78-
auto start_index_tensor =
79-
paddle::experimental::full_like(labels_1d_tensor,
80-
start_index,
81-
labels_1d_tensor.dtype(),
82-
labels_1d_tensor.place());
83-
auto end_index_tensor =
84-
paddle::experimental::full_like(labels_1d_tensor,
85-
start_index + D,
86-
labels_1d_tensor.dtype(),
87-
labels_1d_tensor.place());
88-
auto labels_1d_mask = paddle::experimental::logical_and(
89-
labels_1d_tensor.greater_equal(start_index_tensor),
90-
labels_1d_tensor.less_than(end_index_tensor));
91-
auto real_label_tensor = (labels_1d_tensor - start_index_tensor)
92-
.multiply(paddle::experimental::cast(
93-
labels_1d_mask, labels_1d_tensor.dtype()));
94-
95-
auto predicted_logits_tensor =
96-
logits_2d_sub_max
97-
.multiply(paddle::experimental::cast(
98-
paddle::experimental::one_hot(real_label_tensor, D),
99-
logits_2d_sub_max.dtype()))
100-
.sum({1}, logits_2d_sub_max.dtype(), false)
101-
.multiply(paddle::experimental::cast(labels_1d_mask,
102-
logits_2d_sub_max.dtype()));
103-
104-
in_out.clear();
105-
in_out.push_back(*reinterpret_cast<phi::DenseTensor*>(
106-
predicted_logits_tensor.impl().get()));
107-
opts.reduce_op = distributed::ReduceOp::SUM;
108-
pg->AllReduce(in_out, in_out, opts)->Synchronize();
109-
110-
// step 4, obtain exp(logit)
111-
auto softmax_2d_tensor = logits_2d_sub_max.exp();
112-
113-
// step 5, obtain sum_exp_logits
114-
auto sum_exp_logits_tensor =
115-
softmax_2d_tensor.sum({1}, softmax_2d_tensor.dtype(), false);
116-
117-
in_out.clear();
118-
in_out.push_back(*reinterpret_cast<phi::DenseTensor*>(
119-
sum_exp_logits_tensor.impl().get()));
120-
opts.reduce_op = distributed::ReduceOp::SUM;
121-
pg->AllReduce(in_out, in_out, opts)->Synchronize();
122-
123-
auto softmax_out = softmax_2d_tensor.divide(
124-
paddle::experimental::reshape(sum_exp_logits_tensor, {N, 1}));
125-
auto labels_1d_not_equal_ignore = labels_1d_tensor.not_equal(
126-
paddle::experimental::full_like(labels_1d_tensor,
127-
ignore_index,
128-
labels_1d_tensor.dtype(),
129-
labels_1d_tensor.place()));
130-
auto loss_out =
131-
(sum_exp_logits_tensor.log() - predicted_logits_tensor)
132-
.multiply(paddle::experimental::cast(
133-
labels_1d_not_equal_ignore, sum_exp_logits_tensor.dtype()));
134-
softmax
135-
->ShareDataWith(
136-
*reinterpret_cast<phi::DenseTensor*>(softmax_out.impl().get()))
137-
.Resize(softmax_dims);
138-
loss->ShareDataWith(
139-
*reinterpret_cast<phi::DenseTensor*>(loss_out.impl().get()))
140-
.Resize(loss_dims);
141-
} else {
142-
PADDLE_THROW(
143-
common::errors::Unavailable("CustomDevice c_softmax_with_cross_entropy "
144-
"only support ProcessGroup"));
145-
}
37+
auto comm = reinterpret_cast<phi::distributed::XCCLCommContext*>(
38+
dev_ctx.GetCommContext());
39+
PADDLE_ENFORCE_NE(comm,
40+
nullptr,
41+
common::errors::Unavailable(
42+
"XCCLCommContext is nullptr, collective op should "
43+
"has ring_id attr."));
44+
45+
const phi::DenseTensor* logits = &logits_in;
46+
const phi::DenseTensor* labels = &label_in;
47+
auto softmax_dims = softmax->dims();
48+
auto loss_dims = loss->dims();
49+
50+
const int axis = logits->dims().size() - 1;
51+
const int N = phi::funcs::SizeToAxis(axis, logits->dims());
52+
const int D = phi::funcs::SizeFromAxis(axis, logits->dims());
53+
54+
auto logits_2d = std::make_shared<phi::DenseTensor>();
55+
auto labels_1d = std::make_shared<phi::DenseTensor>();
56+
logits_2d->ShareDataWith(*logits).Resize({N, D});
57+
labels_1d->ShareDataWith(*labels).Resize({N});
58+
paddle::Tensor logits_2d_tensor(logits_2d), labels_1d_tensor(labels_1d);
59+
60+
// step 1, obtain logit_max
61+
auto logits_2d_max_tensor = logits_2d_tensor.max({1}, true);
62+
auto logits_2d_max =
63+
reinterpret_cast<phi::DenseTensor*>(logits_2d_max_tensor.impl().get());
64+
auto& stream = *dev_ctx.GetStream();
65+
phi::DeviceManager::CCLAllReduce(dev_ctx.GetPlace().GetDeviceType(),
66+
logits_2d_max->data<float>(),
67+
logits_2d_max->data<float>(),
68+
logits_2d_max->numel(),
69+
logits_2d_max->dtype(),
70+
phi::ccl::CCLReduceOp::MAX,
71+
comm->GetXcclComm(),
72+
stream);
73+
74+
// step 2, obtain logit - logit_max
75+
auto logits_2d_sub_max = paddle::experimental::clip(
76+
logits_2d_tensor - logits_2d_max_tensor, -64., 0.);
77+
78+
// step 3, obtain predict target
79+
const int start_index = rank * D;
80+
auto start_index_tensor =
81+
paddle::experimental::full_like(labels_1d_tensor,
82+
start_index,
83+
labels_1d_tensor.dtype(),
84+
labels_1d_tensor.place());
85+
auto end_index_tensor =
86+
paddle::experimental::full_like(labels_1d_tensor,
87+
start_index + D,
88+
labels_1d_tensor.dtype(),
89+
labels_1d_tensor.place());
90+
auto labels_1d_mask = paddle::experimental::logical_and(
91+
labels_1d_tensor.greater_equal(start_index_tensor),
92+
labels_1d_tensor.less_than(end_index_tensor));
93+
auto real_label_tensor = (labels_1d_tensor - start_index_tensor)
94+
.multiply(paddle::experimental::cast(
95+
labels_1d_mask, labels_1d_tensor.dtype()));
96+
97+
auto predicted_logits_tensor =
98+
logits_2d_sub_max
99+
.multiply(paddle::experimental::cast(
100+
paddle::experimental::one_hot(real_label_tensor, D),
101+
logits_2d_sub_max.dtype()))
102+
.sum({1}, logits_2d_sub_max.dtype(), false)
103+
.multiply(paddle::experimental::cast(labels_1d_mask,
104+
logits_2d_sub_max.dtype()));
105+
106+
auto predicted_logits =
107+
reinterpret_cast<phi::DenseTensor*>(predicted_logits_tensor.impl().get());
108+
phi::DeviceManager::CCLAllReduce(dev_ctx.GetPlace().GetDeviceType(),
109+
predicted_logits->data<float>(),
110+
predicted_logits->data<float>(),
111+
predicted_logits->numel(),
112+
predicted_logits->dtype(),
113+
phi::ccl::CCLReduceOp::SUM,
114+
comm->GetXcclComm(),
115+
stream);
116+
117+
// step 4, obtain exp(logit)
118+
auto softmax_2d_tensor = logits_2d_sub_max.exp();
119+
120+
// step 5, obtain sum_exp_logits
121+
auto sum_exp_logits_tensor =
122+
softmax_2d_tensor.sum({1}, softmax_2d_tensor.dtype(), false);
123+
124+
auto sum_exp_logits =
125+
reinterpret_cast<phi::DenseTensor*>(sum_exp_logits_tensor.impl().get());
126+
phi::DeviceManager::CCLAllReduce(dev_ctx.GetPlace().GetDeviceType(),
127+
sum_exp_logits->data<float>(),
128+
sum_exp_logits->data<float>(),
129+
sum_exp_logits->numel(),
130+
sum_exp_logits->dtype(),
131+
phi::ccl::CCLReduceOp::SUM,
132+
comm->GetXcclComm(),
133+
stream);
134+
135+
auto softmax_out = softmax_2d_tensor.divide(
136+
paddle::experimental::reshape(sum_exp_logits_tensor, {N, 1}));
137+
auto labels_1d_not_equal_ignore = labels_1d_tensor.not_equal(
138+
paddle::experimental::full_like(labels_1d_tensor,
139+
ignore_index,
140+
labels_1d_tensor.dtype(),
141+
labels_1d_tensor.place()));
142+
auto loss_out =
143+
(sum_exp_logits_tensor.log() - predicted_logits_tensor)
144+
.multiply(paddle::experimental::cast(labels_1d_not_equal_ignore,
145+
sum_exp_logits_tensor.dtype()));
146+
softmax
147+
->ShareDataWith(
148+
*reinterpret_cast<phi::DenseTensor*>(softmax_out.impl().get()))
149+
.Resize(softmax_dims);
150+
loss->ShareDataWith(
151+
*reinterpret_cast<phi::DenseTensor*>(loss_out.impl().get()))
152+
.Resize(loss_dims);
146153
}
147154
} // namespace phi
148155

0 commit comments

Comments
 (0)