diff --git a/test/float8/test_base.py b/test/float8/test_base.py index c19478e02a..cec466ea1d 100644 --- a/test/float8/test_base.py +++ b/test/float8/test_base.py @@ -370,8 +370,9 @@ def test_linear_from_config_params( @pytest.mark.parametrize( "recipe_name", [ - Float8LinearRecipeName.ROWWISE, - Float8LinearRecipeName.ROWWISE_WITH_GW_HP, + # Float8LinearRecipeName.ROWWISE, + # Float8LinearRecipeName.ROWWISE_WITH_GW_HP, + Float8LinearRecipeName.FWD_ROWWISE_GI_ROWWISE_GW_HP, ], ) @pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)]) diff --git a/torchao/float8/config.py b/torchao/float8/config.py index b362390946..81f99b9e5f 100644 --- a/torchao/float8/config.py +++ b/torchao/float8/config.py @@ -134,6 +134,12 @@ class Float8LinearRecipeName(enum.Enum): # * the e4m3 dtype is used across the board, including for gradients ROWWISE_WITH_GW_HP = "rowwise_with_gw_hp" + # debug only, not for land + FWD_FLOAT8_BWD_HP = "fwd_float8_bwd_hp" + + # debug only, not for land + FWD_ROWWISE_GI_ROWWISE_GW_HP = "fwd_rowwise_gi_rowwise_gw_hp" + @dataclass(frozen=True) class Float8LinearConfig: @@ -336,5 +342,59 @@ def from_recipe_name( round_scales_to_power_of_2=True, ) + elif recipe_name is Float8LinearRecipeName.FWD_FLOAT8_BWD_HP: + # output_hp = input_fp8_axiswise_dim0 @ weight_t_axiswise_dim1 + cc_i = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE) + cc_w = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE) + + # grad_input_hp = grad_output_hp @ weight_hp + cc_go = CastConfig(scaling_type=ScalingType.DISABLED) + cc_w_gi = CastConfig(scaling_type=ScalingType.DISABLED) + + # grad_weight_hp = input_t_hp @ grad_output_hp + cc_i_gw = CastConfig(scaling_type=ScalingType.DISABLED) + cc_go_gw = CastConfig( + scaling_type=ScalingType.DISABLED, target_dtype=e4m3_dtype + ) + + return Float8LinearConfig( + cast_config_input=cc_i, + cast_config_weight=cc_w, + cast_config_grad_output=cc_go, + cast_config_input_for_grad_weight=cc_i_gw, + cast_config_weight_for_grad_input=cc_w_gi, + cast_config_grad_output_for_grad_weight=cc_go_gw, + round_scales_to_power_of_2=True, + ) + + elif recipe_name is Float8LinearRecipeName.FWD_ROWWISE_GI_ROWWISE_GW_HP: + # output_hp = input_fp8_axiswise_dim0 @ weight_t_axiswise_dim1 + cc_i = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE) + cc_w = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE) + + # grad_input_hp = grad_output_fp8_axiswise_dim0 @ weight_axiswise_dim1 + cc_go = CastConfig( + scaling_granularity=ScalingGranularity.AXISWISE, target_dtype=e4m3_dtype + ) + cc_w_gi = CastConfig( + scaling_granularity=ScalingGranularity.AXISWISE, target_dtype=e4m3_dtype + ) + + # grad_weight_hp = input_t_hp @ grad_output_hp + cc_i_gw = CastConfig(scaling_type=ScalingType.DISABLED) + cc_go_gw = CastConfig( + scaling_type=ScalingType.DISABLED, target_dtype=e4m3_dtype + ) + + return Float8LinearConfig( + cast_config_input=cc_i, + cast_config_weight=cc_w, + cast_config_grad_output=cc_go, + cast_config_input_for_grad_weight=cc_i_gw, + cast_config_weight_for_grad_input=cc_w_gi, + cast_config_grad_output_for_grad_weight=cc_go_gw, + round_scales_to_power_of_2=True, + ) + else: raise AssertionError(f"unknown recipe_name {recipe_name}")