Skip to content

[not for land] debug accuracy logging for float8 training #2701

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 30 additions & 1 deletion test/float8/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,10 @@
e5m2_dtype,
)
from torchao.float8.float8_linear import Float8Linear
from torchao.float8.float8_linear_utils import convert_to_float8_training
from torchao.float8.float8_linear_utils import (
_populate_debug_fqns,
convert_to_float8_training,
)
from torchao.float8.float8_ops import addmm_float8_unwrapped
from torchao.float8.float8_scaling_utils import (
get_maybe_axiswise_dim,
Expand Down Expand Up @@ -400,6 +403,32 @@ def test_linear_from_recipe(
config,
)

@pytest.mark.parametrize(
"recipe_name",
[
Float8LinearRecipeName.TENSORWISE,
Float8LinearRecipeName.ROWWISE,
Float8LinearRecipeName.ROWWISE_WITH_GW_HP,
],
)
def test_debug_logging(self, recipe_name):
x = torch.randn(1, 16, 16, device="cuda", dtype=torch.bfloat16)
m = nn.Sequential(
nn.Linear(16, 32, bias=False, device="cuda", dtype=torch.bfloat16),
nn.Sequential(
nn.ReLU(),
nn.Linear(32, 64, bias=False, device="cuda", dtype=torch.bfloat16),
),
)
config = Float8LinearConfig.from_recipe_name(recipe_name)
object.__setattr__(config, "_enable_debug_logging", True)
m = convert_to_float8_training(m, config=config)
_populate_debug_fqns(m)
m = torch.compile(m)
y = m(x)
y.sum().backward()
# TODO(before land): actually test the values logged to stdout

@pytest.mark.parametrize(
"emulate", [True, False] if is_sm_at_least_89() else [True]
)
Expand Down
6 changes: 6 additions & 0 deletions torchao/float8/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,12 @@ class Float8LinearConfig:
# same value in the forward pass as the backward passes.
round_scales_to_power_of_2: bool = False

# If True, captures accuracy debugging logging comparing high precision gemm
# outputs to their low precision versions, and outputs it to stdout
# Note: this flag is in prototype, has not been extensively tested and the
# API may change.
_enable_debug_logging: bool = False

def __post_init__(self):
# Populate the additional cast overrides, if the user did not specify them
# Note: this hacks around the frozen-ness of this dataclass
Expand Down
43 changes: 41 additions & 2 deletions torchao/float8/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,19 @@
LinearMMConfig,
ScaledMMConfig,
)
from torchao.float8.float8_utils import mean_absolute_percentage_error
from torchao.float8.fsdp_utils import WeightWithDynamicFloat8CastTensor


@torch._dynamo.allow_in_graph
@torch._dynamo.disable
def log_mape(fqn, gemm_name, mape):
# TODO(future): use logging instead of print, will be more annoying to test
# with pytest
print(f"fqn: {fqn}, gemm_name: {gemm_name}, mape: {mape}")


# note: need to remove torch._dynamo.allow_in_graph for logging to work with torch.compile
# @torch._dynamo.allow_in_graph
class matmul_with_hp_or_float8_args(torch.autograd.Function):
"""
Like torch.matmul, but with the arguments in either high precision or float8.
Expand All @@ -41,10 +50,12 @@ def forward(
weight_hp_t: torch.Tensor,
linear_mm_config: LinearMMConfig,
config: Float8LinearConfig,
debug_fqn: Optional[str],
):
ctx.save_for_backward(input_hp, weight_hp_t)
ctx.linear_mm_config = linear_mm_config
ctx.config = config
ctx.debug_fqn = debug_fqn

c = config

Expand Down Expand Up @@ -87,13 +98,21 @@ def forward(
orig_shape = input_maybe_fp8.shape
input_maybe_fp8_reshaped = input_maybe_fp8.reshape(-1, orig_shape[-1])
res_bits = torch.mm(input_maybe_fp8_reshaped, weight_maybe_fp8_t)

if config._enable_debug_logging:
input_hp_reshaped = input_hp.reshape(-1, orig_shape[-1])
ref_result = torch.mm(input_hp_reshaped, weight_hp_t)
output_mape = mean_absolute_percentage_error(ref_result, res_bits)
log_mape(debug_fqn, "output", output_mape)

res_bits = res_bits.reshape(*orig_shape[:-1], res_bits.shape[-1])
return res_bits

@staticmethod
def backward(ctx, grad_output):
input_hp, weight_hp_t = ctx.saved_tensors
c = ctx.config
debug_fqn = ctx.debug_fqn

# the reshapes are needed in order to make the shapes compatible with
# torch.mm
Expand Down Expand Up @@ -144,6 +163,10 @@ def backward(ctx, grad_output):
grad_output_reshaped_maybe_fp8_dim0,
weight_t_maybe_fp8_dim0.t(),
)
if c._enable_debug_logging:
ref_grad_input = torch.mm(grad_output_reshaped, weight_hp_t.t())
grad_input_mape = mean_absolute_percentage_error(ref_grad_input, grad_input)
log_mape(debug_fqn, "grad_input", grad_input_mape)
grad_input = grad_input.reshape(
*grad_output_orig_shape[:-1], grad_input.shape[-1]
)
Expand Down Expand Up @@ -198,8 +221,19 @@ def backward(ctx, grad_output):
grad_output_reshaped_maybe_fp8_dim1.t(),
input_reshaped_maybe_fp8_dim1,
)
if c._enable_debug_logging:
# don't log if this gemm is in high precision
this_gemm_is_hp = (
c.cast_config_input_for_grad_weight.scaling_type is ScalingType.DISABLED
)
if not this_gemm_is_hp:
ref_grad_weight = torch.mm(grad_output_reshaped.t(), input_hp_reshaped)
grad_weight_mape = mean_absolute_percentage_error(
ref_grad_weight, grad_weight
)
log_mape(debug_fqn, "grad_weight", grad_weight_mape)

empty_grads = None, None
empty_grads = None, None, None

return grad_input, grad_weight.t(), *empty_grads

Expand Down Expand Up @@ -252,6 +286,10 @@ def __init__(self, *args, **kwargs):
),
)

# debugging only, API may change at any time. This is expected to be
# set by the user in a separate API call.
self._debug_fqn: Optional[str] = None

def forward(self, input: torch.Tensor) -> torch.Tensor:
# Duplicate the autocast logic for F.linear, so that the output
# of our module has the right original precision
Expand All @@ -266,6 +304,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
self.weight.t(),
self.linear_mm_config,
self.config,
self._debug_fqn,
)

if self.bias is not None:
Expand Down
6 changes: 6 additions & 0 deletions torchao/float8/float8_linear_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,3 +196,9 @@ def _auto_filter_for_tensorwise(
if K <= 4096 and N <= 1024:
return False
return True


def _populate_debug_fqns(model: nn.Module):
for name, mod in model.named_modules():
if isinstance(mod, Float8Linear):
mod._debug_fqn = name
9 changes: 9 additions & 0 deletions torchao/float8/float8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,3 +239,12 @@ def pad_tensor_for_matmul(
def _round_scale_down_to_power_of_2(scale: torch.Tensor):
assert scale.dtype == torch.float32, "scale must be float32 tensor"
return torch.exp2(torch.floor(torch.log2(scale)))


@torch.no_grad()
def mean_absolute_percentage_error(x_ref, x):
tmp = torch.abs(x_ref - x) / torch.clamp(torch.abs(x_ref), min=1e-9)
# trim to avoid values close to 0 from
# significantly impacting the results
tmp = torch.clamp(tmp, max=1e3)
return torch.mean(tmp)
Loading