Skip to content

Commit 42363c3

Browse files
authored
[Utils] Deprecate safe_permute (#464)
Signed-off-by: Kyle Sayers <[email protected]>
1 parent 891da51 commit 42363c3

File tree

3 files changed

+30
-63
lines changed

3 files changed

+30
-63
lines changed

src/compressed_tensors/quantization/lifecycle/forward.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
calculate_range,
3030
compute_dynamic_scales_and_zp,
3131
)
32-
from compressed_tensors.utils import safe_permute
3332
from torch.nn import Module
3433

3534

@@ -294,7 +293,7 @@ def _process_quantization(
294293
group_sizes = group_sizes[torch.argsort(group_indices)]
295294

296295
perm = torch.argsort(g_idx)
297-
x = safe_permute(x, perm, dim=1)
296+
x = x.index_select(-1, perm)
298297

299298
# Maintain all dimensions except the last dim, which is divided by group_size
300299
reshaped_dims = (
@@ -328,7 +327,8 @@ def _process_quantization(
328327
output = output.to(output_dtype)
329328

330329
if not is_column_order:
331-
output = safe_permute(output, torch.argsort(perm), dim=1)
330+
inv_perm = torch.argsort(perm)
331+
output = output.index_select(-1, inv_perm)
332332

333333
else: # covers channel, token and tensor strategies
334334
if do_quantize:

src/compressed_tensors/utils/permute.py

Lines changed: 3 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,14 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Set, Tuple
16-
1715
import torch
16+
from compressed_tensors.utils.helpers import deprecated
1817

1918

2019
__all__ = ["safe_permute"]
2120

2221

23-
# these datatypes are missing implementations required for standard permutation
24-
_EXPERIMENTAL_DTYPES: Set[Tuple[torch.dtype, torch.device]] = set()
25-
26-
22+
@deprecated("Tensor.index_select")
2723
def safe_permute(value: torch.Tensor, perm: torch.Tensor, dim: int = 0) -> torch.Tensor:
2824
"""
2925
Perform out-of-place permutation without using torch.Tensor.index_put_,
@@ -34,37 +30,4 @@ def safe_permute(value: torch.Tensor, perm: torch.Tensor, dim: int = 0) -> torch
3430
:param dim: dimension along which to apply permutation
3531
:return: permuted value
3632
"""
37-
dtype_tuple = (value.dtype, value.device)
38-
39-
if dtype_tuple in _EXPERIMENTAL_DTYPES:
40-
return _fallback_permute(value, perm, dim)
41-
42-
try:
43-
return value[tuple([slice(None)] * dim + [perm])]
44-
except RuntimeError:
45-
# Mark dtype as experimental if advanced indexing fails
46-
_EXPERIMENTAL_DTYPES.add(dtype_tuple)
47-
return _fallback_permute(value, perm, dim)
48-
49-
50-
def _fallback_permute(
51-
value: torch.Tensor, perm: torch.Tensor, dim: int
52-
) -> torch.Tensor:
53-
"""
54-
Fallback permutation method for experimental dtypes.
55-
56-
:param value: tensor to permute
57-
:param perm: permutation map
58-
:param dim: dimension along which to apply permutation
59-
:return: permuted value
60-
"""
61-
value_ret = value.clone() # cannot use zeros_like b/c of missing impl.
62-
orig_slices = [slice(None)] * (dim + 1)
63-
perm_slices = [slice(None)] * (dim + 1)
64-
65-
for index, perm_index in enumerate(perm):
66-
orig_slices[dim] = index
67-
perm_slices[dim] = perm_index
68-
value_ret[tuple(orig_slices)] = value[tuple(perm_slices)]
69-
70-
return value_ret
33+
return value.index_select(dim, perm)

tests/test_quantization/lifecycle/test_helpers.py

Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -15,31 +15,35 @@
1515

1616
import pytest
1717
import torch
18-
from compressed_tensors.utils import safe_permute
19-
from compressed_tensors.utils.permute import _EXPERIMENTAL_DTYPES
18+
from compressed_tensors.utils.permute import safe_permute
19+
from tests.testing_utils import requires_gpu
2020

2121

22+
@requires_gpu
23+
@pytest.mark.unit
24+
@pytest.mark.filterwarnings("ignore::DeprecationWarning")
2225
@pytest.mark.parametrize(
23-
"dtype,device,exp_experimental",
26+
"dtype",
2427
[
25-
(torch.int8, torch.device("cpu"), False),
26-
(torch.int16, torch.device("cpu"), False),
27-
(torch.int32, torch.device("cpu"), False),
28-
(torch.int64, torch.device("cpu"), False),
29-
(torch.float16, torch.device("cpu"), False),
30-
(torch.float32, torch.device("cpu"), False),
31-
(torch.float64, torch.device("cpu"), False),
32-
(torch.float8_e4m3fn, torch.device("cpu"), True),
28+
torch.int8,
29+
torch.int16,
30+
torch.int32,
31+
torch.int64,
32+
torch.bfloat16,
33+
torch.float16,
34+
torch.float32,
35+
torch.float64,
36+
torch.float8_e4m3fn,
3337
],
3438
)
35-
def test_safe_permute(dtype: torch.dtype, device: str, exp_experimental: bool):
36-
# some dtypes do not support arange initialization
37-
tensor = torch.tensor([0, 1, 2, 3], dtype=dtype, device=device)
38-
perm = torch.tensor([3, 1, 0, 2])
39-
expected = torch.tensor([3, 1, 0, 2], dtype=dtype, device=device)
39+
@pytest.mark.parametrize(
40+
"device", [torch.device("cpu"), torch.device("cuda"), torch.device("meta")]
41+
)
42+
def test_safe_permute(dtype: torch.dtype, device: torch.device):
43+
value = torch.tensor([[0, 1, 2, 3]], dtype=dtype, device=device)
44+
perm = torch.tensor([3, 1, 0, 2], device=device)
4045

41-
result = safe_permute(tensor, perm, dim=0)
46+
result = safe_permute(value, perm, dim=-1)
4247

43-
if exp_experimental:
44-
assert (dtype, device) in _EXPERIMENTAL_DTYPES
45-
assert all(result == expected)
48+
if device.type != "meta":
49+
assert torch.equal(result.squeeze(0), perm.to(result.dtype))

0 commit comments

Comments
 (0)