Skip to content

Commit 2674d1d

Browse files
committed
[not for land] debug accuracy logging for float8 training
Summary: A lightweight logging flag to log the SQNR between the float8 gemm output and the bf16 gemm output. Test Plan: ```bash ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: f419ea2 ghstack-comment-id: 3160382784 Pull Request resolved: #2701
1 parent 5d99ce4 commit 2674d1d

File tree

4 files changed

+81
-3
lines changed

4 files changed

+81
-3
lines changed

test/float8/test_base.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,10 @@
3333
e5m2_dtype,
3434
)
3535
from torchao.float8.float8_linear import Float8Linear
36-
from torchao.float8.float8_linear_utils import convert_to_float8_training
36+
from torchao.float8.float8_linear_utils import (
37+
_populate_debug_fqns,
38+
convert_to_float8_training,
39+
)
3740
from torchao.float8.float8_ops import addmm_float8_unwrapped
3841
from torchao.float8.float8_scaling_utils import (
3942
get_maybe_axiswise_dim,
@@ -400,6 +403,32 @@ def test_linear_from_recipe(
400403
config,
401404
)
402405

406+
@pytest.mark.parametrize(
407+
"recipe_name",
408+
[
409+
Float8LinearRecipeName.TENSORWISE,
410+
Float8LinearRecipeName.ROWWISE,
411+
Float8LinearRecipeName.ROWWISE_WITH_GW_HP,
412+
],
413+
)
414+
def test_debug_logging(self, recipe_name):
415+
x = torch.randn(1, 16, 16, device="cuda", dtype=torch.bfloat16)
416+
m = nn.Sequential(
417+
nn.Linear(16, 32, bias=False, device="cuda", dtype=torch.bfloat16),
418+
nn.Sequential(
419+
nn.ReLU(),
420+
nn.Linear(32, 64, bias=False, device="cuda", dtype=torch.bfloat16),
421+
),
422+
)
423+
config = Float8LinearConfig.from_recipe_name(recipe_name)
424+
object.__setattr__(config, "_enable_debug_logging", True)
425+
m = convert_to_float8_training(m, config=config)
426+
_populate_debug_fqns(m)
427+
m = torch.compile(m)
428+
y = m(x)
429+
y.sum().backward()
430+
# TODO(before land): actually test the values logged to stdout
431+
403432
@pytest.mark.parametrize(
404433
"emulate", [True, False] if is_sm_at_least_89() else [True]
405434
)

torchao/float8/config.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,12 @@ class Float8LinearConfig:
204204
# same value in the forward pass as the backward passes.
205205
round_scales_to_power_of_2: bool = False
206206

207+
# If True, captures accuracy debugging logging comparing high precision gemm
208+
# outputs to their low precision versions, and outputs it to stdout
209+
# Note: this flag is in prototype, has not been extensively tested and the
210+
# API may change.
211+
_enable_debug_logging: bool = False
212+
207213
def __post_init__(self):
208214
# Populate the additional cast overrides, if the user did not specify them
209215
# Note: this hacks around the frozen-ness of this dataclass

torchao/float8/float8_linear.py

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,19 @@
2222
LinearMMConfig,
2323
ScaledMMConfig,
2424
)
25+
from torchao.float8.float8_utils import compute_error
2526
from torchao.float8.fsdp_utils import WeightWithDynamicFloat8CastTensor
2627

2728

28-
@torch._dynamo.allow_in_graph
29+
@torch._dynamo.disable
30+
def log_sqnr(fqn, gemm_name, sqnr):
31+
# TODO(future): use logging instead of print, will be more annoying to test
32+
# with pytest
33+
print(f"fqn: {fqn}, gemm_name: {gemm_name}, sqnr: {sqnr}")
34+
35+
36+
# note: need to remove torch._dynamo.allow_in_graph for logging to work with torch.compile
37+
# @torch._dynamo.allow_in_graph
2938
class matmul_with_hp_or_float8_args(torch.autograd.Function):
3039
"""
3140
Like torch.matmul, but with the arguments in either high precision or float8.
@@ -41,10 +50,12 @@ def forward(
4150
weight_hp_t: torch.Tensor,
4251
linear_mm_config: LinearMMConfig,
4352
config: Float8LinearConfig,
53+
debug_fqn: Optional[str],
4454
):
4555
ctx.save_for_backward(input_hp, weight_hp_t)
4656
ctx.linear_mm_config = linear_mm_config
4757
ctx.config = config
58+
ctx.debug_fqn = debug_fqn
4859

4960
c = config
5061

@@ -87,13 +98,21 @@ def forward(
8798
orig_shape = input_maybe_fp8.shape
8899
input_maybe_fp8_reshaped = input_maybe_fp8.reshape(-1, orig_shape[-1])
89100
res_bits = torch.mm(input_maybe_fp8_reshaped, weight_maybe_fp8_t)
101+
102+
if config._enable_debug_logging:
103+
input_hp_reshaped = input_hp.reshape(-1, orig_shape[-1])
104+
ref_result = torch.mm(input_hp_reshaped, weight_hp_t)
105+
output_sqnr = compute_error(ref_result, res_bits)
106+
log_sqnr(debug_fqn, "output", output_sqnr)
107+
90108
res_bits = res_bits.reshape(*orig_shape[:-1], res_bits.shape[-1])
91109
return res_bits
92110

93111
@staticmethod
94112
def backward(ctx, grad_output):
95113
input_hp, weight_hp_t = ctx.saved_tensors
96114
c = ctx.config
115+
debug_fqn = ctx.debug_fqn
97116

98117
# the reshapes are needed in order to make the shapes compatible with
99118
# torch.mm
@@ -144,6 +163,10 @@ def backward(ctx, grad_output):
144163
grad_output_reshaped_maybe_fp8_dim0,
145164
weight_t_maybe_fp8_dim0.t(),
146165
)
166+
if c._enable_debug_logging:
167+
ref_grad_input = torch.mm(grad_output_reshaped, weight_hp_t.t())
168+
grad_input_sqnr = compute_error(ref_grad_input, grad_input)
169+
log_sqnr(debug_fqn, "grad_input", grad_input_sqnr)
147170
grad_input = grad_input.reshape(
148171
*grad_output_orig_shape[:-1], grad_input.shape[-1]
149172
)
@@ -198,8 +221,17 @@ def backward(ctx, grad_output):
198221
grad_output_reshaped_maybe_fp8_dim1.t(),
199222
input_reshaped_maybe_fp8_dim1,
200223
)
224+
if c._enable_debug_logging:
225+
# don't log if this gemm is in high precision
226+
this_gemm_is_hp = (
227+
c.cast_config_input_for_grad_weight.scaling_type is ScalingType.DISABLED
228+
)
229+
if not this_gemm_is_hp:
230+
ref_grad_weight = torch.mm(grad_output_reshaped.t(), input_hp_reshaped)
231+
grad_weight_sqnr = compute_error(ref_grad_weight, grad_weight)
232+
log_sqnr(debug_fqn, "grad_weight", grad_weight_sqnr)
201233

202-
empty_grads = None, None
234+
empty_grads = None, None, None
203235

204236
return grad_input, grad_weight.t(), *empty_grads
205237

@@ -252,6 +284,10 @@ def __init__(self, *args, **kwargs):
252284
),
253285
)
254286

287+
# debugging only, API may change at any time. This is expected to be
288+
# set by the user in a separate API call.
289+
self._debug_fqn: Optional[str] = None
290+
255291
def forward(self, input: torch.Tensor) -> torch.Tensor:
256292
# Duplicate the autocast logic for F.linear, so that the output
257293
# of our module has the right original precision
@@ -266,6 +302,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
266302
self.weight.t(),
267303
self.linear_mm_config,
268304
self.config,
305+
self._debug_fqn,
269306
)
270307

271308
if self.bias is not None:

torchao/float8/float8_linear_utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,3 +196,9 @@ def _auto_filter_for_tensorwise(
196196
if K <= 4096 and N <= 1024:
197197
return False
198198
return True
199+
200+
201+
def _populate_debug_fqns(model: nn.Module):
202+
for name, mod in model.named_modules():
203+
if isinstance(mod, Float8Linear):
204+
mod._debug_fqn = name

0 commit comments

Comments
 (0)