diff --git a/torchao/quantization/quantize_/workflows/float8/float8_tensor.py b/torchao/quantization/quantize_/workflows/float8/float8_tensor.py index 47395a15af..6ff3df419f 100644 --- a/torchao/quantization/quantize_/workflows/float8/float8_tensor.py +++ b/torchao/quantization/quantize_/workflows/float8/float8_tensor.py @@ -248,18 +248,58 @@ def from_hp( implements_torch_function = Float8Tensor.implements_torch_function -@implements([aten.linear.default]) -@implements_torch_function([torch.nn.functional.linear]) +@implements(aten.linear.default) +@implements_torch_function(torch.nn.functional.linear) def _(func, types, args, kwargs): input_tensor, weight_tensor, bias = ( args[0], args[1], args[2] if len(args) > 2 else None, ) + return _float8_linear_impl(input_tensor, weight_tensor, bias) + + +@implements(aten.mm.default) +@implements_torch_function(torch.matmul) +def _(func, types, args, kwargs): + input_tensor, weight_tensor = args[0], args[1] + return _float8_linear_impl(input_tensor, weight_tensor.t()) + + +@implements(aten.addmm_.default) +def _(func, types, args, kwargs): + output_tensor, input_tensor, weight_tensor = ( + args[0], + args[1], + args[2] if len(args) > 2 else None, + ) + out = _float8_linear_impl(input_tensor, weight_tensor.t()) + return output_tensor.copy_(out) + + +def _float8_linear_impl( + input_tensor: torch.Tensor, + weight_tensor: torch.Tensor, + bias: Optional[torch.Tensor] = None, +) -> torch.Tensor: assert isinstance(weight_tensor, Float8Tensor), ( f"Don't expect to reach here with an override other than weight currently, {type(input_tensor)} {type(weight_tensor)}" ) + # TODO: make this better + # During the backward pass, we transpose the weight tensor, + # so if the weight tensor was originally rowwise quantized, + # now it becomes colwise. In this case, simply dequantize + # the tensor and do a bf16 matmul + is_backward = ( + weight_tensor.block_size[0] == weight_tensor.shape[0] and + weight_tensor.block_size[1] == 1 + ) + if is_backward: + return torch.nn.functional.linear( + input_tensor, weight_tensor.dequantize(), bias, + ) + act_quant_kwargs = weight_tensor.act_quant_kwargs # quantizing activation, if `act_quant_kwargs` is specified if act_quant_kwargs is not None: @@ -299,6 +339,7 @@ def _(func, types, args, kwargs): assert _is_rowwise_scaled(input_tensor), ( "Input tensor must be rowwise block size" ) + wq = wq.contiguous() res = torch.ops.fbgemm.f8f8bf16_rowwise( xq, wq, @@ -557,8 +598,22 @@ def _(func, types, args, kwargs): assert original_shape[-1] == size[-1], ( f"Only support reshaping when last dimension matches, requested: reshaping from {original_shape} to {size}" ) - qdata = self.qdata.reshape(*size) - scale = self.scale.reshape(*size) + # TODO(andrew): This is technically not needed for unsloth fp8 RL + # but fixes a bug nonetheless, can do this separately + # Example input shapes: + # self.shape = [6, 363, 4096] + # self.scale.shape = [6, 363, 1] + # self.block_size = [1, 1, 4096] + # size = [-1, 4096] + # + # Example output shapes: + # self.shape = [2178, 4096] + # self.scale.shape = [2178, 1] + # self.block_size = [1, 4096] + new_dim0 = original_shape[0] * original_shape[1] + assert size[0] == new_dim0 or size[0] == -1 + qdata = self.qdata.reshape(new_dim0, -1) + scale = self.scale.reshape(new_dim0, -1) block_size = self.block_size.copy() block_size = [block_size[0] * block_size[1], block_size[2]] elif len(original_shape) == 2 and len(size) == 3: @@ -665,6 +720,63 @@ def _(func, types, args, kwargs): return return_and_correct_aliasing(func, args, kwargs, new) +@implements(torch.ops.aten.to.dtype_layout) +def _(func, types, args, kwargs): + # only support kwargs for now + assert len(args) == 1 + self = args[0] + # only support dtype, layout, and device for now + for k in kwargs.keys(): + assert k in ["dtype", "layout", "device"] + # only support same dtype and layout + # different dtype and layout has undefined behavior + if "dtype" in kwargs: + assert kwargs["dtype"] == self.dtype + if "layout" in kwargs: + assert kwargs["layout"] == self.layout + # if device is the same, treat this like a no-op + device = kwargs.get("device") + if device == self.device: + return self + # otherwise, move all inner tensors to the new device + new_tensor = self.__class__( + func(self.qdata, device=device), + func(self.scale, device=device), + self.block_size, + self.mm_config, + self.act_quant_kwargs, + self.kernel_preference, + self.dtype, + ) + return return_and_correct_aliasing(func, args, kwargs, new_tensor) + + +# This is called during _apply() to see if we can shallow +# copy the content of one tensor into another. For now, +# we only allow shallow copy if both tensors are `Float8Tensor` +@implements_torch_function(torch._has_compatible_shallow_copy_type) +def _(func, types, args, kwargs): + assert len(args) == 2 + return isinstance(args[0], Float8Tensor) and isinstance(args[1], Float8Tensor) + + +@implements(aten.t.default) +def _(func, types, args, kwargs): + assert len(args) == 1 + self = args[0] + assert len(self.block_size) == 2 + new_tensor = self.__class__( + self.qdata.t(), + self.scale.t(), + (self.block_size[1], self.block_size[0]), + self.mm_config, + self.act_quant_kwargs, + self.kernel_preference, + self.dtype, + ) + return return_and_correct_aliasing(func, args, kwargs, new_tensor) + + Float8Tensor.__module__ = "torchao.quantization" # Allow a model with Float8Tensor weights to be loaded with `weights_only=True`