1
- import numpy as np
1
+ # Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
2
15
import paddle
3
16
import TokenDispatcherUtils as TDU
4
17
5
18
6
- def fabricate_dispatch_result (
7
- seqlen , token_length , topk , num_experts , data_type = "bfloat32" , broadcast_ratio = 0.5
8
- ):
19
+ def fabricate_dispatch_result (seqlen , token_length , topk , num_experts , data_type = "bfloat32" , broadcast_ratio = 0.5 ):
9
20
tokens = paddle .randn ([seqlen , token_length ], dtype = data_type )
10
21
11
22
tokens_scale = paddle .empty ([0 ])
@@ -47,9 +58,7 @@ def fabricate_dispatch_result(
47
58
valid_experts = valid_indices [valid_mask ]
48
59
49
60
# 使用histogram统计每个专家的token数
50
- expert_counts = paddle .histogram (
51
- valid_experts , bins = num_experts , min = 0 , max = num_experts - 1
52
- )
61
+ expert_counts = paddle .histogram (valid_experts , bins = num_experts , min = 0 , max = num_experts - 1 )
53
62
expert_counts = paddle .cast (expert_counts , "int32" )
54
63
expert_counts = list (expert_counts )
55
64
print ("expert counts: " , expert_counts )
@@ -78,11 +87,7 @@ def test_unzip_zip():
78
87
for expert_num in [4 , 8 , 16 , 32 ]:
79
88
for topk in [4 , 8 , 12 ]:
80
89
print ("###################################" )
81
- print (
82
- "testing with {} experts and topk {}, datatype is {}" .format (
83
- expert_num , topk , dt
84
- )
85
- )
90
+ print ("testing with {} experts and topk {}, datatype is {}" .format (expert_num , topk , dt ))
86
91
(
87
92
tokens ,
88
93
tokens_scale ,
@@ -112,7 +117,8 @@ def test_unzip_zip():
112
117
topk = topk ,
113
118
num_experts = expert_num ,
114
119
tokens_per_expert = expert_tokens_count ,
115
- padding_multiplex = 128
120
+ padding_multiplex = 128 ,
121
+ fill_output = True ,
116
122
)
117
123
tokens_recovered , probs_recovered = TDU .tokens_zip (
118
124
(unzipped_tokens * unzipped_probs .unsqueeze (- 1 )).astype ("bfloat16" ),
@@ -122,11 +128,7 @@ def test_unzip_zip():
122
128
total_zipped_tokens = SEQLEN ,
123
129
num_experts = expert_num ,
124
130
)
125
- print (
126
- "unzip-zip tokens 最大绝对误差:{}, 相对误差:{}" .format (
127
- * tensor_max_abs_rel_err (tokens , tokens_recovered )
128
- )
129
- )
131
+ print ("unzip-zip tokens 最大绝对误差:{}, 相对误差:{}" .format (* tensor_max_abs_rel_err (tokens , tokens_recovered )))
130
132
print (
131
133
"unzip-zip probs 最大绝对误差:{}, 相对误差:{}" .format (
132
134
* tensor_max_abs_rel_err (dispatched_probs , probs_recovered )
0 commit comments