Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/compressed_tensors/quantization/lifecycle/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
calculate_range,
compute_dynamic_scales_and_zp,
)
from compressed_tensors.utils import safe_permute
from torch.nn import Module


Expand Down Expand Up @@ -294,7 +293,7 @@ def _process_quantization(
group_sizes = group_sizes[torch.argsort(group_indices)]

perm = torch.argsort(g_idx)
x = safe_permute(x, perm, dim=1)
x = x.index_select(-1, perm)

# Maintain all dimensions except the last dim, which is divided by group_size
reshaped_dims = (
Expand Down Expand Up @@ -328,7 +327,8 @@ def _process_quantization(
output = output.to(output_dtype)

if not is_column_order:
output = safe_permute(output, torch.argsort(perm), dim=1)
inv_perm = torch.argsort(perm)
output = output.index_select(-1, inv_perm)

else: # covers channel, token and tensor strategies
if do_quantize:
Expand Down
43 changes: 3 additions & 40 deletions src/compressed_tensors/utils/permute.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Set, Tuple

import torch
from compressed_tensors.utils.helpers import deprecated


__all__ = ["safe_permute"]


# these datatypes are missing implementations required for standard permutation
_EXPERIMENTAL_DTYPES: Set[Tuple[torch.dtype, torch.device]] = set()


@deprecated("Tensor.index_select")
def safe_permute(value: torch.Tensor, perm: torch.Tensor, dim: int = 0) -> torch.Tensor:
"""
Perform out-of-place permutation without using torch.Tensor.index_put_,
Expand All @@ -34,37 +30,4 @@ def safe_permute(value: torch.Tensor, perm: torch.Tensor, dim: int = 0) -> torch
:param dim: dimension along which to apply permutation
:return: permuted value
"""
dtype_tuple = (value.dtype, value.device)

if dtype_tuple in _EXPERIMENTAL_DTYPES:
return _fallback_permute(value, perm, dim)

try:
return value[tuple([slice(None)] * dim + [perm])]
except RuntimeError:
# Mark dtype as experimental if advanced indexing fails
_EXPERIMENTAL_DTYPES.add(dtype_tuple)
return _fallback_permute(value, perm, dim)


def _fallback_permute(
value: torch.Tensor, perm: torch.Tensor, dim: int
) -> torch.Tensor:
"""
Fallback permutation method for experimental dtypes.
:param value: tensor to permute
:param perm: permutation map
:param dim: dimension along which to apply permutation
:return: permuted value
"""
value_ret = value.clone() # cannot use zeros_like b/c of missing impl.
orig_slices = [slice(None)] * (dim + 1)
perm_slices = [slice(None)] * (dim + 1)

for index, perm_index in enumerate(perm):
orig_slices[dim] = index
perm_slices[dim] = perm_index
value_ret[tuple(orig_slices)] = value[tuple(perm_slices)]

return value_ret
return value.index_select(dim, perm)
44 changes: 24 additions & 20 deletions tests/test_quantization/lifecycle/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,31 +15,35 @@

import pytest
import torch
from compressed_tensors.utils import safe_permute
from compressed_tensors.utils.permute import _EXPERIMENTAL_DTYPES
from compressed_tensors.utils.permute import safe_permute
from tests.testing_utils import requires_gpu


@requires_gpu
@pytest.mark.unit
@pytest.mark.filterwarnings("ignore::DeprecationWarning")
@pytest.mark.parametrize(
"dtype,device,exp_experimental",
"dtype",
[
(torch.int8, torch.device("cpu"), False),
(torch.int16, torch.device("cpu"), False),
(torch.int32, torch.device("cpu"), False),
(torch.int64, torch.device("cpu"), False),
(torch.float16, torch.device("cpu"), False),
(torch.float32, torch.device("cpu"), False),
(torch.float64, torch.device("cpu"), False),
(torch.float8_e4m3fn, torch.device("cpu"), True),
torch.int8,
torch.int16,
torch.int32,
torch.int64,
torch.bfloat16,
torch.float16,
torch.float32,
torch.float64,
torch.float8_e4m3fn,
],
)
def test_safe_permute(dtype: torch.dtype, device: str, exp_experimental: bool):
# some dtypes do not support arange initialization
tensor = torch.tensor([0, 1, 2, 3], dtype=dtype, device=device)
perm = torch.tensor([3, 1, 0, 2])
expected = torch.tensor([3, 1, 0, 2], dtype=dtype, device=device)
@pytest.mark.parametrize(
"device", [torch.device("cpu"), torch.device("cuda"), torch.device("meta")]
)
def test_safe_permute(dtype: torch.dtype, device: torch.device):
value = torch.tensor([[0, 1, 2, 3]], dtype=dtype, device=device)
perm = torch.tensor([3, 1, 0, 2], device=device)

result = safe_permute(tensor, perm, dim=0)
result = safe_permute(value, perm, dim=-1)

if exp_experimental:
assert (dtype, device) in _EXPERIMENTAL_DTYPES
assert all(result == expected)
if device.type != "meta":
assert torch.equal(result.squeeze(0), perm.to(result.dtype))