Skip to content

Commit b5b94e5

Browse files
authored
delete add (#10813)
1 parent dad134e commit b5b94e5

File tree

1 file changed

+38
-8
lines changed

1 file changed

+38
-8
lines changed

paddlenlp/transformers/fp8_utils.py

Lines changed: 38 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,8 @@ def padding(x, axis):
9494

9595
class FP8LinearFunction(paddle.autograd.PyLayer):
9696
@staticmethod
97-
def forward(ctx, x, weight):
97+
def forward(ctx, x, custom_map):
98+
weight = custom_map.weight
9899
x_orig_shape = x.shape
99100
x_t = x.T
100101

@@ -156,9 +157,23 @@ def backward(ctx, dout):
156157
dx = dx.reshape(dx_orig_shape)
157158

158159
# ===== dw1 = deep_gemm(x_t_fp8, dout_t_fp8)
159-
dweight = kitchen_fp8_gemm(x_t_fp8, x_t_scale, dout_t_fp8, dout_t_scale, True, True, rtn_dtype=paddle.float32)
160+
if hasattr(weight, "main_grad"):
161+
if weight.main_grad is None:
162+
weight.main_grad = paddle.zeros(shape=weight.shape, dtype=paddle.float32)
163+
kitchen_fp8_gemm(
164+
x_t_fp8, x_t_scale, dout_t_fp8, dout_t_scale, True, True, weight.main_grad, rtn_dtype=paddle.float32
165+
)
166+
else:
167+
if weight.grad is None:
168+
weight.grad = paddle.zeros(shape=weight.shape, dtype=paddle.float32)
169+
kitchen_fp8_gemm(
170+
x_t_fp8, x_t_scale, dout_t_fp8, dout_t_scale, True, True, weight.grad, rtn_dtype=paddle.float32
171+
)
160172

161-
return dx, dweight
173+
if hasattr(weight, "_apply_backward_hook"):
174+
weight._apply_backward_hook()
175+
176+
return dx
162177

163178

164179
class FP8Linear(paddle.nn.Layer):
@@ -173,12 +188,13 @@ def __init__(self, in_features: int, out_features: int, bias_attr: bool = False)
173188
)
174189

175190
def forward(self, x):
176-
return FP8LinearFunction.apply(x, self.weight)
191+
return FP8LinearFunction.apply(x, self)
177192

178193

179194
class FP8LinearKeepXFunction(paddle.autograd.PyLayer):
180195
@staticmethod
181-
def forward(ctx, x, weight):
196+
def forward(ctx, x, custom_map):
197+
weight = custom_map.weight
182198
x_orig_shape = x.shape
183199

184200
# deep_gemm only support 2D
@@ -234,9 +250,23 @@ def backward(ctx, dout):
234250
dx = dx.reshape(dx_orig_shape)
235251

236252
# ===== dw1 = deep_gemm(x_t_fp8, dout_t_fp8)
237-
dweight = kitchen_fp8_gemm(x_t_fp8, x_t_scale, dout_t_fp8, dout_t_scale, True, True, rtn_dtype=paddle.float32)
253+
if hasattr(weight, "main_grad"):
254+
if weight.main_grad is None:
255+
weight.main_grad = paddle.zeros(shape=weight.shape, dtype=paddle.float32)
256+
kitchen_fp8_gemm(
257+
x_t_fp8, x_t_scale, dout_t_fp8, dout_t_scale, True, True, weight.main_grad, rtn_dtype=paddle.float32
258+
)
259+
else:
260+
if weight.grad is None:
261+
weight.grad = paddle.zeros(shape=weight.shape, dtype=paddle.float32)
262+
kitchen_fp8_gemm(
263+
x_t_fp8, x_t_scale, dout_t_fp8, dout_t_scale, True, True, weight.grad, rtn_dtype=paddle.float32
264+
)
238265

239-
return dx, dweight
266+
if hasattr(weight, "_apply_backward_hook"):
267+
weight._apply_backward_hook()
268+
269+
return dx
240270

241271

242272
class FP8KeepXLinear(paddle.nn.Layer):
@@ -251,7 +281,7 @@ def __init__(self, in_features: int, out_features: int, bias_attr: bool = False)
251281
)
252282

253283
def forward(self, x):
254-
return FP8LinearKeepXFunction.apply(x, self.weight)
284+
return FP8LinearKeepXFunction.apply(x, self)
255285

256286

257287
def fp8_mlp_fwd(x, w1, w2):

0 commit comments

Comments
 (0)