Skip to content

Commit 82ee671

Browse files
committed
basic support
Signed-off-by: Kyle Sayers <[email protected]>
1 parent 8d35794 commit 82ee671

File tree

3 files changed

+38
-7
lines changed

3 files changed

+38
-7
lines changed

src/compressed_tensors/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
from .compressors import *
2121
from .config import *
22+
from .logger import LoggerConfig, configure_logger, logger
2223
from .quantization import QuantizationConfig, QuantizationStatus
2324
from .utils import *
2425
from .version import *

src/compressed_tensors/quantization/lifecycle/forward.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -264,8 +264,7 @@ def _process_quantization(
264264
):
265265

266266
output_dtype = dtype if dtype is not None else x.dtype
267-
output = torch.zeros_like(x).to(output_dtype)
268-
columns = output.shape[-1]
267+
columns = x.size(-1)
269268

270269
# TODO: make validation step for inputs
271270

@@ -323,7 +322,7 @@ def _process_quantization(
323322
global_scale=global_scale,
324323
)
325324

326-
output = output.flatten(start_dim=-2)
325+
output = output.flatten(-2, -1)
327326
output = output.to(output_dtype)
328327

329328
if not is_column_order:

tests/test_quantization/lifecycle/test_forward.py

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def test_forward_quantize(
9595

9696

9797
@pytest.mark.parametrize(
98-
"num_bits,type,strategy,group_size,scale,zero_point,g_idx,global_scale",
98+
"num_bits,type,strategy,group_size,scale,zero_point,g_idx,global_scale,batch_size",
9999
[
100100
(
101101
4,
@@ -106,6 +106,7 @@ def test_forward_quantize(
106106
torch.zeros((1,)),
107107
None,
108108
None,
109+
None,
109110
),
110111
(
111112
4,
@@ -116,6 +117,7 @@ def test_forward_quantize(
116117
torch.zeros((512, 8)),
117118
None,
118119
None,
120+
None,
119121
),
120122
(
121123
4,
@@ -126,6 +128,7 @@ def test_forward_quantize(
126128
torch.zeros((512, 8)),
127129
make_dummy_g_idx(1024, 128),
128130
None,
131+
None,
129132
),
130133
(
131134
8,
@@ -136,6 +139,7 @@ def test_forward_quantize(
136139
torch.zeros((1,)),
137140
None,
138141
None,
142+
None,
139143
),
140144
(
141145
8,
@@ -146,6 +150,7 @@ def test_forward_quantize(
146150
torch.zeros((512, 8)),
147151
None,
148152
None,
153+
None,
149154
),
150155
(
151156
8,
@@ -156,6 +161,7 @@ def test_forward_quantize(
156161
torch.zeros((512, 8)),
157162
make_dummy_g_idx(1024, 128),
158163
None,
164+
None,
159165
),
160166
(
161167
8,
@@ -166,6 +172,7 @@ def test_forward_quantize(
166172
torch.zeros((512, 8)),
167173
None,
168174
None,
175+
None,
169176
),
170177
(
171178
8,
@@ -176,17 +183,41 @@ def test_forward_quantize(
176183
torch.zeros((512, 8)),
177184
make_dummy_g_idx(1024, 128),
178185
None,
186+
None,
187+
),
188+
(
189+
8,
190+
"int",
191+
QuantizationStrategy.GROUP,
192+
128,
193+
torch.rand((512, 8)) * 0.01,
194+
torch.zeros((512, 8)),
195+
make_dummy_g_idx(1024, 128),
196+
None,
197+
5,
179198
),
180199
],
181200
)
182-
def test_fake_quantize_2d(
183-
num_bits, type, strategy, group_size, scale, zero_point, g_idx, global_scale
201+
def test_fake_quantize(
202+
num_bits,
203+
type,
204+
strategy,
205+
group_size,
206+
scale,
207+
zero_point,
208+
g_idx,
209+
global_scale,
210+
batch_size,
184211
):
185212
args = QuantizationArgs(
186213
num_bits=num_bits, type=type, strategy=strategy, group_size=group_size
187214
)
188215

189-
x = torch.rand((512, 1024))
216+
if batch_size is None:
217+
x = torch.rand((512, 1024))
218+
else:
219+
x = torch.rand((batch_size, 512, 1024))
220+
190221
fake_quantize(
191222
x=x,
192223
scale=scale,

0 commit comments

Comments
 (0)