|
15 | 15 |
|
16 | 16 | import pytest
|
17 | 17 | import torch
|
18 |
| -from compressed_tensors.utils import safe_permute |
19 |
| -from compressed_tensors.utils.permute import _EXPERIMENTAL_DTYPES |
| 18 | +from compressed_tensors.utils.permute import safe_permute |
| 19 | +from tests.testing_utils import requires_gpu |
20 | 20 |
|
21 | 21 |
|
| 22 | +@requires_gpu |
22 | 23 | @pytest.mark.parametrize(
|
23 |
| - "dtype,device,exp_experimental", |
| 24 | + "dtype", |
24 | 25 | [
|
25 |
| - (torch.int8, torch.device("cpu"), False), |
26 |
| - (torch.int16, torch.device("cpu"), False), |
27 |
| - (torch.int32, torch.device("cpu"), False), |
28 |
| - (torch.int64, torch.device("cpu"), False), |
29 |
| - (torch.float16, torch.device("cpu"), False), |
30 |
| - (torch.float32, torch.device("cpu"), False), |
31 |
| - (torch.float64, torch.device("cpu"), False), |
32 |
| - (torch.float8_e4m3fn, torch.device("cpu"), True), |
| 26 | + torch.int8, |
| 27 | + torch.int16, |
| 28 | + torch.int32, |
| 29 | + torch.bfloat16, |
| 30 | + torch.float16, |
| 31 | + torch.float32, |
| 32 | + torch.float64, |
| 33 | + torch.float8_e4m3fn, |
33 | 34 | ],
|
34 | 35 | )
|
35 |
| -def test_safe_permute(dtype: torch.dtype, device: str, exp_experimental: bool): |
36 |
| - # some dtypes do not support arange initialization |
37 |
| - tensor = torch.tensor([0, 1, 2, 3], dtype=dtype, device=device) |
38 |
| - perm = torch.tensor([3, 1, 0, 2]) |
39 |
| - expected = torch.tensor([3, 1, 0, 2], dtype=dtype, device=device) |
| 36 | +@pytest.mark.parametrize( |
| 37 | + "device", [torch.device("cpu"), torch.device("cuda"), torch.device("meta")] |
| 38 | +) |
| 39 | +def test_safe_permute(dtype: torch.dtype, device: torch.device): |
| 40 | + value = torch.tensor([[0, 1, 2, 3]], dtype=dtype, device=device) |
| 41 | + perm = torch.tensor([3, 1, 0, 2], device=device) |
40 | 42 |
|
41 |
| - result = safe_permute(tensor, perm, dim=0) |
| 43 | + result = safe_permute(value, perm, dim=-1) |
42 | 44 |
|
43 |
| - if exp_experimental: |
44 |
| - assert (dtype, device) in _EXPERIMENTAL_DTYPES |
45 |
| - assert all(result == expected) |
| 45 | + if device.type != "meta": |
| 46 | + assert torch.equal(result.squeeze(0), perm.to(result.dtype)) |
0 commit comments