diff --git a/src/compressed_tensors/quantization/lifecycle/forward.py b/src/compressed_tensors/quantization/lifecycle/forward.py index 58a16dfba..2e539b070 100644 --- a/src/compressed_tensors/quantization/lifecycle/forward.py +++ b/src/compressed_tensors/quantization/lifecycle/forward.py @@ -29,7 +29,6 @@ calculate_range, compute_dynamic_scales_and_zp, ) -from compressed_tensors.utils import safe_permute from torch.nn import Module @@ -294,7 +293,7 @@ def _process_quantization( group_sizes = group_sizes[torch.argsort(group_indices)] perm = torch.argsort(g_idx) - x = safe_permute(x, perm, dim=1) + x = x.index_select(-1, perm) # Maintain all dimensions except the last dim, which is divided by group_size reshaped_dims = ( @@ -328,7 +327,8 @@ def _process_quantization( output = output.to(output_dtype) if not is_column_order: - output = safe_permute(output, torch.argsort(perm), dim=1) + inv_perm = torch.argsort(perm) + output = output.index_select(-1, inv_perm) else: # covers channel, token and tensor strategies if do_quantize: diff --git a/src/compressed_tensors/utils/permute.py b/src/compressed_tensors/utils/permute.py index e31d4862b..86a0ee805 100644 --- a/src/compressed_tensors/utils/permute.py +++ b/src/compressed_tensors/utils/permute.py @@ -12,18 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Set, Tuple - import torch +from compressed_tensors.utils.helpers import deprecated __all__ = ["safe_permute"] -# these datatypes are missing implementations required for standard permutation -_EXPERIMENTAL_DTYPES: Set[Tuple[torch.dtype, torch.device]] = set() - - +@deprecated("Tensor.index_select") def safe_permute(value: torch.Tensor, perm: torch.Tensor, dim: int = 0) -> torch.Tensor: """ Perform out-of-place permutation without using torch.Tensor.index_put_, @@ -34,37 +30,4 @@ def safe_permute(value: torch.Tensor, perm: torch.Tensor, dim: int = 0) -> torch :param dim: dimension along which to apply permutation :return: permuted value """ - dtype_tuple = (value.dtype, value.device) - - if dtype_tuple in _EXPERIMENTAL_DTYPES: - return _fallback_permute(value, perm, dim) - - try: - return value[tuple([slice(None)] * dim + [perm])] - except RuntimeError: - # Mark dtype as experimental if advanced indexing fails - _EXPERIMENTAL_DTYPES.add(dtype_tuple) - return _fallback_permute(value, perm, dim) - - -def _fallback_permute( - value: torch.Tensor, perm: torch.Tensor, dim: int -) -> torch.Tensor: - """ - Fallback permutation method for experimental dtypes. - - :param value: tensor to permute - :param perm: permutation map - :param dim: dimension along which to apply permutation - :return: permuted value - """ - value_ret = value.clone() # cannot use zeros_like b/c of missing impl. - orig_slices = [slice(None)] * (dim + 1) - perm_slices = [slice(None)] * (dim + 1) - - for index, perm_index in enumerate(perm): - orig_slices[dim] = index - perm_slices[dim] = perm_index - value_ret[tuple(orig_slices)] = value[tuple(perm_slices)] - - return value_ret + return value.index_select(dim, perm) diff --git a/tests/test_quantization/lifecycle/test_helpers.py b/tests/test_quantization/lifecycle/test_helpers.py index 08d916544..20fd39da4 100644 --- a/tests/test_quantization/lifecycle/test_helpers.py +++ b/tests/test_quantization/lifecycle/test_helpers.py @@ -15,31 +15,35 @@ import pytest import torch -from compressed_tensors.utils import safe_permute -from compressed_tensors.utils.permute import _EXPERIMENTAL_DTYPES +from compressed_tensors.utils.permute import safe_permute +from tests.testing_utils import requires_gpu +@requires_gpu +@pytest.mark.unit +@pytest.mark.filterwarnings("ignore::DeprecationWarning") @pytest.mark.parametrize( - "dtype,device,exp_experimental", + "dtype", [ - (torch.int8, torch.device("cpu"), False), - (torch.int16, torch.device("cpu"), False), - (torch.int32, torch.device("cpu"), False), - (torch.int64, torch.device("cpu"), False), - (torch.float16, torch.device("cpu"), False), - (torch.float32, torch.device("cpu"), False), - (torch.float64, torch.device("cpu"), False), - (torch.float8_e4m3fn, torch.device("cpu"), True), + torch.int8, + torch.int16, + torch.int32, + torch.int64, + torch.bfloat16, + torch.float16, + torch.float32, + torch.float64, + torch.float8_e4m3fn, ], ) -def test_safe_permute(dtype: torch.dtype, device: str, exp_experimental: bool): - # some dtypes do not support arange initialization - tensor = torch.tensor([0, 1, 2, 3], dtype=dtype, device=device) - perm = torch.tensor([3, 1, 0, 2]) - expected = torch.tensor([3, 1, 0, 2], dtype=dtype, device=device) +@pytest.mark.parametrize( + "device", [torch.device("cpu"), torch.device("cuda"), torch.device("meta")] +) +def test_safe_permute(dtype: torch.dtype, device: torch.device): + value = torch.tensor([[0, 1, 2, 3]], dtype=dtype, device=device) + perm = torch.tensor([3, 1, 0, 2], device=device) - result = safe_permute(tensor, perm, dim=0) + result = safe_permute(value, perm, dim=-1) - if exp_experimental: - assert (dtype, device) in _EXPERIMENTAL_DTYPES - assert all(result == expected) + if device.type != "meta": + assert torch.equal(result.squeeze(0), perm.to(result.dtype))