Skip to content

Commit d06d093

Browse files
committed
basic support
Signed-off-by: Kyle Sayers <[email protected]>
1 parent 9f3f9ac commit d06d093

File tree

5 files changed

+64
-70
lines changed

5 files changed

+64
-70
lines changed

src/compressed_tensors/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
from .compressors import *
2121
from .config import *
22+
from .logger import LoggerConfig, configure_logger, logger
2223
from .quantization import QuantizationConfig, QuantizationStatus
2324
from .utils import *
2425
from .version import *

src/compressed_tensors/quantization/lifecycle/forward.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
calculate_range,
3030
compute_dynamic_scales_and_zp,
3131
)
32-
from compressed_tensors.utils import safe_permute
3332
from torch.nn import Module
3433

3534

@@ -265,8 +264,7 @@ def _process_quantization(
265264
):
266265

267266
output_dtype = dtype if dtype is not None else x.dtype
268-
output = torch.zeros_like(x).to(output_dtype)
269-
columns = output.shape[-1]
267+
columns = x.size(-1)
270268

271269
# TODO: make validation step for inputs
272270

@@ -294,7 +292,7 @@ def _process_quantization(
294292
group_sizes = group_sizes[torch.argsort(group_indices)]
295293

296294
perm = torch.argsort(g_idx)
297-
x = safe_permute(x, perm, dim=1)
295+
x = x.index_select(dim=-1, index=perm)
298296

299297
# Maintain all dimensions except the last dim, which is divided by group_size
300298
reshaped_dims = (
@@ -324,11 +322,11 @@ def _process_quantization(
324322
global_scale=global_scale,
325323
)
326324

327-
output = output.flatten(start_dim=-2)
325+
output = output.flatten(-2, -1)
328326
output = output.to(output_dtype)
329327

330328
if not is_column_order:
331-
output = safe_permute(output, torch.argsort(perm), dim=1)
329+
output = output.index_select(dim=-1, index=torch.argsort(perm))
332330

333331
else: # covers channel, token and tensor strategies
334332
if do_quantize:

src/compressed_tensors/utils/permute.py

Lines changed: 3 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,14 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Set, Tuple
16-
1715
import torch
16+
from compressed_tensors.utils.helpers import deprecated
1817

1918

2019
__all__ = ["safe_permute"]
2120

2221

23-
# these datatypes are missing implementations required for standard permutation
24-
_EXPERIMENTAL_DTYPES: Set[Tuple[torch.dtype, torch.device]] = set()
25-
26-
22+
@deprecated("Tensor.index_select")
2723
def safe_permute(value: torch.Tensor, perm: torch.Tensor, dim: int = 0) -> torch.Tensor:
2824
"""
2925
Perform out-of-place permutation without using torch.Tensor.index_put_,
@@ -34,37 +30,4 @@ def safe_permute(value: torch.Tensor, perm: torch.Tensor, dim: int = 0) -> torch
3430
:param dim: dimension along which to apply permutation
3531
:return: permuted value
3632
"""
37-
dtype_tuple = (value.dtype, value.device)
38-
39-
if dtype_tuple in _EXPERIMENTAL_DTYPES:
40-
return _fallback_permute(value, perm, dim)
41-
42-
try:
43-
return value[tuple([slice(None)] * dim + [perm])]
44-
except RuntimeError:
45-
# Mark dtype as experimental if advanced indexing fails
46-
_EXPERIMENTAL_DTYPES.add(dtype_tuple)
47-
return _fallback_permute(value, perm, dim)
48-
49-
50-
def _fallback_permute(
51-
value: torch.Tensor, perm: torch.Tensor, dim: int
52-
) -> torch.Tensor:
53-
"""
54-
Fallback permutation method for experimental dtypes.
55-
56-
:param value: tensor to permute
57-
:param perm: permutation map
58-
:param dim: dimension along which to apply permutation
59-
:return: permuted value
60-
"""
61-
value_ret = value.clone() # cannot use zeros_like b/c of missing impl.
62-
orig_slices = [slice(None)] * (dim + 1)
63-
perm_slices = [slice(None)] * (dim + 1)
64-
65-
for index, perm_index in enumerate(perm):
66-
orig_slices[dim] = index
67-
perm_slices[dim] = perm_index
68-
value_ret[tuple(orig_slices)] = value[tuple(perm_slices)]
69-
70-
return value_ret
33+
return value.index_select(dim, perm)

tests/test_quantization/lifecycle/test_forward.py

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def test_forward_quantize(
9595

9696

9797
@pytest.mark.parametrize(
98-
"num_bits,type,strategy,group_size,scale,zero_point,g_idx,global_scale",
98+
"num_bits,type,strategy,group_size,scale,zero_point,g_idx,global_scale,batch_size",
9999
[
100100
(
101101
4,
@@ -106,6 +106,7 @@ def test_forward_quantize(
106106
torch.zeros((1,)),
107107
None,
108108
None,
109+
None,
109110
),
110111
(
111112
4,
@@ -116,6 +117,7 @@ def test_forward_quantize(
116117
torch.zeros((512, 8)),
117118
None,
118119
None,
120+
None,
119121
),
120122
(
121123
4,
@@ -126,6 +128,7 @@ def test_forward_quantize(
126128
torch.zeros((512, 8)),
127129
make_dummy_g_idx(1024, 128),
128130
None,
131+
None,
129132
),
130133
(
131134
8,
@@ -136,6 +139,7 @@ def test_forward_quantize(
136139
torch.zeros((1,)),
137140
None,
138141
None,
142+
None,
139143
),
140144
(
141145
8,
@@ -146,6 +150,7 @@ def test_forward_quantize(
146150
torch.zeros((512, 8)),
147151
None,
148152
None,
153+
None,
149154
),
150155
(
151156
8,
@@ -156,6 +161,7 @@ def test_forward_quantize(
156161
torch.zeros((512, 8)),
157162
make_dummy_g_idx(1024, 128),
158163
None,
164+
None,
159165
),
160166
(
161167
8,
@@ -166,6 +172,7 @@ def test_forward_quantize(
166172
torch.zeros((512, 8)),
167173
None,
168174
None,
175+
None,
169176
),
170177
(
171178
8,
@@ -176,17 +183,41 @@ def test_forward_quantize(
176183
torch.zeros((512, 8)),
177184
make_dummy_g_idx(1024, 128),
178185
None,
186+
None,
187+
),
188+
(
189+
8,
190+
"int",
191+
QuantizationStrategy.GROUP,
192+
128,
193+
torch.rand((512, 8)) * 0.01,
194+
torch.zeros((512, 8)),
195+
make_dummy_g_idx(1024, 128),
196+
None,
197+
5,
179198
),
180199
],
181200
)
182-
def test_fake_quantize_2d(
183-
num_bits, type, strategy, group_size, scale, zero_point, g_idx, global_scale
201+
def test_fake_quantize(
202+
num_bits,
203+
type,
204+
strategy,
205+
group_size,
206+
scale,
207+
zero_point,
208+
g_idx,
209+
global_scale,
210+
batch_size,
184211
):
185212
args = QuantizationArgs(
186213
num_bits=num_bits, type=type, strategy=strategy, group_size=group_size
187214
)
188215

189-
x = torch.rand((512, 1024))
216+
if batch_size is None:
217+
x = torch.rand((512, 1024))
218+
else:
219+
x = torch.rand((batch_size, 512, 1024))
220+
190221
fake_quantize(
191222
x=x,
192223
scale=scale,

tests/test_quantization/lifecycle/test_helpers.py

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -15,31 +15,32 @@
1515

1616
import pytest
1717
import torch
18-
from compressed_tensors.utils import safe_permute
19-
from compressed_tensors.utils.permute import _EXPERIMENTAL_DTYPES
18+
from compressed_tensors.utils.permute import safe_permute
19+
from tests.testing_utils import requires_gpu
2020

2121

22+
@requires_gpu
2223
@pytest.mark.parametrize(
23-
"dtype,device,exp_experimental",
24+
"dtype",
2425
[
25-
(torch.int8, torch.device("cpu"), False),
26-
(torch.int16, torch.device("cpu"), False),
27-
(torch.int32, torch.device("cpu"), False),
28-
(torch.int64, torch.device("cpu"), False),
29-
(torch.float16, torch.device("cpu"), False),
30-
(torch.float32, torch.device("cpu"), False),
31-
(torch.float64, torch.device("cpu"), False),
32-
(torch.float8_e4m3fn, torch.device("cpu"), True),
26+
torch.int8,
27+
torch.int16,
28+
torch.int32,
29+
torch.bfloat16,
30+
torch.float16,
31+
torch.float32,
32+
torch.float64,
33+
torch.float8_e4m3fn,
3334
],
3435
)
35-
def test_safe_permute(dtype: torch.dtype, device: str, exp_experimental: bool):
36-
# some dtypes do not support arange initialization
37-
tensor = torch.tensor([0, 1, 2, 3], dtype=dtype, device=device)
38-
perm = torch.tensor([3, 1, 0, 2])
39-
expected = torch.tensor([3, 1, 0, 2], dtype=dtype, device=device)
36+
@pytest.mark.parametrize(
37+
"device", [torch.device("cpu"), torch.device("cuda"), torch.device("meta")]
38+
)
39+
def test_safe_permute(dtype: torch.dtype, device: torch.device):
40+
value = torch.tensor([[0, 1, 2, 3]], dtype=dtype, device=device)
41+
perm = torch.tensor([3, 1, 0, 2], device=device)
4042

41-
result = safe_permute(tensor, perm, dim=0)
43+
result = safe_permute(value, perm, dim=-1)
4244

43-
if exp_experimental:
44-
assert (dtype, device) in _EXPERIMENTAL_DTYPES
45-
assert all(result == expected)
45+
if device.type != "meta":
46+
assert torch.equal(result.squeeze(0), perm.to(result.dtype))

0 commit comments

Comments
 (0)