Skip to content

Commit af543b7

Browse files
authored
revise get_moe_scores (#3164)
1 parent e24929e commit af543b7

File tree

6 files changed

+165
-23
lines changed

6 files changed

+165
-23
lines changed

custom_ops/gpu_ops/noaux_tc.cu

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,14 @@ std::vector<paddle::Tensor> NoauxTc(paddle::Tensor& scores,
3333
auto input_type = scores_with_bias.dtype();
3434
auto place = scores_with_bias.place();
3535
auto group_scores = paddle::empty({num_tokens, n_group}, input_type, place);
36+
auto topk_values = paddle::empty({num_tokens, topk}, input_type, place);
37+
auto topk_indices = paddle::empty({num_tokens, topk}, paddle::DataType::INT32, place);
3638
auto stream = scores_with_bias.stream();
3739

38-
invokeNoAuxTc<float>(reinterpret_cast<float*>(scores.data<float>()),
40+
invokeNoAuxTc<float, int32_t>(reinterpret_cast<float*>(scores.data<float>()),
3941
reinterpret_cast<float*>(group_scores.data<float>()),
42+
reinterpret_cast<float*>(topk_values.data<float>()),
43+
reinterpret_cast<int32_t*>(topk_indices.data<int32_t>()),
4044
reinterpret_cast<float*>(scores_with_bias.data<float>()),
4145
num_tokens,
4246
num_experts,
@@ -46,19 +50,23 @@ std::vector<paddle::Tensor> NoauxTc(paddle::Tensor& scores,
4650
routed_scaling_factor,
4751
stream);
4852

49-
return {scores};
53+
return {scores, topk_values, topk_indices};
5054
}
5155

5256
std::vector<paddle::DataType> NoauxTcInferDtype(
5357
const paddle::DataType& scores_dtype,
5458
const paddle::DataType& scores_with_bias_dtype) {
55-
return {scores_dtype};
59+
return {scores_dtype, scores_dtype, paddle::DataType::INT32};
5660
}
5761

5862
std::vector<std::vector<int64_t>> NoauxTcInferShape(
5963
const std::vector<int64_t>& scores_shape,
60-
const std::vector<int64_t>& gating_output_shape) {
61-
return {scores_shape};
64+
const std::vector<int64_t>& ,
65+
const int topk) {
66+
auto num_tokens = scores_shape[0];
67+
auto topk_values_shape = std::vector<int64_t>{num_tokens, topk};
68+
auto topk_indices_shape = std::vector<int64_t>{num_tokens, topk};
69+
return {scores_shape, topk_values_shape, topk_indices_shape};
6270
}
6371

6472
PD_BUILD_STATIC_OP(noaux_tc)

custom_ops/gpu_ops/noauxtc_kernel.h

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -372,10 +372,12 @@ __global__ void topk_with_k2_kernel(T* output,
372372
}
373373
}
374374

375-
template <typename T>
375+
template <typename T, typename IdxT>
376376
__global__ void group_idx_and_topk_idx_kernel(
377377
T* scores,
378378
T const* group_scores,
379+
T* topk_values,
380+
IdxT* topk_indices,
379381
T* scores_with_bias,
380382
int64_t const num_tokens,
381383
int64_t const n_group,
@@ -391,6 +393,8 @@ __global__ void group_idx_and_topk_idx_kernel(
391393
scores_with_bias += case_id * num_experts;
392394
scores += case_id * num_experts;
393395
group_scores += case_id * n_group;
396+
topk_values += case_id * topk;
397+
topk_indices += case_id * topk;
394398
int32_t align_num_experts_per_group =
395399
warp_topk::round_up_to_multiple_of<WARP_SIZE>(num_experts_per_group);
396400

@@ -436,6 +440,7 @@ __global__ void group_idx_and_topk_idx_kernel(
436440
queue((int32_t)topk, cuda::std::numeric_limits<T>::min());
437441

438442
int count_equalto_topkth_group = 0;
443+
bool if_proceed_next_topk = (topk_group_value != cuda::std::numeric_limits<T>::min());
439444
if (case_id < num_tokens) {
440445
for (int i_group = 0; i_group < n_group; i_group++) {
441446
if ((group_scores[i_group] > topk_group_value) ||
@@ -490,13 +495,23 @@ __global__ void group_idx_and_topk_idx_kernel(
490495
for (int i = lane_id; i < topk; i += WARP_SIZE) {
491496
float value = s_topk_value[i] / topk_sum * routed_scaling_factor;
492497
scores[s_topk_idx[i]] = value;
498+
if (if_proceed_next_topk) {
499+
topk_indices[i] = s_topk_idx[i];
500+
topk_values[i] = static_cast<T>(value);
501+
}
502+
else {
503+
topk_indices[i] = i;
504+
topk_values[i] = static_cast<float>(1.0f / topk);
505+
}
493506
}
494507
}
495508
}
496509

497-
template <typename T>
510+
template <typename T, typename IdxT>
498511
void invokeNoAuxTc(T* scores,
499512
T* group_scores,
513+
T* topk_values,
514+
IdxT* topk_indices,
500515
T* scores_with_bias,
501516
int64_t const num_tokens,
502517
int64_t const num_experts,
@@ -526,6 +541,8 @@ void invokeNoAuxTc(T* scores,
526541
dynamic_smem_in_bytes,
527542
stream>>>(scores,
528543
group_scores,
544+
topk_values,
545+
topk_indices,
529546
scores_with_bias,
530547
num_tokens,
531548
n_group,
@@ -536,9 +553,11 @@ void invokeNoAuxTc(T* scores,
536553
routed_scaling_factor);
537554
}
538555

539-
#define INSTANTIATE_NOAUX_TC(T) \
540-
template void invokeNoAuxTc<T>(T * scores, \
556+
#define INSTANTIATE_NOAUX_TC(T, IdxT) \
557+
template void invokeNoAuxTc<T, IdxT>(T * scores, \
541558
T * group_scores, \
559+
T* topk_values, \
560+
IdxT* topk_indices, \
542561
T * scores_with_bias, \
543562
int64_t const num_tokens, \
544563
int64_t const num_experts, \
@@ -548,4 +567,4 @@ void invokeNoAuxTc(T* scores,
548567
double const routed_scaling_factor, \
549568
cudaStream_t const stream);
550569

551-
INSTANTIATE_NOAUX_TC(float);
570+
INSTANTIATE_NOAUX_TC(float, int32_t);

fastdeploy/model_executor/layers/moe/ep.py

Lines changed: 46 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,35 @@
3131
from fastdeploy.config import MoEPhase
3232
from fastdeploy.utils import singleton
3333

34+
try:
35+
from fastdeploy.model_executor.ops.gpu import noaux_tc
36+
except:
37+
logger.warning("import noaux_tc Failed!")
38+
39+
40+
def get_moe_scores(
41+
gating_output: paddle.Tensor,
42+
n_group,
43+
topk_group,
44+
top_k,
45+
routed_scaling_factor,
46+
e_score_correction_bias,
47+
) -> paddle.Tensor:
48+
"""
49+
compute moe scores using e_score_correction_bias.
50+
"""
51+
scores = paddle.nn.functional.sigmoid(gating_output)
52+
scores_with_bias = scores + e_score_correction_bias.unsqueeze(0)
53+
scores, topk_values, topk_idx = noaux_tc(
54+
scores,
55+
scores_with_bias,
56+
n_group,
57+
topk_group,
58+
top_k,
59+
routed_scaling_factor,
60+
)
61+
return scores, topk_values, topk_idx
62+
3463

3564
@singleton
3665
class DeepEPEngine:
@@ -284,13 +313,23 @@ def moe_select(self, layer: nn.Layer, gate_out: paddle.Tensor):
284313
redundant_ep_rank_num_plus_one=layer.fd_config.model_config.redundant_experts_num + 1,
285314
)
286315
else:
287-
topk_idx, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
288-
gate_out,
289-
layer.gate_correction_bias,
290-
self.top_k,
291-
True, # apply_norm_weight,
292-
False,
293-
)
316+
if layer.topk_method == "noaux_tc":
317+
score, topk_weights, topk_idx = get_moe_scores(
318+
gate_out,
319+
layer.n_group,
320+
layer.topk_group,
321+
layer.top_k,
322+
layer.routed_scaling_factor,
323+
layer.gate_correction_bias,
324+
)
325+
else:
326+
topk_idx, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
327+
gate_out,
328+
layer.gate_correction_bias,
329+
self.top_k,
330+
True, # apply_norm_weight,
331+
False,
332+
)
294333
return topk_idx, topk_weights
295334

296335
@abstractmethod

fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,15 +53,15 @@ def get_moe_scores(
5353
"""
5454
scores = paddle.nn.functional.sigmoid(gating_output)
5555
scores_with_bias = scores + e_score_correction_bias.unsqueeze(0)
56-
scores = noaux_tc(
56+
scores, topk_values, topk_idx = noaux_tc(
5757
scores,
5858
scores_with_bias,
5959
n_group,
6060
topk_group,
6161
top_k,
6262
routed_scaling_factor,
6363
)
64-
return scores
64+
return scores, topk_values, topk_idx
6565

6666

6767
class CutlassMoEMethod(MoEMethodBase):
@@ -248,7 +248,7 @@ def apply_tp(
248248
Paddle Cutlass compute Fused MoE.
249249
"""
250250
if layer.topk_method == "noaux_tc":
251-
gate_out = get_moe_scores(
251+
gate_out, _, _ = get_moe_scores(
252252
gate_out,
253253
layer.n_group,
254254
layer.topk_group,

fastdeploy/model_executor/layers/moe/fused_moe_marlin_backend.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,15 +41,15 @@ def get_moe_scores(
4141
"""
4242
scores = paddle.nn.functional.sigmoid(gating_output)
4343
scores_with_bias = scores + e_score_correction_bias.unsqueeze(0)
44-
scores = noaux_tc(
44+
scores, topk_values, topk_idx = noaux_tc(
4545
scores,
4646
scores_with_bias,
4747
n_group,
4848
topk_group,
4949
top_k,
5050
routed_scaling_factor,
5151
)
52-
return scores
52+
return scores, topk_values, topk_idx
5353

5454

5555
def gptq_marlin_moe_repack(
@@ -233,7 +233,7 @@ def apply(
233233
topk_method = layer.topk_method
234234

235235
if topk_method == "noaux_tc":
236-
gate_out = get_moe_scores(
236+
gate_out, _, _ = get_moe_scores(
237237
gate_out,
238238
layer.n_group,
239239
layer.topk_group,

test/operators/test_noaux_tc.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
import unittest
2+
3+
import paddle
4+
5+
from fastdeploy.model_executor.ops.gpu import noaux_tc
6+
7+
8+
class TestMoeRouting(unittest.TestCase):
9+
def setUp(self):
10+
self.num_tokens = 10
11+
self.num_experts = 64
12+
self.gating_output = paddle.rand([self.num_tokens, self.num_experts])
13+
self.e_score_correction_bias = paddle.rand([self.num_experts])
14+
self.n_group = 8
15+
self.topk_group = 4
16+
self.top_k = 8
17+
self.routed_scaling_factor = 1.5
18+
19+
def node_limit_routing(self, gate_probs):
20+
"""将所有专家分组, 只在topk_group个group内选择专家"""
21+
assert len(gate_probs.shape) == 2
22+
seq_length, n_experts = gate_probs.shape
23+
24+
group_scores = gate_probs.reshape([seq_length, 8, -1]).topk(2, axis=-1)[0].sum(axis=-1)
25+
group_idx = paddle.topk(group_scores, k=4, axis=-1, sorted=True)[1]
26+
group_mask = paddle.zeros_like(group_scores).put_along_axis(
27+
group_idx, paddle.ones([], dtype="float32"), axis=-1
28+
)
29+
score_mask = group_mask.unsqueeze(-1).expand([seq_length, 8, n_experts // 8]).reshape([seq_length, -1])
30+
gate_probs = gate_probs.masked_fill(~score_mask.astype(paddle.bool), float("-inf"))
31+
return gate_probs
32+
33+
def ref_moe_routing(self):
34+
scores = paddle.nn.functional.sigmoid(self.gating_output)
35+
prob_for_choice = scores + self.e_score_correction_bias.unsqueeze(0)
36+
prob_for_choice = self.node_limit_routing(prob_for_choice)
37+
top_logits, topk_idx_ref = paddle.topk(prob_for_choice, self.top_k, axis=1)
38+
39+
token_num, top_k = topk_idx_ref.shape
40+
_, num_expert = prob_for_choice.shape
41+
topk_idx_expanded = paddle.unsqueeze(topk_idx_ref, axis=-1)
42+
indices = paddle.concat(
43+
[
44+
paddle.arange(token_num, dtype="int64").unsqueeze(1).tile([1, top_k]).unsqueeze(-1),
45+
topk_idx_expanded,
46+
],
47+
axis=-1,
48+
)
49+
selected_gate_probs = paddle.gather_nd(scores, indices)
50+
51+
selected_gate_probs_sum = paddle.sum(selected_gate_probs, axis=1, keepdim=True)
52+
topk_weights_ref = selected_gate_probs / selected_gate_probs_sum
53+
topk_weights_ref = topk_weights_ref * self.routed_scaling_factor
54+
return topk_weights_ref, topk_idx_ref
55+
56+
def test_moe_select(self):
57+
scores = paddle.nn.functional.sigmoid(self.gating_output)
58+
scores_with_bias = scores + self.e_score_correction_bias.unsqueeze(0)
59+
60+
scores, topk_values, topk_idx = noaux_tc(
61+
scores,
62+
scores_with_bias,
63+
self.n_group,
64+
self.topk_group,
65+
self.top_k,
66+
self.routed_scaling_factor,
67+
)
68+
69+
ref_topk_values, ref_topk_idx = self.ref_moe_routing()
70+
71+
paddle.allclose(topk_values, ref_topk_values)
72+
paddle.allclose(topk_idx.cast(int), ref_topk_idx.cast(int))
73+
74+
75+
if __name__ == "__main__":
76+
unittest.main()

0 commit comments

Comments
 (0)