Skip to content

Commit cdcd76a

Browse files
From00lshpku
andauthored
【cherry-pick】Add moe_combine_no_weight OP (#73592) (#73607)
* Add moe_combine_no_weight OP (#73592) * Add moe_combine_no_weight op (#73531) * Add moe_combine_no_weight op * Remove seqlen and k from parameters * Set no_need_buffer for x * Add moe_combine_no_weight Op * Update --------- Co-authored-by: Shuhao Liang <[email protected]> * Empty-Commit --------- Co-authored-by: Shuhao Liang <[email protected]>
1 parent 0ffabff commit cdcd76a

File tree

11 files changed

+527
-0
lines changed

11 files changed

+527
-0
lines changed

paddle/phi/infermeta/ternary.cc

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1630,6 +1630,45 @@ void MoeCombineInferMeta(const MetaTensor& x,
16301630
y->set_dtype(x.dtype());
16311631
}
16321632

1633+
void MoeCombineNoWeightInferMeta(const MetaTensor& x,
1634+
const MetaTensor& combine_weights,
1635+
const MetaTensor& scatter_index,
1636+
float epsilon,
1637+
MetaTensor* y) {
1638+
auto x_dim = x.dims();
1639+
auto scatter_index_dim = scatter_index.dims();
1640+
PADDLE_ENFORCE_EQ(x_dim.size(),
1641+
2,
1642+
common::errors::InvalidArgument(
1643+
"The dimensions of Input(x) must be 2, but "
1644+
"received dimensions of Input(x) is [%d]",
1645+
x_dim.size()));
1646+
PADDLE_ENFORCE_EQ(scatter_index_dim.size(),
1647+
2,
1648+
common::errors::InvalidArgument(
1649+
"The dimensions of Input(scatter_index) must be 2, but "
1650+
"received dimensions of Input(scatter_index) is [%d]",
1651+
scatter_index_dim.size()));
1652+
PADDLE_ENFORCE_EQ(scatter_index.dtype(),
1653+
phi::DataType::INT32,
1654+
common::errors::InvalidArgument(
1655+
"The input scatter_index type should be int32"
1656+
"But received scatter_index type = %s",
1657+
scatter_index.dtype()));
1658+
int64_t seqlen = scatter_index_dim[0];
1659+
int64_t k = scatter_index_dim[1];
1660+
int64_t hidden_size = x_dim[1];
1661+
PADDLE_ENFORCE_EQ(x_dim[0],
1662+
seqlen * k,
1663+
common::errors::InvalidArgument(
1664+
"The upper dim of Input(x) [%d] must equal to "
1665+
"the total size of Input(scatter_index) [%d].",
1666+
x_dim[0],
1667+
seqlen * k));
1668+
y->set_dims(phi::make_ddim({seqlen, hidden_size}));
1669+
y->set_dtype(x.dtype());
1670+
}
1671+
16331672
void MoeGateDispatchPartialNoSoftmaxTopKInferMeta(
16341673
const MetaTensor& x,
16351674
const MetaTensor& combine_weights,

paddle/phi/infermeta/ternary.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,12 @@ void MoeCombineInferMeta(const MetaTensor& x,
274274
const MetaTensor& scatter_index,
275275
MetaTensor* y);
276276

277+
void MoeCombineNoWeightInferMeta(const MetaTensor& x,
278+
const MetaTensor& combine_weights,
279+
const MetaTensor& scatter_index,
280+
float epsilon,
281+
MetaTensor* y);
282+
277283
void MoeGateDispatchPartialNoSoftmaxTopKInferMeta(
278284
const MetaTensor& x,
279285
const MetaTensor& combine_weights,

paddle/phi/kernels/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@ if(((WITH_GPU) AND (CUDA_VERSION VERSION_LESS 12.0))
6464
"legacy/gpu/expand_modality_expert_id_kernel.cu"
6565
"legacy/gpu/moe_combine_kernel.cu"
6666
"legacy/gpu/moe_combine_grad_kernel.cu"
67+
"legacy/gpu/moe_combine_no_weight_kernel.cu"
68+
"legacy/gpu/moe_combine_no_weight_grad_kernel.cu"
6769
"legacy/gpu/cal_aux_loss_kernel.cu"
6870
"legacy/gpu/cal_aux_loss_grad_kernel.cu"
6971
"legacy/gpu/ext_build_src_rank_and_local_expert_id_kernel.cu"
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/phi/backends/gpu/gpu_context.h"
16+
#include "paddle/phi/core/dense_tensor.h"
17+
#include "paddle/phi/core/kernel_registry.h"
18+
#include "paddle/phi/kernels/full_kernel.h"
19+
#include "paddle/phi/kernels/funcs/aligned_vector.h"
20+
21+
namespace phi {
22+
23+
template <typename T, typename MTP, int VecSize>
24+
__global__ void combine_no_weight_bwd_kernel(const T* combine_weights,
25+
const int* scatter_index,
26+
const T* grad_y,
27+
T* grad_x,
28+
const int64_t k,
29+
const int64_t seqlen,
30+
const int64_t hidden_size,
31+
const float epsilon) {
32+
using LoadT = phi::AlignedVector<T, VecSize>;
33+
LoadT grad_y_vec;
34+
int i = blockIdx.x; // Batch index (sequence length)
35+
int ki = blockIdx.y; // Sequence index
36+
37+
if (i < seqlen && ki < k) {
38+
int idx = scatter_index[i * k + ki]; // Index into x
39+
if (fabsf(combine_weights[i * k + ki]) <=
40+
epsilon) { // no grad for padding tokens
41+
return;
42+
}
43+
// Loop over h dimension in strides of block
44+
for (int h_i = threadIdx.x * VecSize; h_i < hidden_size;
45+
h_i += blockDim.x * VecSize) {
46+
phi::Load<T, VecSize>(&(grad_y[i * hidden_size + h_i]), &grad_y_vec);
47+
phi::Store<T, VecSize>(grad_y_vec, &grad_x[idx * hidden_size + h_i]);
48+
}
49+
}
50+
}
51+
52+
template <typename T>
53+
void moe_combine_no_weight_bwd(const T* combine_weights,
54+
const int* scatter_index,
55+
const T* grad_y,
56+
T* grad_x,
57+
const int64_t k,
58+
const int64_t seqlen,
59+
const int64_t hidden_size,
60+
const float epsilon,
61+
cudaStream_t stream) {
62+
int block_size = 512;
63+
int grid_size_i = seqlen;
64+
int grid_size_k = k;
65+
dim3 blockDim(block_size);
66+
dim3 gridDim(grid_size_i, grid_size_k);
67+
68+
constexpr int max_pack_size = 16 / sizeof(T);
69+
if (hidden_size % max_pack_size == 0) {
70+
combine_no_weight_bwd_kernel<T, float, max_pack_size>
71+
<<<gridDim, blockDim, 0, stream>>>(combine_weights,
72+
scatter_index,
73+
grad_y,
74+
grad_x,
75+
k,
76+
seqlen,
77+
hidden_size,
78+
epsilon);
79+
} else {
80+
combine_no_weight_bwd_kernel<T, float, 1>
81+
<<<gridDim, blockDim, 0, stream>>>(combine_weights,
82+
scatter_index,
83+
grad_y,
84+
grad_x,
85+
k,
86+
seqlen,
87+
hidden_size,
88+
epsilon);
89+
}
90+
}
91+
92+
template <typename T, typename Context>
93+
void MoeCombineNoWeightGradKernel(const Context& dev_ctx,
94+
const DenseTensor& x,
95+
const DenseTensor& combine_weights,
96+
const DenseTensor& scatter_index,
97+
const DenseTensor& grad_y,
98+
const float epsilon,
99+
DenseTensor* grad_x) {
100+
const auto x_shape = x.dims();
101+
const int64_t hidden_size = x_shape[1];
102+
103+
const auto scatter_index_shape = scatter_index.dims();
104+
const int64_t seqlen = scatter_index_shape[0];
105+
const int64_t k = scatter_index_shape[1];
106+
107+
dev_ctx.template Alloc<T>(grad_x);
108+
phi::Full<T, Context>(
109+
dev_ctx, phi::IntArray(common::vectorize(grad_x->dims())), 0, grad_x);
110+
111+
moe_combine_no_weight_bwd<T>(combine_weights.data<T>(),
112+
scatter_index.data<int>(),
113+
grad_y.data<T>(),
114+
grad_x->data<T>(),
115+
k,
116+
seqlen,
117+
hidden_size,
118+
epsilon,
119+
dev_ctx.stream());
120+
}
121+
122+
} // namespace phi
123+
124+
PD_REGISTER_KERNEL(moe_combine_no_weight_grad,
125+
GPU,
126+
ALL_LAYOUT,
127+
phi::MoeCombineNoWeightGradKernel,
128+
float,
129+
double,
130+
phi::dtype::bfloat16,
131+
phi::dtype::float16) {}
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/phi/backends/gpu/gpu_context.h"
16+
#include "paddle/phi/core/dense_tensor.h"
17+
#include "paddle/phi/core/kernel_registry.h"
18+
19+
namespace phi {
20+
21+
template <typename T, typename MTP, int k>
22+
__global__ void combine_no_weight_kernel(const T* __restrict__ x,
23+
const T* __restrict__ combine_weights,
24+
const int* __restrict__ scatter_index,
25+
T* __restrict__ y,
26+
const int64_t hidden_size,
27+
const int64_t seqlen,
28+
const float epsilon) {
29+
extern __shared__ char shared_mem[];
30+
MTP* shared_weights = reinterpret_cast<MTP*>(shared_mem);
31+
int64_t* shared_indices = reinterpret_cast<int64_t*>(shared_mem);
32+
33+
int64_t seq_i = blockIdx.x;
34+
for (int ki = threadIdx.x; ki < k; ki += blockDim.x) {
35+
shared_weights[ki] = static_cast<MTP>(combine_weights[seq_i * k + ki]);
36+
shared_indices[ki] = scatter_index[seq_i * k + ki];
37+
}
38+
__syncthreads();
39+
for (int h_i = threadIdx.x; h_i < hidden_size; h_i += blockDim.x) {
40+
MTP sum = static_cast<MTP>(0);
41+
#pragma unroll
42+
for (int ki = 0; ki < k; ++ki) {
43+
if (fabsf(shared_weights[ki]) <= epsilon) {
44+
continue;
45+
}
46+
int64_t scatter_idx = shared_indices[ki];
47+
T x_val = x[scatter_idx * hidden_size + h_i];
48+
sum += static_cast<MTP>(x_val);
49+
}
50+
y[seq_i * hidden_size + h_i] = static_cast<T>(sum);
51+
}
52+
}
53+
54+
template <typename T>
55+
void moe_combine_no_weight_fwd(const T* x,
56+
const T* combine_weights,
57+
const int* scatter_index,
58+
T* y,
59+
const int64_t k,
60+
const int64_t seqlen,
61+
const int64_t hidden_size,
62+
const float epsilon,
63+
cudaStream_t stream) {
64+
int threads_per_block = 1024;
65+
dim3 blockDim(threads_per_block);
66+
dim3 gridDim(seqlen);
67+
size_t sharedMemSize = k * (sizeof(int64_t) + sizeof(T));
68+
69+
#define CALL_KERNEL(K) \
70+
case K: \
71+
combine_no_weight_kernel<T, float, K> \
72+
<<<gridDim, blockDim, sharedMemSize>>>(x, \
73+
combine_weights, \
74+
scatter_index, \
75+
y, \
76+
hidden_size, \
77+
seqlen, \
78+
epsilon); \
79+
break;
80+
81+
switch (k) {
82+
CALL_KERNEL(1);
83+
CALL_KERNEL(2);
84+
CALL_KERNEL(3);
85+
CALL_KERNEL(4);
86+
CALL_KERNEL(5);
87+
CALL_KERNEL(6);
88+
CALL_KERNEL(7);
89+
CALL_KERNEL(8);
90+
CALL_KERNEL(9);
91+
CALL_KERNEL(10);
92+
CALL_KERNEL(11);
93+
CALL_KERNEL(12);
94+
CALL_KERNEL(13);
95+
CALL_KERNEL(14);
96+
CALL_KERNEL(15);
97+
CALL_KERNEL(16);
98+
default:
99+
PADDLE_THROW(phi::errors::InvalidArgument("Invalid k value."));
100+
break;
101+
}
102+
#undef CALL_KERNEL
103+
}
104+
105+
template <typename T, typename Context>
106+
void MoeCombineNoWeightKernel(const Context& dev_ctx,
107+
const DenseTensor& x,
108+
const DenseTensor& combine_weights,
109+
const DenseTensor& scatter_index,
110+
const float epsilon,
111+
DenseTensor* y) {
112+
const auto x_shape = x.dims();
113+
const int64_t hidden_size = x_shape[1];
114+
115+
const auto scatter_index_shape = scatter_index.dims();
116+
const int64_t seqlen = scatter_index_shape[0];
117+
const int64_t k = scatter_index_shape[1];
118+
119+
dev_ctx.template Alloc<T>(y);
120+
121+
moe_combine_no_weight_fwd<T>(x.data<T>(),
122+
combine_weights.data<T>(),
123+
scatter_index.data<int>(),
124+
y->data<T>(),
125+
k,
126+
seqlen,
127+
hidden_size,
128+
epsilon,
129+
dev_ctx.stream());
130+
}
131+
132+
} // namespace phi
133+
134+
PD_REGISTER_KERNEL(moe_combine_no_weight,
135+
GPU,
136+
ALL_LAYOUT,
137+
phi::MoeCombineNoWeightKernel,
138+
float,
139+
double,
140+
phi::dtype::bfloat16,
141+
phi::dtype::float16) {}

paddle/phi/ops/yaml/backward.yaml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2320,6 +2320,17 @@
23202320
kernel :
23212321
func : moe_combine_grad
23222322

2323+
- backward_op : moe_combine_no_weight_grad
2324+
forward : moe_combine_no_weight (Tensor x, Tensor combine_weight, Tensor scatter_index, float epsilon = 1.0e-15) -> Tensor(y)
2325+
args : (Tensor x, Tensor combine_weight, Tensor scatter_index, Tensor y_grad, float epsilon = 1.0e-15)
2326+
output : Tensor(x_grad)
2327+
infer_meta :
2328+
func : UnchangedInferMeta
2329+
param : [x]
2330+
kernel :
2331+
func : moe_combine_no_weight_grad
2332+
no_need_buffer : x
2333+
23232334
- backward_op : moe_gate_dispatch_grad
23242335
forward : moe_gate_dispatch (Tensor x, Tensor gate_logits, Tensor corr_bias, int64_t k, int64_t capacity, bool use_pad) -> Tensor(y), Tensor(combine_weights), Tensor(scatter_index), Tensor(expert_offset), Tensor(expert_id)
23252336
args : (Tensor combine_weights, Tensor scatter_index, Tensor expert_id, Tensor y_grad, Tensor combine_weights_grad, int64_t k, int64_t capacity, bool use_pad)

paddle/phi/ops/yaml/ops.yaml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3678,6 +3678,16 @@
36783678
data_type : x
36793679
backward : moe_combine_grad
36803680

3681+
- op : moe_combine_no_weight
3682+
args : (Tensor x, Tensor combine_weight, Tensor scatter_index, float epsilon = 1.0e-15)
3683+
output : Tensor(y)
3684+
infer_meta :
3685+
func : MoeCombineNoWeightInferMeta
3686+
kernel :
3687+
func : moe_combine_no_weight
3688+
data_type : x
3689+
backward : moe_combine_no_weight_grad
3690+
36813691
- op : moe_gate_dispatch
36823692
args : (Tensor x, Tensor gate_logits, Tensor corr_bias, int64_t k, int64_t capacity, bool use_pad)
36833693
output : Tensor(y), Tensor(combine_weights), Tensor(scatter_index), Tensor(expert_offset), Tensor(expert_id)

python/paddle/incubate/nn/functional/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
from .int_bincount import int_bincount
6161
from .masked_multihead_attention import masked_multihead_attention
6262
from .moe_combine import moe_combine
63+
from .moe_combine_no_weight import moe_combine_no_weight
6364
from .moe_gate_dispatch import moe_gate_dispatch
6465
from .moe_gate_dispatch_partial_nosoftmaxtopk import (
6566
moe_gate_dispatch_partial_nosoftmaxtopk,

0 commit comments

Comments
 (0)