Skip to content

Commit b593884

Browse files
committed
add compression param; update qdq for batch greater than 1
1 parent 5478b43 commit b593884

File tree

3 files changed

+80
-28
lines changed

3 files changed

+80
-28
lines changed

src/compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,25 @@ def compression_param_names(self) -> Tuple[str]:
6060
"weight_zero_point",
6161
"weight_global_scale",
6262
)
63+
64+
def compression_param_info(
65+
self,
66+
weight_shape: torch.Size,
67+
quantization_args: Optional[QuantizationArgs] = None,
68+
) -> Dict[str, Tuple[torch.Size, torch.dtype]]:
69+
"""
70+
Creates a dictionary of expected shapes and dtypes for each compression
71+
parameter used by the compressor
72+
73+
:param weight_shape: uncompressed weight shape
74+
:param quantization_args: quantization parameters for the weight
75+
:return: dictionary mapping compressed parameter names to shape and dtype
76+
"""
77+
output = {
78+
"weight_packed": (torch.Size((weight_shape[0], weight_shape[1] // 2)), torch.uint8),
79+
}
80+
return output
81+
6382

6483
def compress_weight(
6584
self,

src/compressed_tensors/quantization/lifecycle/forward.py

Lines changed: 34 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -253,13 +253,14 @@ def _process_quantization(
253253
QuantizationStrategy.GROUP,
254254
QuantizationStrategy.TENSOR_GROUP,
255255
):
256+
"""
256257
n_dims = x.shape
257258
if len(n_dims) > 2:
258259
x = x.squeeze(0)
259-
260+
"""
260261
output_dtype = dtype if dtype is not None else x.dtype
261262
output = torch.zeros_like(x).to(output_dtype)
262-
columns = output.shape[1]
263+
columns = output.shape[-1]
263264

264265
# TODO: make validation step for inputs
265266

@@ -289,14 +290,25 @@ def _process_quantization(
289290
perm = torch.argsort(g_idx)
290291
x = safe_permute(x, perm, dim=1)
291292

292-
x = torch.reshape(
293-
x,
294-
(
295-
x.shape[0],
296-
ceil(x.shape[1] / group_size),
297-
group_size,
298-
),
299-
)
293+
if len(x.shape) > 2:
294+
x = torch.reshape(
295+
x,
296+
(
297+
x.shape[0],
298+
x.shape[1],
299+
ceil(x.shape[-1] / group_size),
300+
group_size,
301+
),
302+
)
303+
else:
304+
x = torch.reshape(
305+
x,
306+
(
307+
x.shape[0],
308+
ceil(x.shape[-1] / group_size),
309+
group_size,
310+
),
311+
)
300312

301313
if do_quantize:
302314
output = _quantize(
@@ -319,18 +331,24 @@ def _process_quantization(
319331
global_scale=global_scale,
320332
)
321333

322-
output = torch.reshape(
323-
output,
324-
(output.shape[0], output.shape[1] * output.shape[2]),
325-
)
334+
if len(x.shape) > 3:
335+
output = torch.reshape(
336+
output,
337+
(output.shape[0], output.shape[1], output.shape[-1] * output.shape[-2]),
338+
)
339+
else:
340+
output = torch.reshape(
341+
output,
342+
(output.shape[0], output.shape[-1] * output.shape[-2]),
343+
)
326344

327345
output = output.to(output_dtype)
328346

329347
if not is_column_order:
330348
output = safe_permute(output, torch.argsort(perm), dim=1)
331349

332-
if len(n_dims) > 2:
333-
output = output.unsqueeze(0)
350+
#if len(n_dims) > 2:
351+
# output = output.unsqueeze(0)
334352

335353
else: # covers channel, token and tensor strategies
336354
if do_quantize:

src/compressed_tensors/quantization/utils/helpers.py

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -167,28 +167,43 @@ def compute_dynamic_scales_and_zp(
167167

168168
keep_dims = True
169169
if args.strategy == QuantizationStrategy.TOKEN:
170-
dim = {1, 2}
170+
dim = {0, 1, 2}
171171
reduce_dims = tuple(idx for idx in range(value.ndim) if idx not in dim)
172172
elif args.strategy == QuantizationStrategy.TENSOR:
173173
reduce_dims = None
174174
elif args.strategy in (
175175
QuantizationStrategy.TENSOR_GROUP,
176176
QuantizationStrategy.GROUP,
177177
):
178+
#if len(value.shape) > 2:
179+
# value = value.squeeze(0)
178180
if len(value.shape) > 2:
179-
value = value.squeeze(0)
181+
dim = {0, 1, 2}
182+
else:
183+
dim = {0, 1}
180184

181-
dim = {0, 1}
182-
reduce_dims = tuple(idx for idx in range(3) if idx not in dim)
185+
reduce_dims = tuple(idx for idx in range(len(value.shape) + 1) if idx not in dim)
183186
keep_dims = False
184-
value = torch.reshape(
185-
value,
186-
(
187-
value.shape[0],
188-
math.ceil(value.shape[1] / args.group_size),
189-
args.group_size,
190-
),
191-
)
187+
188+
if len(value.shape) > 2:
189+
value = torch.reshape(
190+
value,
191+
(
192+
value.shape[0],
193+
value.shape[1],
194+
math.ceil(value.shape[-1] / args.group_size),
195+
args.group_size,
196+
),
197+
)
198+
else:
199+
value = torch.reshape(
200+
value,
201+
(
202+
value.shape[0],
203+
math.ceil(value.shape[-1] / args.group_size),
204+
args.group_size,
205+
),
206+
)
192207
else:
193208
supported_strategies = (
194209
QuantizationStrategy.TOKEN,

0 commit comments

Comments
 (0)