Skip to content

Commit 2ff16f8

Browse files
authored
fix 0 size bug (#10906)
1 parent 48e3e32 commit 2ff16f8

File tree

1 file changed

+17
-3
lines changed

1 file changed

+17
-3
lines changed

paddlenlp/transformers/fp8_utils.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -917,18 +917,17 @@ def bwd_gate_up_weight(self, do1, input_x, expert_w1, clear_input=False):
917917
@paddle.no_grad()
918918
def forward(self, hs_out, unzipped_probs, tokens_per_expert, origin_token_per_experts, output=None):
919919
self.origin_token_per_experts = origin_token_per_experts
920+
# deal 0 size
921+
dtype = paddle.bfloat16
920922
if hs_out is None:
921923
assert self.input_fp8 is not None
922924
assert self.input_scale is not None
923925
shape = self.input_fp8.shape
924-
dtype = paddle.bfloat16
925926
else:
926927
if isinstance(hs_out, tuple):
927928
shape = hs_out[0].shape
928-
dtype = hs_out[0].dtype
929929
else:
930930
shape = hs_out.shape
931-
dtype = hs_out.dtype
932931

933932
if shape[0] == 0:
934933
o3 = paddle.zeros(shape, dtype=dtype)
@@ -958,6 +957,12 @@ def forward(self, hs_out, unzipped_probs, tokens_per_expert, origin_token_per_ex
958957

959958
@paddle.no_grad()
960959
def backward(self, out_grad):
960+
# deal 0 size
961+
dtype = paddle.bfloat16
962+
shape = out_grad[0].shape if isinstance(out_grad, tuple) else out_grad.shape
963+
if shape[0] == 0:
964+
return paddle.zeros_like(out_grad, dtype=dtype), paddle.zeros_like(self.unzipped_probs, dtype=dtype)
965+
961966
# recompute expert_w2 and expert_w1
962967
expert_w1 = [x.w1 for x in self.experts if x is not None]
963968
expert_w2 = [x.w2 for x in self.experts if x is not None]
@@ -995,6 +1000,12 @@ def backward(self, out_grad):
9951000

9961001
@paddle.no_grad()
9971002
def backward_dx(self, out_grad):
1003+
# deal 0 size
1004+
dtype = paddle.bfloat16
1005+
shape = out_grad[0].shape if isinstance(out_grad, tuple) else out_grad.shape
1006+
if shape[0] == 0:
1007+
return paddle.zeros_like(out_grad, dtype=dtype), paddle.zeros_like(self.unzipped_probs, dtype=dtype)
1008+
9981009
# recompute expert_w2 and expert_w1
9991010
expert_w1 = [x.w1 for x in self.experts if x is not None]
10001011
expert_w2 = [x.w2 for x in self.experts if x is not None]
@@ -1027,6 +1038,9 @@ def backward_dx(self, out_grad):
10271038

10281039
@paddle.no_grad()
10291040
def backward_dw(self):
1041+
# deal 0 size
1042+
if self.input_fp8 is None or self.input_fp8.shape[0] == 0:
1043+
return
10301044
# recompute expert_w2 and expert_w1
10311045
expert_w1 = [x.w1 for x in self.experts if x is not None]
10321046
expert_w2 = [x.w2 for x in self.experts if x is not None]

0 commit comments

Comments
 (0)