Skip to content

Commit 33e01f1

Browse files
chen2016013phlrain
andauthored
replace custom ops with paddle ops (#10857)
Co-authored-by: phlrain <[email protected]>
1 parent 6cd171e commit 33e01f1

File tree

2 files changed

+46
-47
lines changed

2 files changed

+46
-47
lines changed

paddlenlp/transformers/deepseek_v2/modeling_pp.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,6 @@
7070

7171

7272
DSV3_USE_FP8_GEMM = os.getenv("DSV3_USE_FP8_GEMM", "False").lower() == "true"
73-
DSV3_USE_FP8_DISPATCH = os.getenv("DSV3_USE_FP8_DISPATCH", "False").lower() == "true"
7473

7574

7675
def parse_args(args):

paddlenlp/transformers/moe_utils.py

Lines changed: 46 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,6 @@
1919

2020
from .fp8_utils import dequantize_fp8_to_fp32
2121

22-
try:
23-
import TokenDispatcherUtils as TDU
24-
except:
25-
pass
26-
2722
if not hasattr(paddle.Tensor, "_clear_to_zero_allocation"):
2823

2924
def _clear_to_zero_allocation(self):
@@ -151,17 +146,21 @@ def forward(
151146
tokens_per_expert,
152147
):
153148
if isinstance(hs_2d_dispatched, tuple):
154-
(unzipped_tokens, zipped_expertwise_rowmap, unzipped_probs, unzipped_scale,) = TDU.tokens_unzip_stable(
155-
hs_2d_dispatched[0],
156-
hs_2d_dispatched[1],
157-
dispatched_indices,
158-
dispatched_probs,
159-
topk=topk,
160-
num_experts=num_experts,
161-
tokens_per_expert=tokens_per_expert,
162-
padding_multiplex=128,
163-
fill_output=True,
164-
)
149+
with paddle.amp.auto_cast(False):
150+
(
151+
unzipped_tokens,
152+
zipped_expertwise_rowmap,
153+
unzipped_probs,
154+
unzipped_scale,
155+
) = paddle.nn.functional.moe_permute(
156+
hs_2d_dispatched[0],
157+
hs_2d_dispatched[1],
158+
dispatched_indices,
159+
dispatched_probs,
160+
num_experts=num_experts,
161+
tokens_per_expert=tokens_per_expert,
162+
padding_alignment=128,
163+
)
165164
else:
166165
with paddle.amp.auto_cast(False):
167166
(
@@ -184,16 +183,17 @@ def forward(
184183

185184
@paddle.no_grad()
186185
def backward(self, dx, hidden_states_out_grad, probs_grad, dispatched_indices, num_experts):
187-
weighted_zipped_tokens, probs_grad_zipped = TDU.tokens_zip(
188-
dx,
189-
self.zipped_expertwise_rowmap,
190-
dispatched_indices,
191-
probs_grad,
192-
total_zipped_tokens=hidden_states_out_grad[0].shape[0]
193-
if isinstance(hidden_states_out_grad, tuple)
194-
else hidden_states_out_grad.shape[0],
195-
num_experts=num_experts,
196-
)
186+
with paddle.amp.auto_cast(False):
187+
weighted_zipped_tokens, probs_grad_zipped = paddle.nn.functional.moe_unpermute(
188+
dx,
189+
self.zipped_expertwise_rowmap,
190+
dispatched_indices,
191+
probs_grad,
192+
total_zipped_tokens=hidden_states_out_grad[0].shape[0]
193+
if isinstance(hidden_states_out_grad, tuple)
194+
else hidden_states_out_grad.shape[0],
195+
num_experts=num_experts,
196+
)
197197
self.reset_statue()
198198
return weighted_zipped_tokens, probs_grad_zipped
199199

@@ -207,9 +207,10 @@ def __init__(self, token_dispatcher, name="zip"):
207207
def forward(
208208
self, expert_out, zipped_expertwise_rowmap, routemap_topk, unzipped_probs, total_zipped_tokens, num_experts
209209
):
210-
expert_out_zipped, zipped_probs_topk = TDU.tokens_zip(
211-
expert_out, zipped_expertwise_rowmap, routemap_topk, unzipped_probs, total_zipped_tokens, num_experts
212-
)
210+
with paddle.amp.auto_cast(False):
211+
expert_out_zipped, zipped_probs_topk = paddle.nn.functional.moe_unpermute(
212+
expert_out, zipped_expertwise_rowmap, routemap_topk, unzipped_probs, total_zipped_tokens, num_experts
213+
)
213214
return expert_out_zipped
214215

215216
@paddle.no_grad()
@@ -223,23 +224,22 @@ def backward(
223224
tokens_per_expert,
224225
):
225226
if isinstance(grad_output, tuple):
226-
(
227-
unzipped_grad,
228-
zipped_expertwise_rowmap_grad,
229-
unzipped_probs_grad,
230-
unzipped_scale_grad,
231-
) = TDU.tokens_unzip_stable(
232-
grad_output[0],
233-
grad_output[1],
234-
dispatched_indices,
235-
dispatched_probs,
236-
top_k,
237-
num_experts,
238-
tokens_per_expert,
239-
padding_multiplex=128,
240-
fill_output=True,
241-
)
242-
return (unzipped_grad, unzipped_scale_grad)
227+
with paddle.amp.auto_cast(False):
228+
(
229+
unzipped_grad,
230+
zipped_expertwise_rowmap_grad,
231+
unzipped_probs_grad,
232+
unzipped_scale_grad,
233+
) = paddle.nn.functional.moe_permute(
234+
grad_output[0],
235+
grad_output[1],
236+
dispatched_indices,
237+
dispatched_probs,
238+
num_experts,
239+
tokens_per_expert,
240+
padding_alignment=128,
241+
)
242+
return (unzipped_grad, unzipped_scale_grad)
243243
else:
244244
with paddle.amp.auto_cast(False):
245245
(

0 commit comments

Comments
 (0)