Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 116 additions & 4 deletions torchao/quantization/quantize_/workflows/float8/float8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,18 +248,58 @@ def from_hp(
implements_torch_function = Float8Tensor.implements_torch_function


@implements([aten.linear.default])
@implements_torch_function([torch.nn.functional.linear])
@implements(aten.linear.default)
@implements_torch_function(torch.nn.functional.linear)
def _(func, types, args, kwargs):
input_tensor, weight_tensor, bias = (
args[0],
args[1],
args[2] if len(args) > 2 else None,
)
return _float8_linear_impl(input_tensor, weight_tensor, bias)


@implements(aten.mm.default)
@implements_torch_function(torch.matmul)
def _(func, types, args, kwargs):
input_tensor, weight_tensor = args[0], args[1]
return _float8_linear_impl(input_tensor, weight_tensor.t())


@implements(aten.addmm_.default)
def _(func, types, args, kwargs):
output_tensor, input_tensor, weight_tensor = (
args[0],
args[1],
args[2] if len(args) > 2 else None,
)
out = _float8_linear_impl(input_tensor, weight_tensor.t())
return output_tensor.copy_(out)


def _float8_linear_impl(
input_tensor: torch.Tensor,
weight_tensor: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert isinstance(weight_tensor, Float8Tensor), (
f"Don't expect to reach here with an override other than weight currently, {type(input_tensor)} {type(weight_tensor)}"
)

# TODO: make this better
# During the backward pass, we transpose the weight tensor,
# so if the weight tensor was originally rowwise quantized,
# now it becomes colwise. In this case, simply dequantize
# the tensor and do a bf16 matmul
is_backward = (
weight_tensor.block_size[0] == weight_tensor.shape[0] and
weight_tensor.block_size[1] == 1
)
if is_backward:
return torch.nn.functional.linear(
input_tensor, weight_tensor.dequantize(), bias,
)

act_quant_kwargs = weight_tensor.act_quant_kwargs
# quantizing activation, if `act_quant_kwargs` is specified
if act_quant_kwargs is not None:
Expand Down Expand Up @@ -299,6 +339,7 @@ def _(func, types, args, kwargs):
assert _is_rowwise_scaled(input_tensor), (
"Input tensor must be rowwise block size"
)
wq = wq.contiguous()
res = torch.ops.fbgemm.f8f8bf16_rowwise(
xq,
wq,
Expand Down Expand Up @@ -557,8 +598,22 @@ def _(func, types, args, kwargs):
assert original_shape[-1] == size[-1], (
f"Only support reshaping when last dimension matches, requested: reshaping from {original_shape} to {size}"
)
qdata = self.qdata.reshape(*size)
scale = self.scale.reshape(*size)
# TODO(andrew): This is technically not needed for unsloth fp8 RL
# but fixes a bug nonetheless, can do this separately
# Example input shapes:
# self.shape = [6, 363, 4096]
# self.scale.shape = [6, 363, 1]
# self.block_size = [1, 1, 4096]
# size = [-1, 4096]
#
# Example output shapes:
# self.shape = [2178, 4096]
# self.scale.shape = [2178, 1]
# self.block_size = [1, 4096]
new_dim0 = original_shape[0] * original_shape[1]
assert size[0] == new_dim0 or size[0] == -1
qdata = self.qdata.reshape(new_dim0, -1)
scale = self.scale.reshape(new_dim0, -1)
block_size = self.block_size.copy()
block_size = [block_size[0] * block_size[1], block_size[2]]
elif len(original_shape) == 2 and len(size) == 3:
Expand Down Expand Up @@ -665,6 +720,63 @@ def _(func, types, args, kwargs):
return return_and_correct_aliasing(func, args, kwargs, new)


@implements(torch.ops.aten.to.dtype_layout)
def _(func, types, args, kwargs):
# only support kwargs for now
assert len(args) == 1
self = args[0]
# only support dtype, layout, and device for now
for k in kwargs.keys():
assert k in ["dtype", "layout", "device"]
# only support same dtype and layout
# different dtype and layout has undefined behavior
if "dtype" in kwargs:
assert kwargs["dtype"] == self.dtype
if "layout" in kwargs:
assert kwargs["layout"] == self.layout
# if device is the same, treat this like a no-op
device = kwargs.get("device")
if device == self.device:
return self
# otherwise, move all inner tensors to the new device
new_tensor = self.__class__(
func(self.qdata, device=device),
func(self.scale, device=device),
self.block_size,
self.mm_config,
self.act_quant_kwargs,
self.kernel_preference,
self.dtype,
)
return return_and_correct_aliasing(func, args, kwargs, new_tensor)


# This is called during _apply() to see if we can shallow
# copy the content of one tensor into another. For now,
# we only allow shallow copy if both tensors are `Float8Tensor`
@implements_torch_function(torch._has_compatible_shallow_copy_type)
def _(func, types, args, kwargs):
assert len(args) == 2
return isinstance(args[0], Float8Tensor) and isinstance(args[1], Float8Tensor)


@implements(aten.t.default)
def _(func, types, args, kwargs):
assert len(args) == 1
self = args[0]
assert len(self.block_size) == 2
new_tensor = self.__class__(
self.qdata.t(),
self.scale.t(),
(self.block_size[1], self.block_size[0]),
self.mm_config,
self.act_quant_kwargs,
self.kernel_preference,
self.dtype,
)
return return_and_correct_aliasing(func, args, kwargs, new_tensor)


Float8Tensor.__module__ = "torchao.quantization"

# Allow a model with Float8Tensor weights to be loaded with `weights_only=True`
Expand Down
Loading