Skip to content

Commit 740f471

Browse files
authored
fix numel bug (#10979)
1 parent 45b9caa commit 740f471

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

paddlenlp/transformers/fp8_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,7 @@ def compute_expert_w_grad(
281281
统一处理 expert_w 的梯度计算(支持 main_grad 和普通 grad)
282282
"""
283283

284-
if input_t is None or input_t.numel() == 0:
284+
if input_t is None or numpy.prod(input_t.shape) == 0:
285285
return
286286

287287
if hasattr(weight, "main_grad"):

0 commit comments

Comments
 (0)