Skip to content

Commit a0d96fb

Browse files
committed
[not for land] debug accuracy logging for float8 training
Summary: A lightweight logging flag to log the SQNR between the float8 gemm output and the bf16 gemm output. Test Plan: ```bash ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: da0f0f4 ghstack-comment-id: 3160382784 Pull Request resolved: #2701
1 parent 2db4c76 commit a0d96fb

File tree

4 files changed

+156
-5
lines changed

4 files changed

+156
-5
lines changed

test/float8/test_base.py

Lines changed: 86 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,10 @@
2222
e5m2_dtype,
2323
)
2424
from torchao.float8.float8_linear import Float8Linear
25-
from torchao.float8.float8_linear_utils import convert_to_float8_training
25+
from torchao.float8.float8_linear_utils import (
26+
_populate_debug_fqns,
27+
convert_to_float8_training,
28+
)
2629
from torchao.float8.float8_ops import addmm_float8_unwrapped
2730
from torchao.float8.float8_scaling_utils import (
2831
get_maybe_axiswise_dim,
@@ -395,6 +398,88 @@ def test_linear_from_recipe(
395398
config,
396399
)
397400

401+
@pytest.mark.parametrize(
402+
"recipe_name",
403+
[
404+
Float8LinearRecipeName.TENSORWISE,
405+
Float8LinearRecipeName.ROWWISE,
406+
Float8LinearRecipeName.ROWWISE_WITH_GW_HP,
407+
],
408+
)
409+
def test_debug_logging(self, recipe_name):
410+
x = torch.randn(1, 16, 16, device="cuda", dtype=torch.bfloat16)
411+
m = nn.Sequential(
412+
nn.Linear(16, 32, bias=False, device="cuda", dtype=torch.bfloat16),
413+
nn.Sequential(
414+
nn.ReLU(),
415+
nn.Linear(32, 64, bias=False, device="cuda", dtype=torch.bfloat16),
416+
),
417+
)
418+
config = Float8LinearConfig.from_recipe_name(recipe_name)
419+
420+
@torch.no_grad()
421+
def mean_absolute_percentage_error(x_ref, x):
422+
tmp = torch.abs(x_ref - x) / torch.clamp(torch.abs(x_ref), min=1e-9)
423+
# trim to avoid values close to 0 from
424+
# significantly impacting the results
425+
tmp = torch.clamp(tmp, max=1e3)
426+
return torch.mean(tmp)
427+
428+
iter_counter = 0
429+
iter_fqn_gemm_name_to_data = {}
430+
431+
@torch._dynamo.disable
432+
def debug_logging_fn(fqn, gemm_name, a_hp, b_hp, a_fp8, b_fp8):
433+
"""
434+
Example debugging function - this is user defined, easy to customize
435+
1. captures M, K, N
436+
2. captures MAPE for high precision vs float8 gemm
437+
3. leaves data on GPU, so the user can move it to CPU at their
438+
convenience
439+
"""
440+
M, K = a_hp.shape
441+
K2, N = b_hp.shape
442+
assert K == K2
443+
res_hp = a_hp @ b_hp
444+
res_fp8 = a_fp8 @ b_fp8
445+
mape = mean_absolute_percentage_error(res_hp, res_fp8)
446+
iter_fqn_gemm_name_to_data[(iter_counter, fqn, gemm_name)] = (M, K, N), mape
447+
448+
object.__setattr__(config, "_debug_logging_fn", debug_logging_fn)
449+
m = convert_to_float8_training(m, config=config)
450+
_populate_debug_fqns(m)
451+
452+
# iter 0
453+
m = torch.compile(m)
454+
y = m(x)
455+
y.sum().backward()
456+
457+
# iter 1
458+
iter_counter += 1
459+
m = torch.compile(m)
460+
y = m(x)
461+
y.sum().backward()
462+
463+
if recipe_name == Float8LinearRecipeName.ROWWISE_WITH_GW_HP:
464+
# check length is num_float8_layers * num_gemms_per_layer * num_iters
465+
assert len(iter_fqn_gemm_name_to_data) == 2 * 2 * (iter_counter + 1)
466+
# check that some of the expected debug logs exist
467+
assert (0, "0", "output") in iter_fqn_gemm_name_to_data
468+
assert (1, "1.1", "grad_input") in iter_fqn_gemm_name_to_data
469+
else:
470+
# check length is num_float8_layers * num_gemms_per_layer * num_iters
471+
assert len(iter_fqn_gemm_name_to_data) == 2 * 3 * (iter_counter + 1)
472+
# check that some of the expected debug logs exist
473+
assert (0, "0", "output") in iter_fqn_gemm_name_to_data
474+
assert (0, "1.1", "grad_weight") in iter_fqn_gemm_name_to_data
475+
assert (1, "1.1", "grad_input") in iter_fqn_gemm_name_to_data
476+
477+
# check logged data is what we expect
478+
example_data = iter_fqn_gemm_name_to_data[(1, "1.1", "grad_input")]
479+
assert example_data[0] == (16, 64, 32)
480+
assert type(example_data[1]) == torch.Tensor
481+
assert example_data[1].shape == torch.Size()
482+
398483
@pytest.mark.parametrize(
399484
"emulate", [True, False] if is_sm_at_least_89() else [True]
400485
)

torchao/float8/config.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import enum
88
import logging
99
from dataclasses import dataclass
10-
from typing import Optional, Union
10+
from typing import Callable, Optional, Union
1111

1212
import torch
1313

@@ -204,6 +204,17 @@ class Float8LinearConfig:
204204
# same value in the forward pass as the backward passes.
205205
round_scales_to_power_of_2: bool = False
206206

207+
# If specified, the debug fqn, the name of each gemm
208+
# (output/grad_input/grad_weight) and the high_precision and float8 inputs to
209+
# each gemm are passed to this function at each iteration. The intended use
210+
# case is accuracy and performance logging for debugging. This feature is
211+
# prototype and the API may change.
212+
_debug_logging_fn: Optional[
213+
Callable[
214+
[str, str, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], None
215+
]
216+
] = None
217+
207218
def __post_init__(self):
208219
# Populate the additional cast overrides, if the user did not specify them
209220
# Note: this hacks around the frozen-ness of this dataclass

torchao/float8/float8_linear.py

Lines changed: 48 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,9 @@
2525
from torchao.float8.fsdp_utils import WeightWithDynamicFloat8CastTensor
2626

2727

28-
@torch._dynamo.allow_in_graph
28+
# TODO(before land): remove two lines of comments below
29+
# note: need to remove torch._dynamo.allow_in_graph for logging to work with torch.compile
30+
# @torch._dynamo.allow_in_graph
2931
class matmul_with_hp_or_float8_args(torch.autograd.Function):
3032
"""
3133
Like torch.matmul, but with the arguments in either high precision or float8.
@@ -41,10 +43,12 @@ def forward(
4143
weight_hp_t: torch.Tensor,
4244
linear_mm_config: LinearMMConfig,
4345
config: Float8LinearConfig,
46+
debug_fqn: Optional[str],
4447
):
4548
ctx.save_for_backward(input_hp, weight_hp_t)
4649
ctx.linear_mm_config = linear_mm_config
4750
ctx.config = config
51+
ctx.debug_fqn = debug_fqn
4852

4953
c = config
5054

@@ -87,13 +91,26 @@ def forward(
8791
orig_shape = input_maybe_fp8.shape
8892
input_maybe_fp8_reshaped = input_maybe_fp8.reshape(-1, orig_shape[-1])
8993
res_bits = torch.mm(input_maybe_fp8_reshaped, weight_maybe_fp8_t)
94+
95+
if config._debug_logging_fn is not None:
96+
input_hp_reshaped = input_hp.reshape(-1, orig_shape[-1])
97+
config._debug_logging_fn(
98+
debug_fqn,
99+
"output",
100+
input_hp_reshaped,
101+
weight_hp_t,
102+
input_maybe_fp8_reshaped,
103+
weight_maybe_fp8_t,
104+
)
105+
90106
res_bits = res_bits.reshape(*orig_shape[:-1], res_bits.shape[-1])
91107
return res_bits
92108

93109
@staticmethod
94110
def backward(ctx, grad_output):
95111
input_hp, weight_hp_t = ctx.saved_tensors
96112
c = ctx.config
113+
debug_fqn = ctx.debug_fqn
97114

98115
# the reshapes are needed in order to make the shapes compatible with
99116
# torch.mm
@@ -144,6 +161,15 @@ def backward(ctx, grad_output):
144161
grad_output_reshaped_maybe_fp8_dim0,
145162
weight_t_maybe_fp8_dim0.t(),
146163
)
164+
if c._debug_logging_fn is not None:
165+
c._debug_logging_fn(
166+
debug_fqn,
167+
"grad_input",
168+
grad_output_reshaped,
169+
weight_hp_t.t(),
170+
grad_output_reshaped_maybe_fp8_dim0,
171+
weight_t_maybe_fp8_dim0.t(),
172+
)
147173
grad_input = grad_input.reshape(
148174
*grad_output_orig_shape[:-1], grad_input.shape[-1]
149175
)
@@ -198,8 +224,22 @@ def backward(ctx, grad_output):
198224
grad_output_reshaped_maybe_fp8_dim1.t(),
199225
input_reshaped_maybe_fp8_dim1,
200226
)
201-
202-
empty_grads = None, None
227+
if c._debug_logging_fn is not None:
228+
# don't log if this gemm is in high precision
229+
this_gemm_is_hp = (
230+
c.cast_config_input_for_grad_weight.scaling_type is ScalingType.DISABLED
231+
)
232+
if not this_gemm_is_hp:
233+
c._debug_logging_fn(
234+
debug_fqn,
235+
"grad_weight",
236+
grad_output_reshaped.t(),
237+
input_hp_reshaped,
238+
grad_output_reshaped_maybe_fp8_dim1.t(),
239+
input_reshaped_maybe_fp8_dim1,
240+
)
241+
242+
empty_grads = None, None, None
203243

204244
return grad_input, grad_weight.t(), *empty_grads
205245

@@ -252,6 +292,10 @@ def __init__(self, *args, **kwargs):
252292
),
253293
)
254294

295+
# debugging only, API may change at any time. This is expected to be
296+
# set by the user in a separate API call.
297+
self._debug_fqn: Optional[str] = None
298+
255299
def forward(self, input: torch.Tensor) -> torch.Tensor:
256300
# Duplicate the autocast logic for F.linear, so that the output
257301
# of our module has the right original precision
@@ -266,6 +310,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
266310
self.weight.t(),
267311
self.linear_mm_config,
268312
self.config,
313+
self._debug_fqn,
269314
)
270315

271316
if self.bias is not None:

torchao/float8/float8_linear_utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,3 +198,13 @@ def _auto_filter_for_tensorwise(
198198
if K <= 4096 and N <= 1024:
199199
return False
200200
return True
201+
202+
203+
def _populate_debug_fqns(model: nn.Module):
204+
"""Populates the `_debug_fqn` attribute on each `Float8Linear` child of
205+
`model`, useful for debugging. Note that this API is prototype and may
206+
change in the future.
207+
"""
208+
for name, mod in model.named_modules():
209+
if isinstance(mod, Float8Linear):
210+
mod._debug_fqn = name

0 commit comments

Comments
 (0)