@@ -94,7 +94,8 @@ def padding(x, axis):
94
94
95
95
class FP8LinearFunction (paddle .autograd .PyLayer ):
96
96
@staticmethod
97
- def forward (ctx , x , weight ):
97
+ def forward (ctx , x , custom_map ):
98
+ weight = custom_map .weight
98
99
x_orig_shape = x .shape
99
100
x_t = x .T
100
101
@@ -156,9 +157,23 @@ def backward(ctx, dout):
156
157
dx = dx .reshape (dx_orig_shape )
157
158
158
159
# ===== dw1 = deep_gemm(x_t_fp8, dout_t_fp8)
159
- dweight = kitchen_fp8_gemm (x_t_fp8 , x_t_scale , dout_t_fp8 , dout_t_scale , True , True , rtn_dtype = paddle .float32 )
160
+ if hasattr (weight , "main_grad" ):
161
+ if weight .main_grad is None :
162
+ weight .main_grad = paddle .zeros (shape = weight .shape , dtype = paddle .float32 )
163
+ kitchen_fp8_gemm (
164
+ x_t_fp8 , x_t_scale , dout_t_fp8 , dout_t_scale , True , True , weight .main_grad , rtn_dtype = paddle .float32
165
+ )
166
+ else :
167
+ if weight .grad is None :
168
+ weight .grad = paddle .zeros (shape = weight .shape , dtype = paddle .float32 )
169
+ kitchen_fp8_gemm (
170
+ x_t_fp8 , x_t_scale , dout_t_fp8 , dout_t_scale , True , True , weight .grad , rtn_dtype = paddle .float32
171
+ )
160
172
161
- return dx , dweight
173
+ if hasattr (weight , "_apply_backward_hook" ):
174
+ weight ._apply_backward_hook ()
175
+
176
+ return dx
162
177
163
178
164
179
class FP8Linear (paddle .nn .Layer ):
@@ -173,12 +188,13 @@ def __init__(self, in_features: int, out_features: int, bias_attr: bool = False)
173
188
)
174
189
175
190
def forward (self , x ):
176
- return FP8LinearFunction .apply (x , self . weight )
191
+ return FP8LinearFunction .apply (x , self )
177
192
178
193
179
194
class FP8LinearKeepXFunction (paddle .autograd .PyLayer ):
180
195
@staticmethod
181
- def forward (ctx , x , weight ):
196
+ def forward (ctx , x , custom_map ):
197
+ weight = custom_map .weight
182
198
x_orig_shape = x .shape
183
199
184
200
# deep_gemm only support 2D
@@ -234,9 +250,23 @@ def backward(ctx, dout):
234
250
dx = dx .reshape (dx_orig_shape )
235
251
236
252
# ===== dw1 = deep_gemm(x_t_fp8, dout_t_fp8)
237
- dweight = kitchen_fp8_gemm (x_t_fp8 , x_t_scale , dout_t_fp8 , dout_t_scale , True , True , rtn_dtype = paddle .float32 )
253
+ if hasattr (weight , "main_grad" ):
254
+ if weight .main_grad is None :
255
+ weight .main_grad = paddle .zeros (shape = weight .shape , dtype = paddle .float32 )
256
+ kitchen_fp8_gemm (
257
+ x_t_fp8 , x_t_scale , dout_t_fp8 , dout_t_scale , True , True , weight .main_grad , rtn_dtype = paddle .float32
258
+ )
259
+ else :
260
+ if weight .grad is None :
261
+ weight .grad = paddle .zeros (shape = weight .shape , dtype = paddle .float32 )
262
+ kitchen_fp8_gemm (
263
+ x_t_fp8 , x_t_scale , dout_t_fp8 , dout_t_scale , True , True , weight .grad , rtn_dtype = paddle .float32
264
+ )
238
265
239
- return dx , dweight
266
+ if hasattr (weight , "_apply_backward_hook" ):
267
+ weight ._apply_backward_hook ()
268
+
269
+ return dx
240
270
241
271
242
272
class FP8KeepXLinear (paddle .nn .Layer ):
@@ -251,7 +281,7 @@ def __init__(self, in_features: int, out_features: int, bias_attr: bool = False)
251
281
)
252
282
253
283
def forward (self , x ):
254
- return FP8LinearKeepXFunction .apply (x , self . weight )
284
+ return FP8LinearKeepXFunction .apply (x , self )
255
285
256
286
257
287
def fp8_mlp_fwd (x , w1 , w2 ):
0 commit comments