Skip to content

Commit 53919d3

Browse files
authored
Add moe_combine_no_weight op (#73531)
* Add moe_combine_no_weight op * Remove seqlen and k from parameters * Set no_need_buffer for x
1 parent 51d2cf4 commit 53919d3

File tree

11 files changed

+474
-0
lines changed

11 files changed

+474
-0
lines changed

paddle/phi/infermeta/ternary.cc

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1630,6 +1630,43 @@ void MoeCombineInferMeta(const MetaTensor& x,
16301630
y->set_dtype(x.dtype());
16311631
}
16321632

1633+
void MoeCombineNoWeightInferMeta(const MetaTensor& x,
1634+
const MetaTensor& scatter_index,
1635+
MetaTensor* y) {
1636+
auto x_dim = x.dims();
1637+
auto scatter_index_dim = scatter_index.dims();
1638+
PADDLE_ENFORCE_EQ(x_dim.size(),
1639+
2,
1640+
common::errors::InvalidArgument(
1641+
"The dimensions of Input(x) must be 2, but "
1642+
"received dimensions of Input(x) is [%d]",
1643+
x_dim.size()));
1644+
PADDLE_ENFORCE_EQ(scatter_index_dim.size(),
1645+
2,
1646+
common::errors::InvalidArgument(
1647+
"The dimensions of Input(scatter_index) must be 2, but "
1648+
"received dimensions of Input(scatter_index) is [%d]",
1649+
scatter_index_dim.size()));
1650+
PADDLE_ENFORCE_EQ(scatter_index.dtype(),
1651+
phi::DataType::INT32,
1652+
common::errors::InvalidArgument(
1653+
"The input scatter_index type should be int32"
1654+
"But received scatter_index type = %s",
1655+
scatter_index.dtype()));
1656+
int64_t seqlen = scatter_index_dim[0];
1657+
int64_t k = scatter_index_dim[1];
1658+
int64_t hidden_size = x_dim[1];
1659+
PADDLE_ENFORCE_EQ(x_dim[0],
1660+
seqlen * k,
1661+
common::errors::InvalidArgument(
1662+
"The upper dim of Input(x) [%d] must equal to "
1663+
"the total size of Input(scatter_index) [%d].",
1664+
x_dim[0],
1665+
seqlen * k));
1666+
y->set_dims(phi::make_ddim({seqlen, hidden_size}));
1667+
y->set_dtype(x.dtype());
1668+
}
1669+
16331670
void MoeGateDispatchPartialNoSoftmaxTopKInferMeta(
16341671
const MetaTensor& x,
16351672
const MetaTensor& combine_weights,

paddle/phi/infermeta/ternary.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,10 @@ void MoeCombineInferMeta(const MetaTensor& x,
274274
const MetaTensor& scatter_index,
275275
MetaTensor* y);
276276

277+
void MoeCombineNoWeightInferMeta(const MetaTensor& x,
278+
const MetaTensor& scatter_index,
279+
MetaTensor* y);
280+
277281
void MoeGateDispatchPartialNoSoftmaxTopKInferMeta(
278282
const MetaTensor& x,
279283
const MetaTensor& combine_weights,

paddle/phi/kernels/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@ if(((WITH_GPU) AND (CUDA_VERSION VERSION_LESS 12.0))
6464
"legacy/gpu/expand_modality_expert_id_kernel.cu"
6565
"legacy/gpu/moe_combine_kernel.cu"
6666
"legacy/gpu/moe_combine_grad_kernel.cu"
67+
"legacy/gpu/moe_combine_no_weight_kernel.cu"
68+
"legacy/gpu/moe_combine_no_weight_grad_kernel.cu"
6769
"legacy/gpu/cal_aux_loss_kernel.cu"
6870
"legacy/gpu/cal_aux_loss_grad_kernel.cu"
6971
"legacy/gpu/ext_build_src_rank_and_local_expert_id_kernel.cu"
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/phi/backends/gpu/gpu_context.h"
16+
#include "paddle/phi/core/dense_tensor.h"
17+
#include "paddle/phi/core/kernel_registry.h"
18+
#include "paddle/phi/kernels/full_kernel.h"
19+
#include "paddle/phi/kernels/funcs/aligned_vector.h"
20+
21+
namespace phi {
22+
23+
template <typename T, typename MTP, int VecSize>
24+
__global__ void combine_no_weight_bwd_kernel(const int* scatter_index,
25+
const T* grad_y,
26+
T* grad_x,
27+
const int64_t k,
28+
const int64_t seqlen,
29+
const int64_t hidden_size) {
30+
using LoadT = phi::AlignedVector<T, VecSize>;
31+
LoadT grad_y_vec;
32+
int i = blockIdx.x; // Batch index (sequence length)
33+
int ki = blockIdx.y; // Sequence index
34+
35+
if (i < seqlen && ki < k) {
36+
int idx = scatter_index[i * k + ki]; // Index into x
37+
38+
// Loop over h dimension in strides of block
39+
for (int h_i = threadIdx.x * VecSize; h_i < hidden_size;
40+
h_i += blockDim.x * VecSize) {
41+
phi::Load<T, VecSize>(&(grad_y[i * hidden_size + h_i]), &grad_y_vec);
42+
phi::Store<T, VecSize>(grad_y_vec, &grad_x[idx * hidden_size + h_i]);
43+
}
44+
}
45+
}
46+
47+
template <typename T>
48+
void moe_combine_no_weight_bwd(const int* scatter_index,
49+
const T* grad_y,
50+
T* grad_x,
51+
const int64_t k,
52+
const int64_t seqlen,
53+
const int64_t hidden_size,
54+
cudaStream_t stream) {
55+
int block_size = 512;
56+
int grid_size_i = seqlen;
57+
int grid_size_k = k;
58+
dim3 blockDim(block_size);
59+
dim3 gridDim(grid_size_i, grid_size_k);
60+
61+
constexpr int max_pack_size = 16 / sizeof(T);
62+
if (hidden_size % max_pack_size == 0) {
63+
combine_no_weight_bwd_kernel<T, float, max_pack_size>
64+
<<<gridDim, blockDim, 0, stream>>>(
65+
scatter_index, grad_y, grad_x, k, seqlen, hidden_size);
66+
} else {
67+
combine_no_weight_bwd_kernel<T, float, 1><<<gridDim, blockDim, 0, stream>>>(
68+
scatter_index, grad_y, grad_x, k, seqlen, hidden_size);
69+
}
70+
}
71+
72+
template <typename T, typename Context>
73+
void MoeCombineNoWeightGradKernel(const Context& dev_ctx,
74+
const DenseTensor& x,
75+
const DenseTensor& scatter_index,
76+
const DenseTensor& grad_y,
77+
DenseTensor* grad_x) {
78+
const auto x_shape = x.dims();
79+
const int64_t hidden_size = x_shape[1];
80+
81+
const auto scatter_index_shape = scatter_index.dims();
82+
const int64_t seqlen = scatter_index_shape[0];
83+
const int64_t k = scatter_index_shape[1];
84+
85+
dev_ctx.template Alloc<T>(grad_x);
86+
phi::Full<T, Context>(
87+
dev_ctx, phi::IntArray(common::vectorize(grad_x->dims())), 0, grad_x);
88+
89+
moe_combine_no_weight_bwd<T>(scatter_index.data<int>(),
90+
grad_y.data<T>(),
91+
grad_x->data<T>(),
92+
k,
93+
seqlen,
94+
hidden_size,
95+
dev_ctx.stream());
96+
}
97+
98+
} // namespace phi
99+
100+
PD_REGISTER_KERNEL(moe_combine_no_weight_grad,
101+
GPU,
102+
ALL_LAYOUT,
103+
phi::MoeCombineNoWeightGradKernel,
104+
float,
105+
double,
106+
phi::dtype::bfloat16,
107+
phi::dtype::float16) {}
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/phi/backends/gpu/gpu_context.h"
16+
#include "paddle/phi/core/dense_tensor.h"
17+
#include "paddle/phi/core/kernel_registry.h"
18+
19+
namespace phi {
20+
21+
template <typename T, typename MTP, int k>
22+
__global__ void combine_no_weight_kernel(const T* __restrict__ x,
23+
const int* __restrict__ scatter_index,
24+
T* __restrict__ y,
25+
const int64_t hidden_size,
26+
const int64_t seqlen) {
27+
extern __shared__ char shared_mem[];
28+
int64_t* shared_indices = reinterpret_cast<int64_t*>(shared_mem);
29+
30+
int64_t seq_i = blockIdx.x;
31+
for (int ki = threadIdx.x; ki < k; ki += blockDim.x) {
32+
shared_indices[ki] = scatter_index[seq_i * k + ki];
33+
}
34+
__syncthreads();
35+
for (int h_i = threadIdx.x; h_i < hidden_size; h_i += blockDim.x) {
36+
MTP sum = static_cast<MTP>(0);
37+
#pragma unroll
38+
for (int ki = 0; ki < k; ++ki) {
39+
int64_t scatter_idx = shared_indices[ki];
40+
T x_val = x[scatter_idx * hidden_size + h_i];
41+
sum += static_cast<MTP>(x_val);
42+
}
43+
y[seq_i * hidden_size + h_i] = static_cast<T>(sum);
44+
}
45+
}
46+
47+
template <typename T>
48+
void moe_combine_no_weight_fwd(const T* x,
49+
const int* scatter_index,
50+
T* y,
51+
const int64_t k,
52+
const int64_t seqlen,
53+
const int64_t hidden_size,
54+
cudaStream_t stream) {
55+
int threads_per_block = 1024;
56+
dim3 blockDim(threads_per_block);
57+
dim3 gridDim(seqlen);
58+
size_t sharedMemSize = k * sizeof(int64_t);
59+
60+
#define CALL_KERNEL(K) \
61+
case K: \
62+
combine_no_weight_kernel<T, float, K> \
63+
<<<gridDim, blockDim, sharedMemSize>>>( \
64+
x, scatter_index, y, hidden_size, seqlen); \
65+
break;
66+
67+
switch (k) {
68+
CALL_KERNEL(1);
69+
CALL_KERNEL(2);
70+
CALL_KERNEL(3);
71+
CALL_KERNEL(4);
72+
CALL_KERNEL(5);
73+
CALL_KERNEL(6);
74+
CALL_KERNEL(7);
75+
CALL_KERNEL(8);
76+
CALL_KERNEL(9);
77+
CALL_KERNEL(10);
78+
CALL_KERNEL(11);
79+
CALL_KERNEL(12);
80+
CALL_KERNEL(13);
81+
CALL_KERNEL(14);
82+
CALL_KERNEL(15);
83+
CALL_KERNEL(16);
84+
default:
85+
PADDLE_THROW(phi::errors::InvalidArgument("Invalid k value."));
86+
break;
87+
}
88+
#undef CALL_KERNEL
89+
}
90+
91+
template <typename T, typename Context>
92+
void MoeCombineNoWeightKernel(const Context& dev_ctx,
93+
const DenseTensor& x,
94+
const DenseTensor& scatter_index,
95+
DenseTensor* y) {
96+
const auto x_shape = x.dims();
97+
const int64_t hidden_size = x_shape[1];
98+
99+
const auto scatter_index_shape = scatter_index.dims();
100+
const int64_t seqlen = scatter_index_shape[0];
101+
const int64_t k = scatter_index_shape[1];
102+
103+
dev_ctx.template Alloc<T>(y);
104+
105+
moe_combine_no_weight_fwd<T>(x.data<T>(),
106+
scatter_index.data<int>(),
107+
y->data<T>(),
108+
k,
109+
seqlen,
110+
hidden_size,
111+
dev_ctx.stream());
112+
}
113+
114+
} // namespace phi
115+
116+
PD_REGISTER_KERNEL(moe_combine_no_weight,
117+
GPU,
118+
ALL_LAYOUT,
119+
phi::MoeCombineNoWeightKernel,
120+
float,
121+
double,
122+
phi::dtype::bfloat16,
123+
phi::dtype::float16) {}

paddle/phi/ops/yaml/backward.yaml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2320,6 +2320,17 @@
23202320
kernel :
23212321
func : moe_combine_grad
23222322

2323+
- backward_op : moe_combine_no_weight_grad
2324+
forward : moe_combine_no_weight (Tensor x, Tensor scatter_index) -> Tensor(y)
2325+
args : (Tensor x, Tensor scatter_index, Tensor y_grad)
2326+
output : Tensor(x_grad)
2327+
infer_meta :
2328+
func : UnchangedInferMeta
2329+
param : [x]
2330+
kernel :
2331+
func : moe_combine_no_weight_grad
2332+
no_need_buffer : x
2333+
23232334
- backward_op : moe_gate_dispatch_grad
23242335
forward : moe_gate_dispatch (Tensor x, Tensor gate_logits, Tensor corr_bias, int64_t k, int64_t capacity, bool use_pad) -> Tensor(y), Tensor(combine_weights), Tensor(scatter_index), Tensor(expert_offset), Tensor(expert_id)
23252336
args : (Tensor combine_weights, Tensor scatter_index, Tensor expert_id, Tensor y_grad, Tensor combine_weights_grad, int64_t k, int64_t capacity, bool use_pad)

paddle/phi/ops/yaml/ops.yaml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3678,6 +3678,16 @@
36783678
data_type : x
36793679
backward : moe_combine_grad
36803680

3681+
- op : moe_combine_no_weight
3682+
args : (Tensor x, Tensor scatter_index)
3683+
output : Tensor(y)
3684+
infer_meta :
3685+
func : MoeCombineNoWeightInferMeta
3686+
kernel :
3687+
func : moe_combine_no_weight
3688+
data_type : x
3689+
backward : moe_combine_no_weight_grad
3690+
36813691
- op : moe_gate_dispatch
36823692
args : (Tensor x, Tensor gate_logits, Tensor corr_bias, int64_t k, int64_t capacity, bool use_pad)
36833693
output : Tensor(y), Tensor(combine_weights), Tensor(scatter_index), Tensor(expert_offset), Tensor(expert_id)

python/paddle/incubate/nn/functional/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
from .int_bincount import int_bincount
6161
from .masked_multihead_attention import masked_multihead_attention
6262
from .moe_combine import moe_combine
63+
from .moe_combine_no_weight import moe_combine_no_weight
6364
from .moe_gate_dispatch import moe_gate_dispatch
6465
from .moe_gate_dispatch_partial_nosoftmaxtopk import (
6566
moe_gate_dispatch_partial_nosoftmaxtopk,
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
from typing import TYPE_CHECKING
18+
19+
from paddle import _C_ops
20+
from paddle.base.framework import in_dynamic_or_pir_mode
21+
from paddle.base.layer_helper import LayerHelper
22+
23+
if TYPE_CHECKING:
24+
from paddle import Tensor
25+
26+
27+
def moe_combine_no_weight(
28+
x: Tensor,
29+
scatter_index: Tensor,
30+
name: str | None = None,
31+
) -> Tensor:
32+
"""
33+
Args:
34+
x: Input tensor [num_tokens, hidden_size]
35+
scatter_index: Scatter indices [seq_len, k] dtype=int32
36+
37+
Returns:
38+
Output Combined output [seq_len, hidden_size]
39+
"""
40+
if in_dynamic_or_pir_mode():
41+
return _C_ops.moe_combine_no_weight(x, scatter_index)
42+
helper = LayerHelper('moe_combine_no_weight', **locals())
43+
y = helper.create_variable_for_type_inference(dtype=x.dtype)
44+
inputs = {
45+
'x': x,
46+
'scatter_index': scatter_index,
47+
}
48+
helper.append_op(
49+
type='moe_combine_no_weight', inputs=inputs, outputs={'y': y}
50+
)
51+
return y

test/legacy_test/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -513,6 +513,7 @@ if(NOT WITH_GPU
513513
test_incubate_fused_rmsnorm_ext
514514
test_incubate_int_bincount
515515
test_incubate_moe_combine
516+
test_incubate_moe_combine_no_weight
516517
test_incubate_moe_gate_dispatch_partial_nosoftmaxtopk
517518
test_incubate_moe_gate_dispatch_w_permute_bwd
518519
test_incubate_moe_gate_dispatch_w_permute

0 commit comments

Comments
 (0)