Skip to content

Commit 91767eb

Browse files
authored
cherry-pick fleety's customized moe_permute optimization (#74979)
* cherry-pick fleety * fix miscs * recover fp16 * fix miscs
1 parent c791a51 commit 91767eb

File tree

7 files changed

+122
-43
lines changed

7 files changed

+122
-43
lines changed

paddle/phi/infermeta/multiary.cc

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6131,7 +6131,8 @@ void MoePermuteInferMeta(const MetaTensor& X,
61316131
const MetaTensor& expert_prob_topk,
61326132
const int num_experts,
61336133
const std::vector<int>& tokens_per_expert,
6134-
const int padding_multiplex,
6134+
const int padding_alignment,
6135+
const bool do_gather,
61356136
MetaTensor* X_unzipped,
61366137
MetaTensor* zipped_expertwise_rowmap,
61376138
MetaTensor* token_prob_unzipped,
@@ -6154,7 +6155,7 @@ void MoePermuteInferMeta(const MetaTensor& X,
61546155
true,
61556156
common::errors::InvalidArgument(
61566157
"Input expert_prob_topk's dtype should be FLOAT32"));
6157-
if (XScale) {
6158+
if (XScale && do_gather) {
61586159
PADDLE_ENFORCE_EQ(XScale.dtype(),
61596160
phi::DataType::FLOAT32,
61606161
common::errors::InvalidArgument(
@@ -6168,8 +6169,16 @@ void MoePermuteInferMeta(const MetaTensor& X,
61686169
}
61696170
const int rows = X.dims()[0];
61706171
const int cols = X.dims()[1];
6171-
X_unzipped->set_dims({-1, cols});
6172-
X_unzipped->set_dtype(X.dtype());
6172+
6173+
if (do_gather) {
6174+
X_unzipped->set_dims({-1, cols});
6175+
X_unzipped->set_dtype(X.dtype());
6176+
} else {
6177+
// Meta only, not
6178+
X_unzipped->set_dims({0, cols});
6179+
X_unzipped->set_dtype(X.dtype());
6180+
}
6181+
61736182
zipped_expertwise_rowmap->set_dims({rows, num_experts});
61746183
zipped_expertwise_rowmap->set_dtype(phi::DataType::INT32);
61756184
token_prob_unzipped->set_dims({-1});
@@ -6356,7 +6365,8 @@ void MaskedMultiheadAttentionInferMeta(const MetaTensor& x,
63566365
num_head % k_num_head,
63576366
0,
63586367
errors::InvalidArgument(
6359-
"The num_head of query must be divisible by the num_head of key, but "
6368+
"The num_head of query must be divisible by the num_head of key, "
6369+
"but "
63606370
"received num_head of query is %d, and the num_head of key is %d",
63616371
num_head,
63626372
k_num_head));
@@ -6798,6 +6808,5 @@ void MoeGateDispatchAutoInferMeta(const MetaTensor& x,
67986808
expert_id->set_dims(common::make_ddim({num_rows, k}));
67996809
expert_id->set_dtype(phi::DataType::INT32);
68006810
}
6801-
68026811
} // namespace phi
68036812
PD_REGISTER_INFER_META_FN(batch_norm_infer, phi::BatchNormInferInferMeta);

paddle/phi/infermeta/multiary.h

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -560,7 +560,8 @@ void MoePermuteInferMeta(const MetaTensor& X,
560560
const MetaTensor& expert_prob_topk,
561561
const int num_experts,
562562
const std::vector<int>& tokens_per_expert,
563-
const int padding_multiplex,
563+
const int padding_alignment,
564+
const bool do_gather,
564565
MetaTensor* X_unzipped,
565566
MetaTensor* zipped_expertwise_rowmap,
566567
MetaTensor* token_prob_unzipped,
@@ -858,6 +859,28 @@ void MomentumInferMeta(const MetaTensor& param,
858859
MetaTensor* param_out,
859860
MetaTensor* velocity_out,
860861
MetaTensor* master_param_out);
862+
void MoePermuteInferMeta(const MetaTensor& X,
863+
const MetaTensor& XScale,
864+
const MetaTensor& expert_routemap_topk,
865+
const MetaTensor& expert_prob_topk,
866+
const int num_experts,
867+
const std::vector<int>& tokens_per_expert,
868+
const int padding_alignment,
869+
const bool do_gather,
870+
MetaTensor* X_unzipped,
871+
MetaTensor* zipped_expertwise_rowmap,
872+
MetaTensor* token_prob_unzipped,
873+
MetaTensor* XScale_unzipped);
874+
875+
void MoeUnpermuteInferMeta(const MetaTensor& unzipped_tokens,
876+
const MetaTensor& zipped_expertwise_rowmap,
877+
const MetaTensor& expert_routemap_topk,
878+
const MetaTensor& unzipped_token_probs,
879+
const int total_zipped_tokens_num,
880+
const int num_experts,
881+
const bool MP,
882+
MetaTensor* zipped_tokens,
883+
MetaTensor* zipped_probs_topk);
861884

862885
void MultiDotInferMeta(const std::vector<const MetaTensor*>& x,
863886
MetaTensor* out);

paddle/phi/kernels/gpu/moe_permute_kernel.cu

Lines changed: 50 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,11 @@ struct expert_infos {
4343
}
4444
};
4545

46-
template <typename X_T, typename routemap_T, typename probs_T, bool has_scale>
46+
template <typename X_T,
47+
typename routemap_T,
48+
typename probs_T,
49+
bool has_scale,
50+
bool do_gather>
4751
__global__ __launch_bounds__(512) void tokens_unzip_stable_kernel(
4852
const X_T *__restrict__ X,
4953
const routemap_T *__restrict__ routemap_topk,
@@ -130,17 +134,19 @@ __global__ __launch_bounds__(512) void tokens_unzip_stable_kernel(
130134
if (proposed_row_idx == -1) continue; // no memcpy
131135
if (threadIdx.x == 0)
132136
probs_unzipped[proposed_row_idx] = this_expert_token_info.expert_probs;
133-
// vec copy
134-
if constexpr (has_scale) {
137+
if constexpr (do_gather) {
138+
// vec copy
139+
if constexpr (has_scale) {
140+
vectorized_memcpy(&XScale[(int64_t)row * (int64_t)scale_length],
141+
&XScale_unzipped[(int64_t)proposed_row_idx *
142+
(int64_t)scale_length],
143+
scale_length);
144+
}
135145
vectorized_memcpy(
136-
&XScale[(int64_t)row * (int64_t)scale_length],
137-
&XScale_unzipped[(int64_t)proposed_row_idx * (int64_t)scale_length],
138-
scale_length);
146+
&X[(int64_t)row * (int64_t)token_length],
147+
&X_unzipped[(int64_t)proposed_row_idx * (int64_t)token_length],
148+
token_length);
139149
}
140-
vectorized_memcpy(
141-
&X[(int64_t)row * (int64_t)token_length],
142-
&X_unzipped[(int64_t)proposed_row_idx * (int64_t)token_length],
143-
token_length);
144150
}
145151
}
146152
}
@@ -160,7 +166,8 @@ void dispatch_tokens_unzip_stable(const Context &dev_ctx,
160166
const int token_length,
161167
const int topk, // deprecated
162168
const int num_experts,
163-
const int scale_length) {
169+
const int scale_length,
170+
const bool do_gather) {
164171
dim3 grid, block;
165172
grid.x =
166173
(total_zipped_tokens_num + CUMSUM_BLOCK_SIZE - 1) / CUMSUM_BLOCK_SIZE;
@@ -169,33 +176,41 @@ void dispatch_tokens_unzip_stable(const Context &dev_ctx,
169176
#define DTYPE_CASE(dtype, type) dtype == phi::DataType::type
170177
#define GET_DATA(tensor, type) tensor.data<type>()
171178
#define GET_PTR_DATA(tensor, type) tensor->data<type>()
172-
#define DISPATCH_CASE(TOKEN_T, PROB_T, INT_T, HAS_SCALE) \
173-
auto kernel = tokens_unzip_stable_kernel<TOKEN_T, INT_T, PROB_T, HAS_SCALE>; \
174-
kernel<<<grid, block, 0, dev_ctx.stream()>>>( \
175-
GET_DATA(X, TOKEN_T), \
176-
GET_DATA(expert_routemap_topk, INT_T), \
177-
GET_DATA(expert_prob_topk, PROB_T), \
178-
XScale ? XScale.get_ptr()->data<float>() : nullptr, \
179-
GET_DATA(expert_offsets, int), \
180-
GET_PTR_DATA(X_unzipped, TOKEN_T), \
181-
GET_PTR_DATA(zipped_expertwise_rowmap, INT_T), \
182-
GET_PTR_DATA(token_prob_unzipped, PROB_T), \
183-
XScale_unzipped->data<float>(), \
184-
global_expertwise_block_cumsum->data<int>(), \
185-
total_zipped_tokens_num, \
186-
token_length, \
187-
scale_length, \
188-
num_experts, \
179+
#define DISPATCH_CASE(TOKEN_T, PROB_T, INT_T, HAS_SCALE, DO_GATHER) \
180+
auto kernel = tokens_unzip_stable_kernel<TOKEN_T, \
181+
INT_T, \
182+
PROB_T, \
183+
HAS_SCALE, \
184+
DO_GATHER>; \
185+
kernel<<<grid, block, 0, dev_ctx.stream()>>>( \
186+
GET_DATA(X, TOKEN_T), \
187+
GET_DATA(expert_routemap_topk, INT_T), \
188+
GET_DATA(expert_prob_topk, PROB_T), \
189+
XScale ? XScale.get_ptr()->data<float>() : nullptr, \
190+
GET_DATA(expert_offsets, int), \
191+
GET_PTR_DATA(X_unzipped, TOKEN_T), \
192+
GET_PTR_DATA(zipped_expertwise_rowmap, INT_T), \
193+
GET_PTR_DATA(token_prob_unzipped, PROB_T), \
194+
XScale_unzipped->data<float>(), \
195+
global_expertwise_block_cumsum->data<int>(), \
196+
total_zipped_tokens_num, \
197+
token_length, \
198+
scale_length, \
199+
num_experts, \
189200
topk);
190201

191-
#define HANDLE_EXPERT_CASE(TOKEN_T, PROB_T, INT_T, HAS_SCALE) \
192-
DISPATCH_CASE(TOKEN_T, PROB_T, INT_T, HAS_SCALE)
202+
#define HANDLE_GATHER_CASE(TOKEN_T, PROB_T, INT_T, HAS_SCALE) \
203+
if (do_gather) { \
204+
DISPATCH_CASE(TOKEN_T, PROB_T, INT_T, HAS_SCALE, true) \
205+
} else { \
206+
DISPATCH_CASE(TOKEN_T, PROB_T, INT_T, HAS_SCALE, false) \
207+
}
193208

194209
#define HANDLE_TOKEN_TYPE(PROB_T, INT_T) \
195210
if (DTYPE_CASE(X.dtype(), BFLOAT16)) { \
196-
HANDLE_EXPERT_CASE(phi::bfloat16, PROB_T, INT_T, false) \
211+
HANDLE_GATHER_CASE(phi::bfloat16, PROB_T, INT_T, false) \
197212
} else if (DTYPE_CASE(X.dtype(), FLOAT8_E4M3FN)) { \
198-
HANDLE_EXPERT_CASE(phi::float8_e4m3fn, PROB_T, INT_T, true) \
213+
HANDLE_GATHER_CASE(phi::float8_e4m3fn, PROB_T, INT_T, true) \
199214
}
200215

201216
#define HANDLE_PROB_TYPE(INT_T) \
@@ -226,6 +241,7 @@ void MoePermuteKernel(const Context &dev_ctx,
226241
const int num_experts,
227242
const std::vector<int> &tokens_per_expert,
228243
const int padding_multiplex,
244+
const bool do_gather,
229245
DenseTensor *X_unzipped,
230246
DenseTensor *zipped_expertwise_rowmap,
231247
DenseTensor *token_prob_unzipped,
@@ -341,7 +357,8 @@ void MoePermuteKernel(const Context &dev_ctx,
341357
cols,
342358
topk_calculated,
343359
num_experts,
344-
quanted_cols);
360+
quanted_cols,
361+
do_gather);
345362
}
346363
#undef CUMSUM_BLOCK_SIZE
347364
#undef CUMSUM_INVALID_TAG

paddle/phi/ops/yaml/ops.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3886,7 +3886,7 @@
38863886
backward : moe_gate_dispatch_permute_grad
38873887

38883888
- op : moe_permute
3889-
args : (Tensor hidden_states, Tensor scale, Tensor expert_routemap_topk, Tensor expert_prob_topk, int num_experts, int[] tokens_per_expert, int padding_alignment)
3889+
args : (Tensor hidden_states, Tensor scale, Tensor expert_routemap_topk, Tensor expert_prob_topk, int num_experts, int[] tokens_per_expert, int padding_alignment, bool do_gather)
38903890
output : Tensor(hidden_states_unzipped), Tensor(zipped_expertwise_rowmap), Tensor(token_prob_unzipped), Tensor(scale_unzipped)
38913891
infer_meta:
38923892
func : MoePermuteInferMeta

python/paddle/nn/functional/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,8 @@
242242
'max_unpool1d',
243243
'max_unpool2d',
244244
'max_unpool3d',
245+
'moe_permute',
246+
'moe_unpermute',
245247
'adaptive_avg_pool1d',
246248
'adaptive_avg_pool2d',
247249
'adaptive_avg_pool3d',
@@ -304,6 +306,4 @@
304306
"flash_attention_v3_varlen",
305307
'flash_attn_varlen_qkvpacked',
306308
'group_norm',
307-
'moe_permute',
308-
'moe_unpermute',
309309
]

python/paddle/nn/functional/moe_permute.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def moe_permute(
3131
num_experts: int,
3232
tokens_per_expert: list,
3333
padding_alignment: int,
34+
do_gather: bool = True,
3435
name: str | None = None,
3536
) -> tuple[Tensor, Tensor, Tensor, Tensor]:
3637
r"""
@@ -67,6 +68,7 @@ def moe_permute(
6768
assigned to the corresponding expert.
6869
padding_alignment (int): Tokens alignment requirement for expert buffers (in bytes).
6970
Must be a power of 2. Typical values are 16, 32 or 64 for optimal memory access.
71+
do_gather(bool): Decide whether do actual tokens gather operation or not, default is True.
7072
name (str|None, optional): Name prefix for the operation (optional).
7173
Default: None
7274
@@ -133,6 +135,7 @@ def moe_permute(
133135
num_experts,
134136
tokens_per_expert,
135137
padding_alignment,
138+
do_gather,
136139
)
137140
return (
138141
hidden_states_unzipped,

test/legacy_test/test_moe_permute_unpermute.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,22 @@ def test_permute_unpermute_consistency(self):
139139
tokens_per_expert=tokens_per_expert,
140140
padding_alignment=128,
141141
)
142+
# do_gather = False
143+
(
144+
_,
145+
zipped_expertwise_rowmap_no_gather,
146+
unzipped_probs_no_gather,
147+
_,
148+
) = moe_permute(
149+
hidden_states,
150+
scale,
151+
expert_routemap_topk,
152+
expert_prob_topk,
153+
num_experts=expert_num,
154+
tokens_per_expert=tokens_per_expert,
155+
padding_alignment=128,
156+
do_gather=False,
157+
)
142158

143159
unpermute_input = (
144160
unzipped_tokens.astype("float32")
@@ -174,6 +190,17 @@ def test_permute_unpermute_consistency(self):
174190
err_msg="moe_permute_unpermute probs do not match",
175191
)
176192

193+
np.testing.assert_equal(
194+
zipped_expertwise_rowmap_no_gather._md5sum(),
195+
zipped_expertwise_rowmap._md5sum(),
196+
err_msg="no_gather's zipped_expertwise_rowmap do not match",
197+
)
198+
np.testing.assert_equal(
199+
unzipped_probs_no_gather._md5sum(),
200+
unzipped_probs._md5sum(),
201+
err_msg="no_gather's unzipped_probs do not match",
202+
)
203+
177204

178205
if __name__ == "__main__":
179206
unittest.main()

0 commit comments

Comments
 (0)