19
19
20
20
from .fp8_utils import dequantize_fp8_to_fp32
21
21
22
- try :
23
- import TokenDispatcherUtils as TDU
24
- except :
25
- pass
26
-
27
22
if not hasattr (paddle .Tensor , "_clear_to_zero_allocation" ):
28
23
29
24
def _clear_to_zero_allocation (self ):
@@ -151,17 +146,21 @@ def forward(
151
146
tokens_per_expert ,
152
147
):
153
148
if isinstance (hs_2d_dispatched , tuple ):
154
- (unzipped_tokens , zipped_expertwise_rowmap , unzipped_probs , unzipped_scale ,) = TDU .tokens_unzip_stable (
155
- hs_2d_dispatched [0 ],
156
- hs_2d_dispatched [1 ],
157
- dispatched_indices ,
158
- dispatched_probs ,
159
- topk = topk ,
160
- num_experts = num_experts ,
161
- tokens_per_expert = tokens_per_expert ,
162
- padding_multiplex = 128 ,
163
- fill_output = True ,
164
- )
149
+ with paddle .amp .auto_cast (False ):
150
+ (
151
+ unzipped_tokens ,
152
+ zipped_expertwise_rowmap ,
153
+ unzipped_probs ,
154
+ unzipped_scale ,
155
+ ) = paddle .nn .functional .moe_permute (
156
+ hs_2d_dispatched [0 ],
157
+ hs_2d_dispatched [1 ],
158
+ dispatched_indices ,
159
+ dispatched_probs ,
160
+ num_experts = num_experts ,
161
+ tokens_per_expert = tokens_per_expert ,
162
+ padding_alignment = 128 ,
163
+ )
165
164
else :
166
165
with paddle .amp .auto_cast (False ):
167
166
(
@@ -184,16 +183,17 @@ def forward(
184
183
185
184
@paddle .no_grad ()
186
185
def backward (self , dx , hidden_states_out_grad , probs_grad , dispatched_indices , num_experts ):
187
- weighted_zipped_tokens , probs_grad_zipped = TDU .tokens_zip (
188
- dx ,
189
- self .zipped_expertwise_rowmap ,
190
- dispatched_indices ,
191
- probs_grad ,
192
- total_zipped_tokens = hidden_states_out_grad [0 ].shape [0 ]
193
- if isinstance (hidden_states_out_grad , tuple )
194
- else hidden_states_out_grad .shape [0 ],
195
- num_experts = num_experts ,
196
- )
186
+ with paddle .amp .auto_cast (False ):
187
+ weighted_zipped_tokens , probs_grad_zipped = paddle .nn .functional .moe_unpermute (
188
+ dx ,
189
+ self .zipped_expertwise_rowmap ,
190
+ dispatched_indices ,
191
+ probs_grad ,
192
+ total_zipped_tokens = hidden_states_out_grad [0 ].shape [0 ]
193
+ if isinstance (hidden_states_out_grad , tuple )
194
+ else hidden_states_out_grad .shape [0 ],
195
+ num_experts = num_experts ,
196
+ )
197
197
self .reset_statue ()
198
198
return weighted_zipped_tokens , probs_grad_zipped
199
199
@@ -207,9 +207,10 @@ def __init__(self, token_dispatcher, name="zip"):
207
207
def forward (
208
208
self , expert_out , zipped_expertwise_rowmap , routemap_topk , unzipped_probs , total_zipped_tokens , num_experts
209
209
):
210
- expert_out_zipped , zipped_probs_topk = TDU .tokens_zip (
211
- expert_out , zipped_expertwise_rowmap , routemap_topk , unzipped_probs , total_zipped_tokens , num_experts
212
- )
210
+ with paddle .amp .auto_cast (False ):
211
+ expert_out_zipped , zipped_probs_topk = paddle .nn .functional .moe_unpermute (
212
+ expert_out , zipped_expertwise_rowmap , routemap_topk , unzipped_probs , total_zipped_tokens , num_experts
213
+ )
213
214
return expert_out_zipped
214
215
215
216
@paddle .no_grad ()
@@ -223,23 +224,22 @@ def backward(
223
224
tokens_per_expert ,
224
225
):
225
226
if isinstance (grad_output , tuple ):
226
- (
227
- unzipped_grad ,
228
- zipped_expertwise_rowmap_grad ,
229
- unzipped_probs_grad ,
230
- unzipped_scale_grad ,
231
- ) = TDU .tokens_unzip_stable (
232
- grad_output [0 ],
233
- grad_output [1 ],
234
- dispatched_indices ,
235
- dispatched_probs ,
236
- top_k ,
237
- num_experts ,
238
- tokens_per_expert ,
239
- padding_multiplex = 128 ,
240
- fill_output = True ,
241
- )
242
- return (unzipped_grad , unzipped_scale_grad )
227
+ with paddle .amp .auto_cast (False ):
228
+ (
229
+ unzipped_grad ,
230
+ zipped_expertwise_rowmap_grad ,
231
+ unzipped_probs_grad ,
232
+ unzipped_scale_grad ,
233
+ ) = paddle .nn .functional .moe_permute (
234
+ grad_output [0 ],
235
+ grad_output [1 ],
236
+ dispatched_indices ,
237
+ dispatched_probs ,
238
+ num_experts ,
239
+ tokens_per_expert ,
240
+ padding_alignment = 128 ,
241
+ )
242
+ return (unzipped_grad , unzipped_scale_grad )
243
243
else :
244
244
with paddle .amp .auto_cast (False ):
245
245
(
0 commit comments