diff --git a/test/float8/test_base.py b/test/float8/test_base.py index c19478e02a..62bfee00f5 100644 --- a/test/float8/test_base.py +++ b/test/float8/test_base.py @@ -33,7 +33,10 @@ e5m2_dtype, ) from torchao.float8.float8_linear import Float8Linear -from torchao.float8.float8_linear_utils import convert_to_float8_training +from torchao.float8.float8_linear_utils import ( + _populate_debug_fqns, + convert_to_float8_training, +) from torchao.float8.float8_ops import addmm_float8_unwrapped from torchao.float8.float8_scaling_utils import ( get_maybe_axiswise_dim, @@ -400,6 +403,32 @@ def test_linear_from_recipe( config, ) + @pytest.mark.parametrize( + "recipe_name", + [ + Float8LinearRecipeName.TENSORWISE, + Float8LinearRecipeName.ROWWISE, + Float8LinearRecipeName.ROWWISE_WITH_GW_HP, + ], + ) + def test_debug_logging(self, recipe_name): + x = torch.randn(1, 16, 16, device="cuda", dtype=torch.bfloat16) + m = nn.Sequential( + nn.Linear(16, 32, bias=False, device="cuda", dtype=torch.bfloat16), + nn.Sequential( + nn.ReLU(), + nn.Linear(32, 64, bias=False, device="cuda", dtype=torch.bfloat16), + ), + ) + config = Float8LinearConfig.from_recipe_name(recipe_name) + object.__setattr__(config, "_enable_debug_logging", True) + m = convert_to_float8_training(m, config=config) + _populate_debug_fqns(m) + m = torch.compile(m) + y = m(x) + y.sum().backward() + # TODO(before land): actually test the values logged to stdout + @pytest.mark.parametrize( "emulate", [True, False] if is_sm_at_least_89() else [True] ) diff --git a/torchao/float8/config.py b/torchao/float8/config.py index 939f68e59a..eb7a871a87 100644 --- a/torchao/float8/config.py +++ b/torchao/float8/config.py @@ -204,6 +204,12 @@ class Float8LinearConfig: # same value in the forward pass as the backward passes. round_scales_to_power_of_2: bool = False + # If True, captures accuracy debugging logging comparing high precision gemm + # outputs to their low precision versions, and outputs it to stdout + # Note: this flag is in prototype, has not been extensively tested and the + # API may change. + _enable_debug_logging: bool = False + def __post_init__(self): # Populate the additional cast overrides, if the user did not specify them # Note: this hacks around the frozen-ness of this dataclass diff --git a/torchao/float8/float8_linear.py b/torchao/float8/float8_linear.py index a946835a4d..cc511c9675 100644 --- a/torchao/float8/float8_linear.py +++ b/torchao/float8/float8_linear.py @@ -22,10 +22,19 @@ LinearMMConfig, ScaledMMConfig, ) +from torchao.float8.float8_utils import mean_absolute_percentage_error from torchao.float8.fsdp_utils import WeightWithDynamicFloat8CastTensor -@torch._dynamo.allow_in_graph +@torch._dynamo.disable +def log_mape(fqn, gemm_name, mape): + # TODO(future): use logging instead of print, will be more annoying to test + # with pytest + print(f"fqn: {fqn}, gemm_name: {gemm_name}, mape: {mape}") + + +# note: need to remove torch._dynamo.allow_in_graph for logging to work with torch.compile +# @torch._dynamo.allow_in_graph class matmul_with_hp_or_float8_args(torch.autograd.Function): """ Like torch.matmul, but with the arguments in either high precision or float8. @@ -41,10 +50,12 @@ def forward( weight_hp_t: torch.Tensor, linear_mm_config: LinearMMConfig, config: Float8LinearConfig, + debug_fqn: Optional[str], ): ctx.save_for_backward(input_hp, weight_hp_t) ctx.linear_mm_config = linear_mm_config ctx.config = config + ctx.debug_fqn = debug_fqn c = config @@ -87,6 +98,13 @@ def forward( orig_shape = input_maybe_fp8.shape input_maybe_fp8_reshaped = input_maybe_fp8.reshape(-1, orig_shape[-1]) res_bits = torch.mm(input_maybe_fp8_reshaped, weight_maybe_fp8_t) + + if config._enable_debug_logging: + input_hp_reshaped = input_hp.reshape(-1, orig_shape[-1]) + ref_result = torch.mm(input_hp_reshaped, weight_hp_t) + output_mape = mean_absolute_percentage_error(ref_result, res_bits) + log_mape(debug_fqn, "output", output_mape) + res_bits = res_bits.reshape(*orig_shape[:-1], res_bits.shape[-1]) return res_bits @@ -94,6 +112,7 @@ def forward( def backward(ctx, grad_output): input_hp, weight_hp_t = ctx.saved_tensors c = ctx.config + debug_fqn = ctx.debug_fqn # the reshapes are needed in order to make the shapes compatible with # torch.mm @@ -144,6 +163,10 @@ def backward(ctx, grad_output): grad_output_reshaped_maybe_fp8_dim0, weight_t_maybe_fp8_dim0.t(), ) + if c._enable_debug_logging: + ref_grad_input = torch.mm(grad_output_reshaped, weight_hp_t.t()) + grad_input_mape = mean_absolute_percentage_error(ref_grad_input, grad_input) + log_mape(debug_fqn, "grad_input", grad_input_mape) grad_input = grad_input.reshape( *grad_output_orig_shape[:-1], grad_input.shape[-1] ) @@ -198,8 +221,19 @@ def backward(ctx, grad_output): grad_output_reshaped_maybe_fp8_dim1.t(), input_reshaped_maybe_fp8_dim1, ) + if c._enable_debug_logging: + # don't log if this gemm is in high precision + this_gemm_is_hp = ( + c.cast_config_input_for_grad_weight.scaling_type is ScalingType.DISABLED + ) + if not this_gemm_is_hp: + ref_grad_weight = torch.mm(grad_output_reshaped.t(), input_hp_reshaped) + grad_weight_mape = mean_absolute_percentage_error( + ref_grad_weight, grad_weight + ) + log_mape(debug_fqn, "grad_weight", grad_weight_mape) - empty_grads = None, None + empty_grads = None, None, None return grad_input, grad_weight.t(), *empty_grads @@ -252,6 +286,10 @@ def __init__(self, *args, **kwargs): ), ) + # debugging only, API may change at any time. This is expected to be + # set by the user in a separate API call. + self._debug_fqn: Optional[str] = None + def forward(self, input: torch.Tensor) -> torch.Tensor: # Duplicate the autocast logic for F.linear, so that the output # of our module has the right original precision @@ -266,6 +304,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: self.weight.t(), self.linear_mm_config, self.config, + self._debug_fqn, ) if self.bias is not None: diff --git a/torchao/float8/float8_linear_utils.py b/torchao/float8/float8_linear_utils.py index 0d9674e6c3..758d1646d5 100644 --- a/torchao/float8/float8_linear_utils.py +++ b/torchao/float8/float8_linear_utils.py @@ -196,3 +196,9 @@ def _auto_filter_for_tensorwise( if K <= 4096 and N <= 1024: return False return True + + +def _populate_debug_fqns(model: nn.Module): + for name, mod in model.named_modules(): + if isinstance(mod, Float8Linear): + mod._debug_fqn = name diff --git a/torchao/float8/float8_utils.py b/torchao/float8/float8_utils.py index 625fb29235..0fa70e407f 100644 --- a/torchao/float8/float8_utils.py +++ b/torchao/float8/float8_utils.py @@ -239,3 +239,12 @@ def pad_tensor_for_matmul( def _round_scale_down_to_power_of_2(scale: torch.Tensor): assert scale.dtype == torch.float32, "scale must be float32 tensor" return torch.exp2(torch.floor(torch.log2(scale))) + + +@torch.no_grad() +def mean_absolute_percentage_error(x_ref, x): + tmp = torch.abs(x_ref - x) / torch.clamp(torch.abs(x_ref), min=1e-9) + # trim to avoid values close to 0 from + # significantly impacting the results + tmp = torch.clamp(tmp, max=1e3) + return torch.mean(tmp)