Skip to content

Commit ff83763

Browse files
authored
add expert num 32 (#10802)
1 parent b9db2c1 commit ff83763

File tree

2 files changed

+21
-18
lines changed

2 files changed

+21
-18
lines changed

slm/model_zoo/gpt-3/external_ops/token_dispatcher_utils/utils.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,5 +123,6 @@ __device__ __forceinline__ void vectorized_memcpy(const T* src,
123123
auto __num_expert = (__num_experts_expr); \
124124
PD_SWITCH_NUM_EXPERTS_IMPL(__num_expert, 8, __VA_ARGS__); \
125125
PD_SWITCH_NUM_EXPERTS_IMPL(__num_expert, 16, __VA_ARGS__); \
126+
PD_SWITCH_NUM_EXPERTS_IMPL(__num_expert, 32, __VA_ARGS__); \
126127
PD_THROW("Unsupported expert number %d", int(__num_expert)); \
127128
} while (0)

tests/ops/test_unzip_zip.py

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,22 @@
1-
import numpy as np
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
215
import paddle
316
import TokenDispatcherUtils as TDU
417

518

6-
def fabricate_dispatch_result(
7-
seqlen, token_length, topk, num_experts, data_type="bfloat32", broadcast_ratio=0.5
8-
):
19+
def fabricate_dispatch_result(seqlen, token_length, topk, num_experts, data_type="bfloat32", broadcast_ratio=0.5):
920
tokens = paddle.randn([seqlen, token_length], dtype=data_type)
1021

1122
tokens_scale = paddle.empty([0])
@@ -47,9 +58,7 @@ def fabricate_dispatch_result(
4758
valid_experts = valid_indices[valid_mask]
4859

4960
# 使用histogram统计每个专家的token数
50-
expert_counts = paddle.histogram(
51-
valid_experts, bins=num_experts, min=0, max=num_experts - 1
52-
)
61+
expert_counts = paddle.histogram(valid_experts, bins=num_experts, min=0, max=num_experts - 1)
5362
expert_counts = paddle.cast(expert_counts, "int32")
5463
expert_counts = list(expert_counts)
5564
print("expert counts: ", expert_counts)
@@ -78,11 +87,7 @@ def test_unzip_zip():
7887
for expert_num in [4, 8, 16, 32]:
7988
for topk in [4, 8, 12]:
8089
print("###################################")
81-
print(
82-
"testing with {} experts and topk {}, datatype is {}".format(
83-
expert_num, topk, dt
84-
)
85-
)
90+
print("testing with {} experts and topk {}, datatype is {}".format(expert_num, topk, dt))
8691
(
8792
tokens,
8893
tokens_scale,
@@ -112,7 +117,8 @@ def test_unzip_zip():
112117
topk=topk,
113118
num_experts=expert_num,
114119
tokens_per_expert=expert_tokens_count,
115-
padding_multiplex=128
120+
padding_multiplex=128,
121+
fill_output=True,
116122
)
117123
tokens_recovered, probs_recovered = TDU.tokens_zip(
118124
(unzipped_tokens * unzipped_probs.unsqueeze(-1)).astype("bfloat16"),
@@ -122,11 +128,7 @@ def test_unzip_zip():
122128
total_zipped_tokens=SEQLEN,
123129
num_experts=expert_num,
124130
)
125-
print(
126-
"unzip-zip tokens 最大绝对误差:{}, 相对误差:{}".format(
127-
*tensor_max_abs_rel_err(tokens, tokens_recovered)
128-
)
129-
)
131+
print("unzip-zip tokens 最大绝对误差:{}, 相对误差:{}".format(*tensor_max_abs_rel_err(tokens, tokens_recovered)))
130132
print(
131133
"unzip-zip probs 最大绝对误差:{}, 相对误差:{}".format(
132134
*tensor_max_abs_rel_err(dispatched_probs, probs_recovered)

0 commit comments

Comments
 (0)