From 26a5cd2b65a4647994fd1a56452f57cbffc37582 Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Fri, 10 Oct 2025 16:35:03 -0700 Subject: [PATCH 1/5] temp --- .../workflows/float8/float8_tensor.py | 74 +++++++++++++++++++ torchao/utils.py | 9 ++- 2 files changed, 81 insertions(+), 2 deletions(-) diff --git a/torchao/quantization/quantize_/workflows/float8/float8_tensor.py b/torchao/quantization/quantize_/workflows/float8/float8_tensor.py index 47395a15af..f72d561e7f 100644 --- a/torchao/quantization/quantize_/workflows/float8/float8_tensor.py +++ b/torchao/quantization/quantize_/workflows/float8/float8_tensor.py @@ -256,6 +256,21 @@ def _(func, types, args, kwargs): args[1], args[2] if len(args) > 2 else None, ) + return _float8_linear_impl(input_tensor, weight_tensor, bias) + + +@implements([torch.matmul, aten.mm.default]) +def _(func, types, args, kwargs): + input_tensor, weight_tensor = args[0], args[1] + print(f"input = {input_tensor.shape}, weight = {weight_tensor.shape} (before transpose)") + return _float8_linear_impl(input_tensor, weight_tensor.t()) + + +def _float8_linear_impl( + input_tensor: torch.Tensor, + weight_tensor: torch.Tensor, + bias: Optional[torch.Tensor] = None, +) -> torch.Tensor: assert isinstance(weight_tensor, Float8Tensor), ( f"Don't expect to reach here with an override other than weight currently, {type(input_tensor)} {type(weight_tensor)}" ) @@ -665,6 +680,65 @@ def _(func, types, args, kwargs): return return_and_correct_aliasing(func, args, kwargs, new) +@implements(torch.ops.aten.to.dtype_layout) +def _(func, types, args, kwargs): + # only support kwargs for now + assert len(args) == 1 + self = args[0] + # only support dtype, layout, and device for now + for k in kwargs.keys(): + assert k in ["dtype", "layout", "device"] + # only support same dtype and layout + # different dtype and layout has undefined behavior + if "dtype" in kwargs: + assert kwargs["dtype"] == self.dtype + if "layout" in kwargs: + assert kwargs["layout"] == self.layout + # if device is the same, treat this like a no-op + device = kwargs.get("device") + if device == self.device: + return self + # otherwise, move all inner tensors to the new device + new_tensor = self.__class__( + func(self.qdata, device=device), + func(self.scale, device=device), + self.block_size, + self.mm_config, + self.act_quant_kwargs, + self.kernel_preference, + self.dtype + ) + return return_and_correct_aliasing(func, args, kwargs, new_tensor) + + +# This is called during _apply() to see if we can shallow +# copy the content of one tensor into another. For now, +# we only allow shallow copy if both tensors are `Float8Tensor` +@implements(torch._has_compatible_shallow_copy_type) +def _(func, types, args, kwargs): + assert len(args) == 2 + return ( + isinstance(args[0], Float8Tensor) and + isinstance(args[1], Float8Tensor) + ) + + +@implements(aten.t.default) +def _(func, types, args, kwargs): + assert len(args) == 1 + self = args[0] + new_tensor = self.__class__( + self.qdata.t(), + self.scale.t(), + self.block_size, + self.mm_config, + self.act_quant_kwargs, + self.kernel_preference, + self.dtype + ) + return return_and_correct_aliasing(func, args, kwargs, new_tensor) + + Float8Tensor.__module__ = "torchao.quantization" # Allow a model with Float8Tensor weights to be loaded with `weights_only=True` diff --git a/torchao/utils.py b/torchao/utils.py index 5af3e00cfa..9fb1e33f61 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -653,6 +653,7 @@ class MyTensor(torch.Tensor): ... __torch_function__ = classmethod(_dispatch__torch_function__) """ + #print(f"dispatch__torch_function__ {func}, cls = {cls}") kwargs = {} if kwargs is None else kwargs if ( hasattr(cls, "_TORCH_FN_TABLE") @@ -661,8 +662,11 @@ class MyTensor(torch.Tensor): ): return cls._TORCH_FN_TABLE[cls][func](func, types, args, kwargs) with torch._C.DisableTorchFunctionSubclass(): - return func(*args, **kwargs) - + try: + return func(*args, **kwargs) + except Exception as e: + print("func is ", func if func is not None else "n/a", "cls is ", cls if cls is not None else "n/a", "args are", args, "kwargs are ", kwargs) + raise e def _dispatch__torch_dispatch__(cls, func, types, args, kwargs): """Use this util function for a common `__torch_dispatch__` implementation @@ -672,6 +676,7 @@ class MyTensor(torch.Tensor): ... __torch_dispatch__ = classmethod(_dispatch__torch_dispatch__) """ + #print(f"dispatched to {func}, cls is {cls}, types is {types}, args is {args}, kwargs is {kwargs}") if ( hasattr(cls, "_ATEN_OP_TABLE") and cls in cls._ATEN_OP_TABLE From d63cbe97058c895fef469ce776cfb75c9b82f251 Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Sat, 11 Oct 2025 19:22:09 -0700 Subject: [PATCH 2/5] runs without errors --- .../workflows/float8/float8_tensor.py | 33 +++++++++++++++++-- 1 file changed, 30 insertions(+), 3 deletions(-) diff --git a/torchao/quantization/quantize_/workflows/float8/float8_tensor.py b/torchao/quantization/quantize_/workflows/float8/float8_tensor.py index f72d561e7f..9aa187b610 100644 --- a/torchao/quantization/quantize_/workflows/float8/float8_tensor.py +++ b/torchao/quantization/quantize_/workflows/float8/float8_tensor.py @@ -262,10 +262,32 @@ def _(func, types, args, kwargs): @implements([torch.matmul, aten.mm.default]) def _(func, types, args, kwargs): input_tensor, weight_tensor = args[0], args[1] - print(f"input = {input_tensor.shape}, weight = {weight_tensor.shape} (before transpose)") + print(f"input = {input_tensor.shape}, weight = {weight_tensor.shape}, weight.block_size = {weight_tensor.block_size} (before transpose)") return _float8_linear_impl(input_tensor, weight_tensor.t()) +@implements([aten.addmm_.default]) +def _(func, types, args, kwargs): + output_tensor, input_tensor, weight_tensor = ( + args[0], + args[1], + args[2] if len(args) > 2 else None, + ) + print(f"input = {input_tensor.shape}, weight = {weight_tensor.shape}, weight.block_size = {weight_tensor.block_size} (before transpose), output_tensor = {output_tensor.shape}") + out = _float8_linear_impl(input_tensor, weight_tensor.t()) + return output_tensor.copy_(out) + + +@implements(aten.copy_.default) +def _(func, types, args, kwargs): + # For now, just support copying from a Float8Tensor to a Float8Tensor + assert len(args) == 2 + assert isinstance(args[0], Float8Tensor) and isinstance(args[1], Float8Tensor) + args[0].qdata.copy_(args[1].qdata, **kwargs) + args[0].scale.copy_(args[1].scale, **kwargs) + return args[0] + + def _float8_linear_impl( input_tensor: torch.Tensor, weight_tensor: torch.Tensor, @@ -310,10 +332,12 @@ def _float8_linear_impl( wq = weight_tensor.qdata x_scale = input_tensor.scale w_scale = weight_tensor.scale - if _is_rowwise_scaled(weight_tensor): + if True: #_is_rowwise_scaled(weight_tensor): assert _is_rowwise_scaled(input_tensor), ( "Input tensor must be rowwise block size" ) + print(f" * fbgemm op input = {xq.shape}, weight = {wq.shape}, input_scale = {x_scale.shape}, weight_scale = {w_scale.shape}") + wq = wq.contiguous() res = torch.ops.fbgemm.f8f8bf16_rowwise( xq, wq, @@ -323,6 +347,8 @@ def _float8_linear_impl( use_fast_accum=mm_config.use_fast_accum, ).reshape(out_shape) else: + print("weight_tensor failed _is_rowwise_scaled, SHOULDN'T BE HERE!!!!!!") + breakpoint() assert _is_tensorwise_scaled(weight_tensor) assert _is_tensorwise_scaled(input_tensor) res = torch.ops.fbgemm.f8f8bf16( @@ -727,10 +753,11 @@ def _(func, types, args, kwargs): def _(func, types, args, kwargs): assert len(args) == 1 self = args[0] + assert len(self.block_size) == 2 new_tensor = self.__class__( self.qdata.t(), self.scale.t(), - self.block_size, + (self.block_size[1], self.block_size[0]), self.mm_config, self.act_quant_kwargs, self.kernel_preference, From 86e14a187e915b46e34a0ad47e26daddac8a5434 Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Sun, 12 Oct 2025 15:06:00 -0700 Subject: [PATCH 3/5] Fix Float8Tensor view bug --- .../workflows/float8/float8_tensor.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/torchao/quantization/quantize_/workflows/float8/float8_tensor.py b/torchao/quantization/quantize_/workflows/float8/float8_tensor.py index 9aa187b610..deba4a6c17 100644 --- a/torchao/quantization/quantize_/workflows/float8/float8_tensor.py +++ b/torchao/quantization/quantize_/workflows/float8/float8_tensor.py @@ -598,8 +598,22 @@ def _(func, types, args, kwargs): assert original_shape[-1] == size[-1], ( f"Only support reshaping when last dimension matches, requested: reshaping from {original_shape} to {size}" ) - qdata = self.qdata.reshape(*size) - scale = self.scale.reshape(*size) + # TODO(andrew): This is technically not needed for unsloth fp8 RL + # but fixes a bug nonetheless, can do this separately + # Example input shapes: + # self.shape = [6, 363, 4096] + # self.scale.shape = [6, 363, 1] + # self.block_size = [1, 1, 4096] + # size = [-1, 4096] + # + # Example output shapes: + # self.shape = [2178, 4096] + # self.scale.shape = [2178, 1] + # self.block_size = [1, 4096] + new_dim0 = original_shape[0] * original_shape[1] + assert size[0] == new_dim0 or size[0] == -1 + qdata = self.qdata.reshape(new_dim0, -1) + scale = self.scale.reshape(new_dim0, -1) block_size = self.block_size.copy() block_size = [block_size[0] * block_size[1], block_size[2]] elif len(original_shape) == 2 and len(size) == 3: From 19500bf07bcaeca016817fd1f23545ca8a9a7658 Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Sun, 12 Oct 2025 15:07:56 -0700 Subject: [PATCH 4/5] clean up a bit --- .../workflows/float8/float8_tensor.py | 38 ++++++------------- torchao/utils.py | 9 +---- 2 files changed, 13 insertions(+), 34 deletions(-) diff --git a/torchao/quantization/quantize_/workflows/float8/float8_tensor.py b/torchao/quantization/quantize_/workflows/float8/float8_tensor.py index deba4a6c17..fbaae16061 100644 --- a/torchao/quantization/quantize_/workflows/float8/float8_tensor.py +++ b/torchao/quantization/quantize_/workflows/float8/float8_tensor.py @@ -248,8 +248,8 @@ def from_hp( implements_torch_function = Float8Tensor.implements_torch_function -@implements([aten.linear.default]) -@implements_torch_function([torch.nn.functional.linear]) +@implements(aten.linear.default) +@implements_torch_function(torch.nn.functional.linear) def _(func, types, args, kwargs): input_tensor, weight_tensor, bias = ( args[0], @@ -259,35 +259,24 @@ def _(func, types, args, kwargs): return _float8_linear_impl(input_tensor, weight_tensor, bias) -@implements([torch.matmul, aten.mm.default]) +@implements(aten.mm.default) +@implements_torch_function(torch.matmul) def _(func, types, args, kwargs): input_tensor, weight_tensor = args[0], args[1] - print(f"input = {input_tensor.shape}, weight = {weight_tensor.shape}, weight.block_size = {weight_tensor.block_size} (before transpose)") return _float8_linear_impl(input_tensor, weight_tensor.t()) -@implements([aten.addmm_.default]) +@implements(aten.addmm_.default) def _(func, types, args, kwargs): output_tensor, input_tensor, weight_tensor = ( args[0], args[1], args[2] if len(args) > 2 else None, ) - print(f"input = {input_tensor.shape}, weight = {weight_tensor.shape}, weight.block_size = {weight_tensor.block_size} (before transpose), output_tensor = {output_tensor.shape}") out = _float8_linear_impl(input_tensor, weight_tensor.t()) return output_tensor.copy_(out) -@implements(aten.copy_.default) -def _(func, types, args, kwargs): - # For now, just support copying from a Float8Tensor to a Float8Tensor - assert len(args) == 2 - assert isinstance(args[0], Float8Tensor) and isinstance(args[1], Float8Tensor) - args[0].qdata.copy_(args[1].qdata, **kwargs) - args[0].scale.copy_(args[1].scale, **kwargs) - return args[0] - - def _float8_linear_impl( input_tensor: torch.Tensor, weight_tensor: torch.Tensor, @@ -332,11 +321,11 @@ def _float8_linear_impl( wq = weight_tensor.qdata x_scale = input_tensor.scale w_scale = weight_tensor.scale - if True: #_is_rowwise_scaled(weight_tensor): + # TODO: fix this? + if True: # _is_rowwise_scaled(weight_tensor): assert _is_rowwise_scaled(input_tensor), ( "Input tensor must be rowwise block size" ) - print(f" * fbgemm op input = {xq.shape}, weight = {wq.shape}, input_scale = {x_scale.shape}, weight_scale = {w_scale.shape}") wq = wq.contiguous() res = torch.ops.fbgemm.f8f8bf16_rowwise( xq, @@ -347,8 +336,6 @@ def _float8_linear_impl( use_fast_accum=mm_config.use_fast_accum, ).reshape(out_shape) else: - print("weight_tensor failed _is_rowwise_scaled, SHOULDN'T BE HERE!!!!!!") - breakpoint() assert _is_tensorwise_scaled(weight_tensor) assert _is_tensorwise_scaled(input_tensor) res = torch.ops.fbgemm.f8f8bf16( @@ -746,7 +733,7 @@ def _(func, types, args, kwargs): self.mm_config, self.act_quant_kwargs, self.kernel_preference, - self.dtype + self.dtype, ) return return_and_correct_aliasing(func, args, kwargs, new_tensor) @@ -754,13 +741,10 @@ def _(func, types, args, kwargs): # This is called during _apply() to see if we can shallow # copy the content of one tensor into another. For now, # we only allow shallow copy if both tensors are `Float8Tensor` -@implements(torch._has_compatible_shallow_copy_type) +@implements_torch_function(torch._has_compatible_shallow_copy_type) def _(func, types, args, kwargs): assert len(args) == 2 - return ( - isinstance(args[0], Float8Tensor) and - isinstance(args[1], Float8Tensor) - ) + return isinstance(args[0], Float8Tensor) and isinstance(args[1], Float8Tensor) @implements(aten.t.default) @@ -775,7 +759,7 @@ def _(func, types, args, kwargs): self.mm_config, self.act_quant_kwargs, self.kernel_preference, - self.dtype + self.dtype, ) return return_and_correct_aliasing(func, args, kwargs, new_tensor) diff --git a/torchao/utils.py b/torchao/utils.py index 9fb1e33f61..5af3e00cfa 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -653,7 +653,6 @@ class MyTensor(torch.Tensor): ... __torch_function__ = classmethod(_dispatch__torch_function__) """ - #print(f"dispatch__torch_function__ {func}, cls = {cls}") kwargs = {} if kwargs is None else kwargs if ( hasattr(cls, "_TORCH_FN_TABLE") @@ -662,11 +661,8 @@ class MyTensor(torch.Tensor): ): return cls._TORCH_FN_TABLE[cls][func](func, types, args, kwargs) with torch._C.DisableTorchFunctionSubclass(): - try: - return func(*args, **kwargs) - except Exception as e: - print("func is ", func if func is not None else "n/a", "cls is ", cls if cls is not None else "n/a", "args are", args, "kwargs are ", kwargs) - raise e + return func(*args, **kwargs) + def _dispatch__torch_dispatch__(cls, func, types, args, kwargs): """Use this util function for a common `__torch_dispatch__` implementation @@ -676,7 +672,6 @@ class MyTensor(torch.Tensor): ... __torch_dispatch__ = classmethod(_dispatch__torch_dispatch__) """ - #print(f"dispatched to {func}, cls is {cls}, types is {types}, args is {args}, kwargs is {kwargs}") if ( hasattr(cls, "_ATEN_OP_TABLE") and cls in cls._ATEN_OP_TABLE From 3d4cb8dbe12cb440085799816a030f010d185db3 Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Tue, 14 Oct 2025 17:07:07 -0700 Subject: [PATCH 5/5] Dequantize fp8 rowwise in backward --- .../quantize_/workflows/float8/float8_tensor.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/torchao/quantization/quantize_/workflows/float8/float8_tensor.py b/torchao/quantization/quantize_/workflows/float8/float8_tensor.py index fbaae16061..6ff3df419f 100644 --- a/torchao/quantization/quantize_/workflows/float8/float8_tensor.py +++ b/torchao/quantization/quantize_/workflows/float8/float8_tensor.py @@ -286,6 +286,20 @@ def _float8_linear_impl( f"Don't expect to reach here with an override other than weight currently, {type(input_tensor)} {type(weight_tensor)}" ) + # TODO: make this better + # During the backward pass, we transpose the weight tensor, + # so if the weight tensor was originally rowwise quantized, + # now it becomes colwise. In this case, simply dequantize + # the tensor and do a bf16 matmul + is_backward = ( + weight_tensor.block_size[0] == weight_tensor.shape[0] and + weight_tensor.block_size[1] == 1 + ) + if is_backward: + return torch.nn.functional.linear( + input_tensor, weight_tensor.dequantize(), bias, + ) + act_quant_kwargs = weight_tensor.act_quant_kwargs # quantizing activation, if `act_quant_kwargs` is specified if act_quant_kwargs is not None: @@ -321,8 +335,7 @@ def _float8_linear_impl( wq = weight_tensor.qdata x_scale = input_tensor.scale w_scale = weight_tensor.scale - # TODO: fix this? - if True: # _is_rowwise_scaled(weight_tensor): + if _is_rowwise_scaled(weight_tensor): assert _is_rowwise_scaled(input_tensor), ( "Input tensor must be rowwise block size" )