22
22
LinearMMConfig ,
23
23
ScaledMMConfig ,
24
24
)
25
+ from torchao .float8 .float8_utils import mean_absolute_percentage_error
25
26
from torchao .float8 .fsdp_utils import WeightWithDynamicFloat8CastTensor
26
27
27
28
28
- @torch ._dynamo .allow_in_graph
29
+ @torch ._dynamo .disable
30
+ def log_mape (fqn , gemm_name , mape ):
31
+ # TODO(future): use logging instead of print, will be more annoying to test
32
+ # with pytest
33
+ print (f"fqn: { fqn } , gemm_name: { gemm_name } , mape: { mape } " )
34
+
35
+
36
+ # note: need to remove torch._dynamo.allow_in_graph for logging to work with torch.compile
37
+ # @torch._dynamo.allow_in_graph
29
38
class matmul_with_hp_or_float8_args (torch .autograd .Function ):
30
39
"""
31
40
Like torch.matmul, but with the arguments in either high precision or float8.
@@ -41,10 +50,12 @@ def forward(
41
50
weight_hp_t : torch .Tensor ,
42
51
linear_mm_config : LinearMMConfig ,
43
52
config : Float8LinearConfig ,
53
+ debug_fqn : Optional [str ],
44
54
):
45
55
ctx .save_for_backward (input_hp , weight_hp_t )
46
56
ctx .linear_mm_config = linear_mm_config
47
57
ctx .config = config
58
+ ctx .debug_fqn = debug_fqn
48
59
49
60
c = config
50
61
@@ -87,13 +98,21 @@ def forward(
87
98
orig_shape = input_maybe_fp8 .shape
88
99
input_maybe_fp8_reshaped = input_maybe_fp8 .reshape (- 1 , orig_shape [- 1 ])
89
100
res_bits = torch .mm (input_maybe_fp8_reshaped , weight_maybe_fp8_t )
101
+
102
+ if config ._enable_debug_logging :
103
+ input_hp_reshaped = input_hp .reshape (- 1 , orig_shape [- 1 ])
104
+ ref_result = torch .mm (input_hp_reshaped , weight_hp_t )
105
+ output_mape = mean_absolute_percentage_error (ref_result , res_bits )
106
+ log_mape (debug_fqn , "output" , output_mape )
107
+
90
108
res_bits = res_bits .reshape (* orig_shape [:- 1 ], res_bits .shape [- 1 ])
91
109
return res_bits
92
110
93
111
@staticmethod
94
112
def backward (ctx , grad_output ):
95
113
input_hp , weight_hp_t = ctx .saved_tensors
96
114
c = ctx .config
115
+ debug_fqn = ctx .debug_fqn
97
116
98
117
# the reshapes are needed in order to make the shapes compatible with
99
118
# torch.mm
@@ -144,6 +163,10 @@ def backward(ctx, grad_output):
144
163
grad_output_reshaped_maybe_fp8_dim0 ,
145
164
weight_t_maybe_fp8_dim0 .t (),
146
165
)
166
+ if c ._enable_debug_logging :
167
+ ref_grad_input = torch .mm (grad_output_reshaped , weight_hp_t .t ())
168
+ grad_input_mape = mean_absolute_percentage_error (ref_grad_input , grad_input )
169
+ log_mape (debug_fqn , "grad_input" , grad_input_mape )
147
170
grad_input = grad_input .reshape (
148
171
* grad_output_orig_shape [:- 1 ], grad_input .shape [- 1 ]
149
172
)
@@ -198,8 +221,19 @@ def backward(ctx, grad_output):
198
221
grad_output_reshaped_maybe_fp8_dim1 .t (),
199
222
input_reshaped_maybe_fp8_dim1 ,
200
223
)
224
+ if c ._enable_debug_logging :
225
+ # don't log if this gemm is in high precision
226
+ this_gemm_is_hp = (
227
+ c .cast_config_input_for_grad_weight .scaling_type is ScalingType .DISABLED
228
+ )
229
+ if not this_gemm_is_hp :
230
+ ref_grad_weight = torch .mm (grad_output_reshaped .t (), input_hp_reshaped )
231
+ grad_weight_mape = mean_absolute_percentage_error (
232
+ ref_grad_weight , grad_weight
233
+ )
234
+ log_mape (debug_fqn , "grad_weight" , grad_weight_mape )
201
235
202
- empty_grads = None , None
236
+ empty_grads = None , None , None
203
237
204
238
return grad_input , grad_weight .t (), * empty_grads
205
239
@@ -252,6 +286,10 @@ def __init__(self, *args, **kwargs):
252
286
),
253
287
)
254
288
289
+ # debugging only, API may change at any time. This is expected to be
290
+ # set by the user in a separate API call.
291
+ self ._debug_fqn : Optional [str ] = None
292
+
255
293
def forward (self , input : torch .Tensor ) -> torch .Tensor :
256
294
# Duplicate the autocast logic for F.linear, so that the output
257
295
# of our module has the right original precision
@@ -266,6 +304,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
266
304
self .weight .t (),
267
305
self .linear_mm_config ,
268
306
self .config ,
307
+ self ._debug_fqn ,
269
308
)
270
309
271
310
if self .bias is not None :
0 commit comments