Skip to content

Commit c42f0a7

Browse files
authored
[XPU] support moe_gate_dispatch_partial_nosoftmaxtopk, expand_modality_expert_id and build_src_rank_and_local_expert_id (#73234)
* [XPU] support moe_gate_dispatch_partial_nosoftmaxtopk, expand_modality_expert_id and build_src_rank_and_local_expert_id * fix * fix * fix * fix * fix * fix * fix * fix * fix
1 parent 94f0691 commit c42f0a7

12 files changed

+729
-46
lines changed

paddle/phi/backends/xpu/xpu3_op_list.cc

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1851,6 +1851,18 @@ XPUOpMap& get_kl3_ops() {
18511851
phi::DataType::FLOAT16,
18521852
phi::DataType::FLOAT32,
18531853
phi::DataType::INT32})},
1854+
{"expand_modality_expert_id",
1855+
XPUKernelSet({phi::DataType::INT32, phi::DataType::INT64})},
1856+
{"build_src_rank_and_local_expert_id",
1857+
XPUKernelSet({phi::DataType::INT32, phi::DataType::INT64})},
1858+
{"moe_gate_dispatch_partial_nosoftmaxtopk",
1859+
XPUKernelSet({phi::DataType::FLOAT32,
1860+
phi::DataType::FLOAT16,
1861+
phi::DataType::BFLOAT16})},
1862+
{"moe_gate_dispatch_partial_nosoftmaxtopk_grad",
1863+
XPUKernelSet({phi::DataType::FLOAT32,
1864+
phi::DataType::FLOAT16,
1865+
phi::DataType::BFLOAT16})},
18541866
{"blha_get_max_len",
18551867
XPUKernelSet({phi::DataType::INT32, phi::DataType::INT64})},
18561868
{"full_with_tensor",
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/phi/backends/xpu/enforce_xpu.h"
16+
#include "paddle/phi/backends/xpu/xpu_context.h"
17+
#include "paddle/phi/core/kernel_registry.h"
18+
19+
namespace phi {
20+
21+
template <typename T, typename Context>
22+
void ExpandModalityExpertIDKernel(const Context& dev_ctx,
23+
const DenseTensor& expert_id,
24+
int64_t num_expert_per_modality,
25+
int64_t group_size,
26+
int64_t modality_offset,
27+
bool is_group_expert,
28+
DenseTensor* expert_id_out) {
29+
dev_ctx.template Alloc<T>(expert_id_out);
30+
auto expert_id_shape = expert_id.dims();
31+
int64_t seqlen = expert_id_shape[0];
32+
int64_t k = expert_id_shape[1];
33+
34+
int r = xpu::expand_modality_expert_id(dev_ctx.x_context(),
35+
expert_id.data<T>(),
36+
expert_id_out->data<T>(),
37+
seqlen,
38+
k,
39+
num_expert_per_modality,
40+
group_size,
41+
modality_offset,
42+
is_group_expert);
43+
PADDLE_ENFORCE_XDNN_SUCCESS(r, "expand_modality_expert_id");
44+
}
45+
} // namespace phi
46+
47+
PD_REGISTER_KERNEL(expand_modality_expert_id,
48+
XPU,
49+
ALL_LAYOUT,
50+
phi::ExpandModalityExpertIDKernel,
51+
int,
52+
int64_t) {}
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/phi/backends/xpu/enforce_xpu.h"
16+
#include "paddle/phi/backends/xpu/xpu_context.h"
17+
#include "paddle/phi/core/kernel_registry.h"
18+
19+
namespace phi {
20+
21+
template <typename T, typename Context>
22+
void BuildSrcRankAndLocalExpertIdKernel(
23+
const Context& dev_ctx,
24+
const DenseTensor& expert_num_global_tensor,
25+
const std::vector<int64_t>& expert_num_global,
26+
int64_t num_local_experts,
27+
DenseTensor* src_rank,
28+
DenseTensor* local_expert_id) {
29+
int64_t token_num =
30+
std::accumulate(expert_num_global.begin(), expert_num_global.end(), 0);
31+
32+
const int64_t* expert_num_global_tensor_data =
33+
expert_num_global_tensor.data<int64_t>();
34+
35+
// Hard coded as ernie-core did.
36+
int* src_rank_data = dev_ctx.template Alloc<int>(src_rank);
37+
int* local_expert_id_data = dev_ctx.template Alloc<int>(local_expert_id);
38+
39+
int r = xpu::build_srcrank_and_local_expert_id(
40+
dev_ctx.x_context(),
41+
src_rank_data,
42+
local_expert_id_data,
43+
expert_num_global_tensor_data,
44+
expert_num_global,
45+
token_num,
46+
static_cast<int64_t>(expert_num_global.size()),
47+
num_local_experts);
48+
PADDLE_ENFORCE_XDNN_SUCCESS(r, "build_srcrank_and_local_expert_id");
49+
}
50+
51+
} // namespace phi
52+
53+
PD_REGISTER_KERNEL(build_src_rank_and_local_expert_id,
54+
XPU,
55+
ALL_LAYOUT,
56+
phi::BuildSrcRankAndLocalExpertIdKernel,
57+
int,
58+
int64_t) {}

paddle/phi/kernels/xpu/moe_gate_dispatch_grad_kernel.cc

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
// NOLINT
21
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
32
//
43
// Licensed under the Apache License, Version 2.0 (the "License");
@@ -73,13 +72,8 @@ void moe_dispatch_grad(
7372
int64_t num_rows = scatter_index.dims()[1];
7473

7574
const std::vector<int32_t> axis = {1, 0};
76-
DenseTensor t_scatter_index_tmp;
77-
phi::Transpose<int, Context>(
78-
dev_ctx, scatter_index, axis, &t_scatter_index_tmp);
79-
DenseTensor t_scatter_index_;
80-
phi::ContiguousKernel<int, Context>(
81-
dev_ctx, t_scatter_index_tmp, &t_scatter_index_);
82-
const DenseTensor t_scatter_index = t_scatter_index_;
75+
DenseTensor t_scatter_index;
76+
phi::Transpose<int, Context>(dev_ctx, scatter_index, axis, &t_scatter_index);
8377

8478
// output
8579
DenseTensor x_grad_tmp =
@@ -92,7 +86,7 @@ void moe_dispatch_grad(
9286
auto combine_weights_data =
9387
reinterpret_cast<const float*>(combine_weights.data<float>());
9488
auto t_scatter_index_data =
95-
reinterpret_cast<const int*>(t_scatter_index_tmp.data<int>());
89+
reinterpret_cast<const int*>(t_scatter_index.data<int>());
9690
auto combine_weights_grad_data =
9791
reinterpret_cast<const float*>(combine_weights_grad.data<float>());
9892
auto expert_id_data = reinterpret_cast<const int*>(expert_id.data<int>());

paddle/phi/kernels/xpu/moe_gate_dispatch_kernel.cc

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
// NOLINT
21
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
32
//
43
// Licensed under the Apache License, Version 2.0 (the "License");
@@ -34,30 +33,32 @@ void moe_dispatch_fwd(const Context &dev_ctx,
3433
DenseTensor *expert_offset,
3534
DenseTensor *expert_id,
3635
bool use_pad) {
37-
if (!(x.dtype() == paddle::DataType::FLOAT32 ||
38-
x.dtype() == paddle::DataType::FLOAT16 ||
39-
x.dtype() == paddle::DataType::BFLOAT16)) {
40-
PD_THROW(
41-
"Unsupported dtype for x, "
42-
"currently float32, float16 and bfloat16 are supported.");
43-
}
44-
45-
if (gate_logits.dtype() != paddle::DataType::FLOAT32) {
46-
PD_THROW(
47-
"Unsupported dtype for gate_logits, "
48-
"currently only float32 is supported.");
49-
}
36+
PADDLE_ENFORCE_EQ(gate_logits.dtype(),
37+
paddle::DataType::FLOAT32,
38+
::common::errors::InvalidArgument(
39+
"Unsupported dtype for gate_logits, "
40+
"currently only float32 is supported."));
5041

5142
int64_t s = x.dims()[0];
5243
int64_t d = x.dims()[1];
5344
int64_t e = gate_logits.dims()[1];
5445

55-
if (k <= 0) PD_THROW("the k of topk must more than 0.");
56-
if (capacity <= 0) PD_THROW("the capacity of each expert must more than 0.");
57-
if (e < k) PD_THROW("the amount of experts must greater than k.");
58-
if (k > 512) PD_THROW("currently, the k of topk must lesser than 512.");
59-
if (e > 512 * 64 * 12)
60-
PD_THROW("currently, he amount of experts must lesser than 393216.");
46+
PADDLE_ENFORCE_GT(
47+
k,
48+
0,
49+
::common::errors::InvalidArgument("the k of topk must more than 0."));
50+
PADDLE_ENFORCE_GT(capacity,
51+
0,
52+
::common::errors::InvalidArgument(
53+
"the capacity of each expert must more than 0."));
54+
PADDLE_ENFORCE_GE(e,
55+
k,
56+
::common::errors::InvalidArgument(
57+
"the amount of experts must greater than k."));
58+
PADDLE_ENFORCE_EQ(
59+
corr_bias.is_initialized(),
60+
false,
61+
::common::errors::InvalidArgument("corr_bias is not supported yet"));
6162

6263
using XPUType = typename XPUTypeTrait<T>::Type;
6364

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/phi/backends/xpu/enforce_xpu.h"
16+
#include "paddle/phi/core/kernel_registry.h"
17+
#include "paddle/phi/kernels/contiguous_kernel.h"
18+
#include "paddle/phi/kernels/empty_kernel.h"
19+
#include "paddle/phi/kernels/full_kernel.h"
20+
#include "paddle/phi/kernels/transpose_kernel.h"
21+
22+
namespace phi {
23+
24+
template <typename T, typename Context>
25+
void MoeGateDispatchPartialNoSoftMaxTopkGradKernel(
26+
const Context& dev_ctx,
27+
const DenseTensor& combine_weights_out,
28+
const DenseTensor& scatter_index,
29+
const DenseTensor& scatter_index_rev,
30+
const DenseTensor& expert_offset,
31+
const DenseTensor& expert_offset_local,
32+
const DenseTensor& y_grad,
33+
const DenseTensor& combine_weights_out_grad,
34+
int64_t k,
35+
int64_t capacity,
36+
bool use_pad,
37+
int64_t expert_start_index,
38+
int64_t expert_end_index,
39+
DenseTensor* x_grad,
40+
DenseTensor* combine_weights_grad) {
41+
dev_ctx.template Alloc<T>(x_grad);
42+
dev_ctx.template Alloc<float>(combine_weights_grad);
43+
phi::Full<float, Context>(
44+
dev_ctx,
45+
phi::IntArray(common::vectorize(combine_weights_grad->dims())),
46+
0,
47+
combine_weights_grad);
48+
DenseTensor t_scatter_index;
49+
phi::Transpose<int, Context>(
50+
dev_ctx, scatter_index, {1, 0}, &t_scatter_index);
51+
52+
int64_t num_rows = combine_weights_out.dims()[0];
53+
int64_t hidden_size = y_grad.dims()[1];
54+
int64_t num_experts = expert_offset.dims()[0];
55+
int64_t num_active = y_grad.dims()[0];
56+
57+
using XPUDataType = typename XPUTypeTrait<T>::Type;
58+
int r = xpu::moe_gate_dispatch_partial_nosoftmaxtopk_grad(
59+
dev_ctx.x_context(),
60+
reinterpret_cast<const XPUDataType*>(y_grad.data<T>()),
61+
combine_weights_out.data<float>(),
62+
t_scatter_index.data<int>(),
63+
combine_weights_out_grad.data<float>(),
64+
combine_weights_grad->data<float>(),
65+
reinterpret_cast<XPUDataType*>(x_grad->data<T>()),
66+
num_rows,
67+
k,
68+
hidden_size,
69+
num_experts,
70+
num_active);
71+
PADDLE_ENFORCE_XDNN_SUCCESS(r,
72+
"moe_gate_dispatch_partial_nosoftmaxtopk_grad");
73+
}
74+
} // namespace phi
75+
76+
PD_REGISTER_KERNEL(moe_gate_dispatch_partial_nosoftmaxtopk_grad,
77+
XPU,
78+
ALL_LAYOUT,
79+
phi::MoeGateDispatchPartialNoSoftMaxTopkGradKernel,
80+
float,
81+
phi::dtype::float16,
82+
phi::dtype::bfloat16) {}

0 commit comments

Comments
 (0)