diff --git a/test/quantization/quantize_/workflows/int4/test_int4_tensor.py b/test/quantization/quantize_/workflows/int4/test_int4_tensor.py index 85c2132731..b15ff63093 100644 --- a/test/quantization/quantize_/workflows/int4/test_int4_tensor.py +++ b/test/quantization/quantize_/workflows/int4/test_int4_tensor.py @@ -8,25 +8,21 @@ import torch from torch.testing._internal.common_utils import ( - TestCase, + instantiate_parametrized_tests, + parametrize, run_tests, ) -from torchao.quantization import ( - Int4WeightOnlyConfig, - quantize_, -) +from torchao.quantization import Int4WeightOnlyConfig, quantize_ from torchao.quantization.utils import compute_error -from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_8, - is_sm_at_least_90, -) +from torchao.testing.utils import TorchAOIntegrationTestCase +from torchao.utils import TORCH_VERSION_AT_LEAST_2_8, is_sm_at_least_90 @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_8, "Need pytorch 2.8+") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+") -class TestInt4Tensor(TestCase): +class TestInt4Tensor(TorchAOIntegrationTestCase): def setUp(self): self.config = Int4WeightOnlyConfig( group_size=128, @@ -61,50 +57,46 @@ def test_slice(self): quantize_(dummy, self.config) weight1 = dummy.weight.narrow(0, 0, 64) weight2 = dummy.weight.narrow(1, 0, 128) - self.assertEqual(weight1._data, dummy.weight._data.narrow(0, 0, 64)) + self.assertEqual(weight1.qdata, dummy.weight.qdata.narrow(0, 0, 64)) self.assertEqual(weight1.scale, dummy.weight.scale.narrow(1, 0, 64)) - self.assertEqual(weight2._data, dummy.weight._data.narrow(1, 0, 64)) + self.assertEqual(weight1.zero_point, dummy.weight.zero_point.narrow(1, 0, 64)) + self.assertEqual(weight2.qdata, dummy.weight.qdata.narrow(1, 0, 64)) self.assertEqual(weight2.scale, dummy.weight.scale.narrow(0, 0, 1)) + self.assertEqual(weight2.zero_point, dummy.weight.zero_point.narrow(0, 0, 1)) # check for sliced weight, before and after float8 quantization # does not differ too much input = torch.randn(2, 256, dtype=dtype, device=device) res_ref = dummy1(input) - dummy.weight = torch.nn.Parameter(weight1, requires_grad=False) + dummy.weight = torch.nn.Parameter(weight1.contiguous(), requires_grad=False) res = dummy(input) assert compute_error(res, res_ref) > 20 input = torch.randn(2, 128, dtype=dtype, device=device) res_ref = dummy2(input) - dummy.weight = torch.nn.Parameter(weight2, requires_grad=False) + dummy.weight = torch.nn.Parameter(weight2.contiguous(), requires_grad=False) res = dummy(input) assert compute_error(res, res_ref) > 15 - def test_slice_and_copy_(self): + def test_slice_preserves_aliasing(self): + config = self.config l = torch.nn.Linear(1024, 1024).to("cuda").to(torch.bfloat16) l.weight = torch.nn.Parameter( torch.zeros(1024, 1024, dtype=torch.bfloat16, device="cuda") ) - quantize_(l, self.config) + quantize_(l, config) param = l.weight param_data = param.data param_data = param_data.narrow(0, 0, 512) - assert param.data._data.data_ptr() == param_data._data.data_ptr() + # Making sure the aliasing is preserved in sliced quantized Tensor + assert param.data.qdata.data_ptr() == param_data.qdata.data_ptr() assert param.data.scale.data_ptr() == param_data.scale.data_ptr() assert param.data.zero_point.data_ptr() == param_data.zero_point.data_ptr() - orig_value = param.data._data[0][0].item() - - # dummy_l has random input (shouldn't be 0) - dummy_l = torch.nn.Linear(1024, 1024).to("cuda").to(torch.bfloat16) - quantize_(dummy_l, self.config) - quantized = dummy_l.weight - quantized = quantized.narrow(0, 0, 512) - param_data.copy_(quantized) - - # making sure param.data is updated - assert param.data._data[0][0] != orig_value + def test_slice_and_copy_similar_to_vllm(self): + self._test_slice_and_copy_similar_to_vllm(self.config) + @unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+") def test_bmm(self): class M(torch.nn.Module): def __init__(self, weight): @@ -126,20 +118,103 @@ def forward(self, x): quantized = m(input) self.assertTrue(compute_error(original, quantized) > 18) - def test_to_device(self): + @parametrize( + "sizes", + [ + ((128,), 256, 128), + ((32, 128), 64, 256), + ((2, 32, 128), 64, 256), + ], + ) + def test_to_device(self, sizes): + config = self.config + M, N, K = sizes + dtype = torch.bfloat16 for device in self.GPU_DEVICES: - linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16) - quantize_(linear, self.config) + input_tensor = torch.randn(*M, K, dtype=dtype, device=device) + linear = torch.nn.Linear(K, N, dtype=dtype) + quantize_(linear, config) linear.to(device) + linear(input_tensor) - linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16) - quantize_(linear, self.config) + linear = torch.nn.Linear(K, N, dtype=dtype) + quantize_(linear, config) linear.to(device=device) + linear(input_tensor) - linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16) - quantize_(linear, self.config) + linear = torch.nn.Linear(K, N, dtype=dtype) + quantize_(linear, config) linear.to(device) + linear(input_tensor) + + @parametrize( + "sizes", + [ + ((128,), 256, 128), + ((32, 128), 64, 256), + ((2, 32, 128), 64, 256), + ], + ) + def test_cat(self, sizes): + config = self.config + dtype = torch.bfloat16 + device = "cuda" + M, N, K = sizes + linear1 = torch.nn.Linear(K, N, dtype=dtype, device=device) + linear2 = torch.nn.Linear(K, N, dtype=dtype, device=device) + input_cat1 = torch.randn(*M, K, dtype=dtype, device=device) + + cat_weight1 = torch.cat([linear1.weight, linear2.weight], dim=0) + dummy_linear1 = torch.nn.Linear(K, N, bias=False, dtype=dtype, device=device) + + dummy_linear1.weight = torch.nn.Parameter(cat_weight1) + quantize_(dummy_linear1, config) + + quantize_(linear1, config) + quantize_(linear2, config) + + cat_qweight1 = torch.cat([linear1.weight, linear2.weight], dim=0) + self.assertTrue(cat_qweight1.shape, (2 * N, K)) + self.assertEqual( + dummy_linear1.weight.qdata, + cat_qweight1.qdata, + ) + self.assertEqual( + dummy_linear1.weight.scale, + cat_qweight1.scale, + ) + self.assertEqual( + dummy_linear1.weight.zero_point, + cat_qweight1.zero_point, + ) + + # making sure cat_qweight1 can be used for inference + dummy_linear1.weight = torch.nn.Parameter(cat_qweight1, requires_grad=False) + dummy_linear1(input_cat1) + + # align the scale and zero_point before concatenation + linear2.weight.scale = linear1.weight.scale + linear2.weight.zero_point = linear1.weight.zero_point + cat_qweight2 = torch.cat([linear1.weight, linear2.weight], dim=1) + self.assertTrue(cat_qweight2.shape, (N, 2 * K)) + ref_data = torch.cat( + [ + linear1.weight.qdata, + linear2.weight.qdata, + ], + dim=1, + ) + ref_scale = linear1.weight.scale + ref_zero_point = linear1.weight.zero_point + self.assertEqual(cat_qweight2.qdata, ref_data) + self.assertEqual(cat_qweight2.scale, ref_scale) + self.assertEqual(cat_qweight2.zero_point, ref_zero_point) + + def test_moe_weight_reshape_ops(self): + self._test_moe_weight_reshape_ops(self.config) + +instantiate_parametrized_tests(TestInt4Tensor) if __name__ == "__main__": run_tests() diff --git a/test/test_utils.py b/test/test_utils.py index 3ba2f32613..36d2823495 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -76,7 +76,7 @@ def __init__(self, qdata, attr, device=None): self.qdata = qdata self.attr = attr - l = torch.nn.Linear(1, 1) + l = torch.nn.Linear(2, 3) l.weight = torch.nn.Parameter(MyTensor(l.weight, "attr")) lp_tensor = l.weight # test __tensor_flatten__ and __tensor_unflatten__ @@ -107,18 +107,24 @@ def __init__(self, qdata, attr, device=None): # explicitly testing aten.alias lp_tensor = torch.ops.aten.alias(lp_tensor) lp_tensor = lp_tensor.clone() + # making qdata not contiguous + lp_tensor.qdata = lp_tensor.qdata.transpose(0, 1).contiguous() + lp_tensor.qdata = lp_tensor.qdata.transpose(0, 1) + self.assertFalse(lp_tensor.qdata.is_contiguous()) lp_tensor = lp_tensor.contiguous() + # making sure contiguous call works + self.assertTrue(lp_tensor.qdata.is_contiguous()) # copy_ - another_tensor = torch.nn.Linear(1, 1).weight + another_tensor = torch.nn.Linear(2, 3).weight # attribute has to be the same another_lp_tensor = MyTensor(another_tensor, "attr") # initially tensor values are not the same - self.assertNotEqual(lp_tensor.qdata[0], another_lp_tensor.qdata[0]) + self.assertNotEqual(lp_tensor.qdata[0][0], another_lp_tensor.qdata[0][0]) lp_tensor.copy_(another_lp_tensor) self.assertEqual(lp_tensor.attr, "attr") # after copy_, the tensor values should match - self.assertEqual(lp_tensor.qdata[0], another_lp_tensor.qdata[0]) + self.assertEqual(lp_tensor.qdata[0][0], another_lp_tensor.qdata[0][0]) if __name__ == "__main__": diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 5d79563ab1..a07297d74a 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -1160,6 +1160,7 @@ def _int4_weight_only_quantize_tensor(weight, config): block_size = tuple([1 for _ in range(weight.ndim - 1)] + [group_size]) if config.VERSION == 2: + block_size = list(block_size) if packing_format == PackingFormat.PRESHUFFLED: new_weight = Int4PreshuffledTensor.from_float( weight, @@ -1168,7 +1169,7 @@ def _int4_weight_only_quantize_tensor(weight, config): ) return new_weight elif packing_format == PackingFormat.PLAIN: - new_weight = Int4Tensor.from_float( + new_weight = Int4Tensor.from_hp( weight, block_size, ) @@ -2212,7 +2213,7 @@ def _(module: torch.nn.Module, config: FbgemmConfig) -> torch.nn.Module: activation_dtype=torch.bfloat16, ) else: - weight = Int4Tensor.from_float( + weight = Int4Tensor.from_hp( module.weight, config.block_size, ) diff --git a/torchao/quantization/quantize_/workflows/int4/int4_tensor.py b/torchao/quantization/quantize_/workflows/int4/int4_tensor.py index 371ab6de2b..ebf36dd644 100644 --- a/torchao/quantization/quantize_/workflows/int4/int4_tensor.py +++ b/torchao/quantization/quantize_/workflows/int4/int4_tensor.py @@ -10,11 +10,7 @@ import torch from torch.utils._python_dispatch import return_and_correct_aliasing -from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_5, - TorchAOBaseTensor, - fill_defaults, -) +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, TorchAOBaseTensor, fill_defaults __all__ = [ "Int4Tensor", @@ -35,10 +31,10 @@ class Int4Tensor(TorchAOBaseTensor): int4 quantization with plain (default) packing format (for all granularities) Tensor Attributes: - _data: packed int4 weight, either 2D (N, K/2) or 3D (B, N, K/2), last dimension is packed - scale: (K/group_size, N) for 2D Tensor, (B, N, K/group_size) for 3D Tensor, where B is batch size, + qdata: packed int4 weight, either 2D (N, K/2) or 3D (B, N, K/2), last dimension is packed + scale: (K/group_size, N) for 2D Tensor, (B, K/group_size, N) for 3D Tensor, where B is batch size, dtype is the same as the original Tensor dtype - zero_point: (K/group_size, N) for 2D Tensor, (B, N, K/group_size) for 3D Tensor, where B is batch size, + zero_point: (K/group_size, N) for 2D Tensor, (B, K/group_size, N) for 3D Tensor, where B is batch size, dtype is the same as the original Tensor dtype Non-Tensor Attributes: @@ -46,64 +42,27 @@ class Int4Tensor(TorchAOBaseTensor): shape: the shape of the original Tensor """ - tensor_data_attrs = ["_data", "scale", "zero_point"] - tensor_attributes = ["block_size", "shape"] + tensor_data_names = ["qdata", "scale", "zero_point"] + tensor_attribute_names = ["block_size", "shape"] - def __new__(cls, _data, scale, zero_point, block_size, shape): + def __new__(cls, qdata, scale, zero_point, block_size, shape): kwargs = {} - kwargs["device"] = _data.device + kwargs["device"] = qdata.device kwargs["dtype"] = scale.dtype kwargs["requires_grad"] = False return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] - def __init__(self, _data, scale, zero_point, block_size, shape): - self._data = _data + def __init__(self, qdata, scale, zero_point, block_size, shape): + self.qdata = qdata self.scale = scale self.zero_point = zero_point self.block_size = block_size - def __tensor_flatten__(self): - return self.tensor_data_attrs, [ - getattr(self, attr) for attr in self.tensor_attributes - ] - - @classmethod - def __tensor_unflatten__( - cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride - ): - return cls( - *[tensor_data_dict[name] for name in cls.tensor_data_attrs], - *tensor_attributes, - ) - - def _apply_fn_to_data(self, fn): - return self.__class__( - *[fn(getattr(self, attr)) for attr in self.tensor_data_attrs], - *[getattr(self, attr) for attr in self.tensor_attributes], - ) - - def __repr__(self): - return ( - f"{self.__class__.__name__}(weight={self._data}, block_size={self.block_size}, " - f"shape={self.shape}, device={self.device}, dtype={self.dtype}, requires_grad={self.requires_grad})" - ) - def _quantization_type(self): return f"shape={self.shape}, block_size={self.block_size}, device={self.device}" - def to(self, *args, **kwargs): - kwargs = self._get_to_kwargs(*args, **kwargs) - device = kwargs.pop("device") - return self.__class__( - self._data.to(device), - self.scale.to(device), - self.zero_point.to(device), - self.block_size, - self.shape, - ) - @classmethod - def from_float( + def from_hp( cls, w: torch.Tensor, block_size: List[int], @@ -135,9 +94,8 @@ def from_float( scale = scale.to(w.dtype) zero_point = zero_point.to(w.dtype) - del w return Int4Tensor( - _data=wq, + qdata=wq, scale=scale, zero_point=zero_point, block_size=block_size, @@ -155,14 +113,21 @@ def _(func, types, args, kwargs): args[1], args[2] if len(args) > 2 else None, ) + assert weight_tensor.qdata.is_contiguous(), "Expected qdata to be contiguous" + assert weight_tensor.scale.is_contiguous(), "Expected scale to be contiguous" + assert weight_tensor.zero_point.is_contiguous(), ( + "Expected zero_point to be contiguous" + ) + orig_act_size = input_tensor.size() orig_out_features = weight_tensor.shape[-2] + input_tensor = input_tensor.reshape(-1, input_tensor.shape[-1]) res = torch.ops.fbgemm.bf16i4bf16_rowwise( input_tensor, - weight_tensor._data.contiguous(), - weight_tensor.scale.contiguous(), - weight_tensor.zero_point.contiguous(), + weight_tensor.qdata, + weight_tensor.scale, + weight_tensor.zero_point, ) res = res.reshape(*orig_act_size[:-1], orig_out_features) if bias is not None: @@ -176,12 +141,17 @@ def _(func, types, args, kwargs): args[0], args[1], ) + assert weight_tensor.qdata.is_contiguous(), "Expected qdata to be contiguous" + assert weight_tensor.scale.is_contiguous(), "Expected scale to be contiguous" + assert weight_tensor.zero_point.is_contiguous(), ( + "Expected zero_point to be contiguous" + ) + orig_act_size = input_tensor.size() orig_out_features = weight_tensor.shape[-2] - res = torch.ops.fbgemm.bf16i4bf16_rowwise_batched( input_tensor, - weight_tensor._data.contiguous(), + weight_tensor.qdata, weight_tensor.scale, weight_tensor.zero_point, ) @@ -189,66 +159,26 @@ def _(func, types, args, kwargs): return res -@implements([aten.detach.default, aten.alias.default]) -def _(func, types, args, kwargs): - return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) - ) - - -@implements(aten.clone.default) -def _(func, types, args, kwargs): - return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) - ) - - -def _same_metadata(self: "Int4Tensor", src: "Int4Tensor") -> bool: - return ( - isinstance(self, Int4Tensor) - and isinstance(src, Int4Tensor) - and self.shape == src.shape - and self._data.shape == src._data.shape - and self.scale.shape == src.scale.shape - and self.zero_point.shape == src.zero_point.shape - and self.block_size == src.block_size - ) - - -@implements(aten.copy_.default) -def _(func, types, args, kwargs): - self = args[0] - src = args[1] - if _same_metadata(self, src): - self_tensors = self.__tensor_flatten__()[0] - for tensor_name in self_tensors: - getattr(self, tensor_name).copy_(getattr(src, tensor_name)) - return - raise ValueError( - f"Not supported args for copy_ due to metadata mismatch: {args[0], args[1]}" - ) - - @implements(aten.slice.Tensor) def _(func, types, args, kwargs): """Only supports slicing for dim == 1 and dim == 2 - _data has dimension: (N, K/2) + qdata has dimension: (N, K/2) scale and zero_point has dimension: (K/groups, N) dim, start, end, step are args that's referring to the original tensor shape - which is (N, K), and we need to map that to the transformed weight shape of _data, + which is (N, K), and we need to map that to the transformed weight shape of qdata, scale and zero_point - when dim == 0: we do a slice on _data dim 0, and on dim 1 of scale and zero_point, + when dim == 0: we do a slice on qdata dim 0, and on dim 1 of scale and zero_point, also adjust the start and end indexes based on the ratio between original shape and the shape - of _data and scale/zero_point + of qdata and scale/zero_point - when dim == 1: we do a slice on _data dim 1 and dim 0 of scale and zero_point and do the + when dim == 1: we do a slice on qdata dim 1 and dim 0 of scale and zero_point and do the same adjustment based on ratio - Note that we need to call slice on the _data, scale and zero_point directly because slice - is an operation that need to preserve aliasing, see `test_slice_and_copy_` in `test_fbgemm_int4` - for + Note that we need to call slice on the qdata, scale and zero_point directly because slice + is an operation that need to preserve aliasing, see `test_slice_preserves_aliasing` and + `test_slice_and_copy_similar_to_vllm` in `test_int4_tensor` for more details """ self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) assert step == 1 @@ -256,10 +186,10 @@ def _(func, types, args, kwargs): if end >= self.shape[dim]: end = self.shape[dim] - assert self._data.ndim == 2, ( - f"Expected packed weight to have dim 2, got {self._data.dim}" + assert self.qdata.ndim == 2, ( + f"Expected packed weight to have dim 2, got {self.qdata.dim}" ) - N, K_by_2 = self._data.shape + N, K_by_2 = self.qdata.shape sz_dim0, sz_dim1 = self.scale.shape data_len = self.shape[dim] @@ -278,7 +208,7 @@ def _(func, types, args, kwargs): args, kwargs, self.__class__( - self._data, + self.qdata, self.scale, self.zero_point, block_size=self.block_size, @@ -294,13 +224,262 @@ def _(func, types, args, kwargs): start_sz = int(start / sz_ratio) end_sz = int(end / sz_ratio) - _data = aten.slice.Tensor(self._data, dim, start_pw, end_pw, step) + qdata = aten.slice.Tensor(self.qdata, dim, start_pw, end_pw, step) scale = aten.slice.Tensor(self.scale, sz_dim, start_sz, end_sz, step) zero_point = aten.slice.Tensor(self.zero_point, sz_dim, start_sz, end_sz, step) - packed_shape0, packed_shape1 = _data.shape + packed_shape0, packed_shape1 = qdata.shape new_shape = (packed_shape0, packed_shape1 * 2) new = self.__class__( - _data, scale, zero_point, block_size=self.block_size, shape=new_shape + qdata, scale, zero_point, block_size=self.block_size, shape=new_shape + ) + return return_and_correct_aliasing(func, args, kwargs, new) + + +@implements(aten.cat.default) +def _(func, types, args, kwargs): + """Concatenate multiple Int4 quantized tensors + + For Int4Tensor, we need to concatenate qdata, scale, and zero_point tensors. + The concatenation behavior depends on the dimension and block_size configuration. + + If the concatenation dimension is not the same as the packed dimension, then we can just concatenate the + qdata, scale and zero_point directly, note that scale and zero_point has reversed dimension order in 2D + If the concatention dimension is the same as block_size, we'll check that scales from all + tensors are equal and use the first scale + """ + tensors, dim = fill_defaults(args, 2, [[], 0]) + if not tensors: + raise ValueError("Cannot concatenate empty list of tensors") + + tensor_0 = tensors[0] + dim = dim % tensor_0.ndim + + # Validate that all tensors have compatible properties + for i in range(1, len(tensors)): + assert tensor_0.qdata.ndim == tensors[i].qdata.ndim + assert tensor_0.scale.ndim == tensors[i].scale.ndim + assert tensor_0.zero_point.ndim == tensors[i].zero_point.ndim + assert tensor_0.block_size == tensors[i].block_size + + qdatas = [t.qdata for t in tensors] + scales = [t.scale for t in tensors] + zero_points = [t.zero_point for t in tensors] + + # Concatenate the quantized data along the specified dimension + cat_qdata = aten.cat.default(qdatas, dim=dim) + + # if concatenation happens in the non-packed dimension, we need to concatenation + # scale and zero_point + if tensor_0.block_size[dim] == 1: + # For scale and zero_point, the concatenation dimension depends on the dimension + # Int4Tensor has scale and zero_point with shape (K/group_size, N) for 2D or (B, K/group_size, N) for 3D + if cat_qdata.ndim == 2: # 2D case + sz_dim = ( + 1 - dim + ) # If concatenating dim 0 (N), use dim 1 for scale; if dim 1 (K), use dim 0 + else: # 3D case + assert cat_qdata.ndim == 3 + if dim in [1, 2]: + sz_dim = 3 - dim + else: + sz_dim = dim + + cat_scale = aten.cat.default(scales, dim=sz_dim) + cat_zero_point = aten.cat.default(zero_points, dim=sz_dim) + + else: + # if concatenation happens in the packed dimension, we just need to verify + # that all scale and zero_points match + for i in range(1, len(tensors)): + assert torch.equal(tensor_0.scale, tensors[i].scale) + assert torch.equal(tensor_0.zero_point, tensors[i].zero_point) + cat_scale = scales[0] + cat_zero_point = zero_points[0] + + # Calculate new shape based on the concatenated qdata shape + new_shape = list(cat_qdata.shape) + new_shape[-1] *= 2 + new_shape = list(new_shape) + + new = Int4Tensor( + cat_qdata, + cat_scale, + cat_zero_point, + tensor_0.block_size, + new_shape, + ) + return return_and_correct_aliasing(func, args, kwargs, new) + + +@implements(aten.transpose.int) +def _(func, types, args, kwargs): + self, dim0, dim1 = args + + # Transpose the quantized data + qdata = self.qdata.transpose(dim0, dim1).contiguous() + if self.scale.ndim == 3: + # since scale/zero_point dimension order is different + # (B, K/group_size, N), we'll need to remap the dim + remapped_dim0 = dim0 + if dim0 in [1, 2]: + remapped_dim0 = 3 - dim0 + + remapped_dim1 = dim1 + if dim1 in [1, 2]: + remapped_dim1 = 3 - dim1 + + scale = self.scale.transpose(remapped_dim0, remapped_dim1) + zero_point = self.zero_point.transpose(remapped_dim0, remapped_dim1) + else: + assert scale.ndim == 2, f"Only support ndim == 2 or 3, got: {scale.ndim}" + remapped_dim0 = 1 - dim0 + remapped_dim1 = 1 - dim1 + scale = self.scale.transpose(remapped_dim0, remapped_dim1) + zero_point = self.zero_point.transpose(remapped_dim0, remapped_dim1) + + # Update block_size by swapping the dimensions + block_size = self.block_size.copy() + block_size[dim0], block_size[dim1] = block_size[dim1], block_size[dim0] + + # Update shape by swapping the dimensions + new_shape = list(self.shape) + new_shape[dim0], new_shape[dim1] = new_shape[dim1], new_shape[dim0] + + new = Int4Tensor( + qdata, + scale, + zero_point, + block_size, + new_shape, + ) + return return_and_correct_aliasing(func, args, kwargs, new) + + +@implements(aten.view.default) +def _(func, types, args, kwargs): + self, size = args + original_shape = self.shape + original_packing_dim = None + for i in range(len(original_shape)): + if original_shape[i] == (self.qdata.shape[i] * 2): + original_packing_dim = i + assert original_packing_dim is not None, "Didn't find a packing_dim" + + if len(original_shape) == 3 and len(size) == 2: + # only support combining the dim 0 and dim1 together + assert original_shape[-1] == size[-1], ( + f"Only support reshaping when last dimension matches, requested: reshaping from {original_shape} to {size}" + ) + # the dim that int4 packing happens + if original_packing_dim in [0, 1]: + packing_dim = 0 + else: + packing_dim = 1 + + block_size = self.block_size.copy() + block_size = [block_size[0] * block_size[1], block_size[2]] + + qdata_shape = size.copy() + qdata_shape[packing_dim] //= 2 + qdata = self.qdata.reshape(*qdata_shape) + sz_shape = [] + for i in range(len(size)): + sz_shape.append(size[i] // block_size[i]) + # scale and zero_point have reversed dimensions + sz_shape[0], sz_shape[1] = sz_shape[1], sz_shape[0] + + scale = self.scale.reshape(*sz_shape) + zero_point = self.zero_point.reshape(*sz_shape) + elif len(original_shape) == 2 and len(size) == 3: + # only support extending the dim 0 to 2, `t.unflatten(0, (num_experts, -1))` + assert original_shape[-1] == size[-1], ( + f"Only support reshaping when last dimension matches, requested: reshaping from {original_shape} to {size}" + ) + if original_packing_dim == 0: + packing_dim = 1 + else: + # original_packing_dim is 1 + packing_dim = 2 + + block_size = self.block_size.copy() + block_size = [1, block_size[0], block_size[1]] + + qdata_shape = size.copy() + qdata_shape[packing_dim] //= 2 + qdata = self.qdata.reshape(*qdata_shape) + + sz_shape = [] + for i in range(len(size)): + sz_shape.append(size[i] // block_size[i]) + + # scale and zero_point have reversed dimensions + sz_shape[1], sz_shape[2] = sz_shape[2], sz_shape[1] + + scale = self.scale.reshape(*sz_shape) + zero_point = self.zero_point.reshape(*sz_shape) + elif len(original_shape) == len(size): + assert all(x == y or y == -1 for x, y in zip(original_shape, size)), ( + f"Only support viewing with match dimensions or -1, got: {original_shape}, {size}" + ) + packing_dim = original_packing_dim + block_size = self.block_size + else: + assert len(original_shape) == 2 and len(size) == 3, ( + f"Only support reshaping from 2D to 3D or from 3D to 2D or between sam ranges, requested: reshaping from {original_shape} to {size}" + ) + + shape = list(qdata.shape) + for i in range(len(shape)): + if i == packing_dim: + shape[i] *= 2 + + new = Int4Tensor( + qdata, + scale, + zero_point, + block_size, + shape, + ) + return return_and_correct_aliasing(func, args, kwargs, new) + + +@implements(aten.squeeze.dim) +def _(func, types, args, kwargs): + self, dim = args + + # Squeeze qdata + qdata = self.qdata.squeeze(dim=dim) + + # For scale and zero_point, we need to squeeze based on the tensor layout + # Int4Tensor has scale and zero_point with shape (K/group_size, N) for 2D or (B, N, K/group_size) for 3D + if self.qdata.ndim == 2: # 2D case + # qdata is (N, K/2), scale/zero_point is (K/group_size, N) + # When squeezing qdata dim, we need to squeeze scale/zero_point in reverse order + sz_dim = 1 - dim + else: # 3D case + # qdata is (B, N, K/2), scale/zero_point is (B, N, K/group_size) + sz_dim = dim + + scale = self.scale.squeeze(dim=sz_dim) + zero_point = self.zero_point.squeeze(dim=sz_dim) + + # Update block_size by removing the squeezed dimension + new_block_size = list(self.block_size) + if len(qdata.shape) < len(new_block_size): + new_block_size.pop(dim) + + # Update shape by removing the squeezed dimension + new_shape = list(self.shape) + if len(qdata.shape) < len(new_shape): + assert new_shape[dim] == 1 + new_shape.pop(dim) + + new = Int4Tensor( + qdata, + scale, + zero_point, + new_block_size, + new_shape, ) return return_and_correct_aliasing(func, args, kwargs, new) diff --git a/torchao/utils.py b/torchao/utils.py index fb82b9f005..9aea6464cb 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -435,7 +435,20 @@ def _implements_common_tensor_ops(cls): aten = torch.ops.aten @implements( - [aten.detach.default, aten.clone.default, aten.alias.default, aten.contiguous] + [ + torch.Tensor.contiguous, + ] + ) + def _(func, types, args, kwargs): + return args[0]._apply_fn_to_data(lambda x: func(x, *args[1:], **kwargs)) + + @implements( + [ + aten.detach.default, + aten.clone.default, + aten.alias.default, + aten.contiguous.default, + ] ) def _(func, types, args, kwargs): return return_and_correct_aliasing(