@@ -917,18 +917,17 @@ def bwd_gate_up_weight(self, do1, input_x, expert_w1, clear_input=False):
917
917
@paddle .no_grad ()
918
918
def forward (self , hs_out , unzipped_probs , tokens_per_expert , origin_token_per_experts , output = None ):
919
919
self .origin_token_per_experts = origin_token_per_experts
920
+ # deal 0 size
921
+ dtype = paddle .bfloat16
920
922
if hs_out is None :
921
923
assert self .input_fp8 is not None
922
924
assert self .input_scale is not None
923
925
shape = self .input_fp8 .shape
924
- dtype = paddle .bfloat16
925
926
else :
926
927
if isinstance (hs_out , tuple ):
927
928
shape = hs_out [0 ].shape
928
- dtype = hs_out [0 ].dtype
929
929
else :
930
930
shape = hs_out .shape
931
- dtype = hs_out .dtype
932
931
933
932
if shape [0 ] == 0 :
934
933
o3 = paddle .zeros (shape , dtype = dtype )
@@ -958,6 +957,12 @@ def forward(self, hs_out, unzipped_probs, tokens_per_expert, origin_token_per_ex
958
957
959
958
@paddle .no_grad ()
960
959
def backward (self , out_grad ):
960
+ # deal 0 size
961
+ dtype = paddle .bfloat16
962
+ shape = out_grad [0 ].shape if isinstance (out_grad , tuple ) else out_grad .shape
963
+ if shape [0 ] == 0 :
964
+ return paddle .zeros_like (out_grad , dtype = dtype ), paddle .zeros_like (self .unzipped_probs , dtype = dtype )
965
+
961
966
# recompute expert_w2 and expert_w1
962
967
expert_w1 = [x .w1 for x in self .experts if x is not None ]
963
968
expert_w2 = [x .w2 for x in self .experts if x is not None ]
@@ -995,6 +1000,12 @@ def backward(self, out_grad):
995
1000
996
1001
@paddle .no_grad ()
997
1002
def backward_dx (self , out_grad ):
1003
+ # deal 0 size
1004
+ dtype = paddle .bfloat16
1005
+ shape = out_grad [0 ].shape if isinstance (out_grad , tuple ) else out_grad .shape
1006
+ if shape [0 ] == 0 :
1007
+ return paddle .zeros_like (out_grad , dtype = dtype ), paddle .zeros_like (self .unzipped_probs , dtype = dtype )
1008
+
998
1009
# recompute expert_w2 and expert_w1
999
1010
expert_w1 = [x .w1 for x in self .experts if x is not None ]
1000
1011
expert_w2 = [x .w2 for x in self .experts if x is not None ]
@@ -1027,6 +1038,9 @@ def backward_dx(self, out_grad):
1027
1038
1028
1039
@paddle .no_grad ()
1029
1040
def backward_dw (self ):
1041
+ # deal 0 size
1042
+ if self .input_fp8 is None or self .input_fp8 .shape [0 ] == 0 :
1043
+ return
1030
1044
# recompute expert_w2 and expert_w1
1031
1045
expert_w1 = [x .w1 for x in self .experts if x is not None ]
1032
1046
expert_w2 = [x .w2 for x in self .experts if x is not None ]
0 commit comments