Skip to content

Commit be89690

Browse files
committed
add compression param; update qdq for batch greater than 1
1 parent 3d49764 commit be89690

File tree

3 files changed

+80
-28
lines changed

3 files changed

+80
-28
lines changed

src/compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,25 @@ def compression_param_names(self) -> Tuple[str]:
6060
"weight_zero_point",
6161
"weight_global_scale",
6262
)
63+
64+
def compression_param_info(
65+
self,
66+
weight_shape: torch.Size,
67+
quantization_args: Optional[QuantizationArgs] = None,
68+
) -> Dict[str, Tuple[torch.Size, torch.dtype]]:
69+
"""
70+
Creates a dictionary of expected shapes and dtypes for each compression
71+
parameter used by the compressor
72+
73+
:param weight_shape: uncompressed weight shape
74+
:param quantization_args: quantization parameters for the weight
75+
:return: dictionary mapping compressed parameter names to shape and dtype
76+
"""
77+
output = {
78+
"weight_packed": (torch.Size((weight_shape[0], weight_shape[1] // 2)), torch.uint8),
79+
}
80+
return output
81+
6382

6483
def compress_weight(
6584
self,

src/compressed_tensors/quantization/lifecycle/forward.py

Lines changed: 34 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -257,13 +257,14 @@ def _process_quantization(
257257
QuantizationStrategy.GROUP,
258258
QuantizationStrategy.TENSOR_GROUP,
259259
):
260+
"""
260261
n_dims = x.shape
261262
if len(n_dims) > 2:
262263
x = x.squeeze(0)
263-
264+
"""
264265
output_dtype = dtype if dtype is not None else x.dtype
265266
output = torch.zeros_like(x).to(output_dtype)
266-
columns = output.shape[1]
267+
columns = output.shape[-1]
267268

268269
# TODO: make validation step for inputs
269270

@@ -293,14 +294,25 @@ def _process_quantization(
293294
perm = torch.argsort(g_idx)
294295
x = safe_permute(x, perm, dim=1)
295296

296-
x = torch.reshape(
297-
x,
298-
(
299-
x.shape[0],
300-
ceil(x.shape[1] / group_size),
301-
group_size,
302-
),
303-
)
297+
if len(x.shape) > 2:
298+
x = torch.reshape(
299+
x,
300+
(
301+
x.shape[0],
302+
x.shape[1],
303+
ceil(x.shape[-1] / group_size),
304+
group_size,
305+
),
306+
)
307+
else:
308+
x = torch.reshape(
309+
x,
310+
(
311+
x.shape[0],
312+
ceil(x.shape[-1] / group_size),
313+
group_size,
314+
),
315+
)
304316

305317
if do_quantize:
306318
output = _quantize(
@@ -323,18 +335,24 @@ def _process_quantization(
323335
global_scale=global_scale,
324336
)
325337

326-
output = torch.reshape(
327-
output,
328-
(output.shape[0], output.shape[1] * output.shape[2]),
329-
)
338+
if len(x.shape) > 3:
339+
output = torch.reshape(
340+
output,
341+
(output.shape[0], output.shape[1], output.shape[-1] * output.shape[-2]),
342+
)
343+
else:
344+
output = torch.reshape(
345+
output,
346+
(output.shape[0], output.shape[-1] * output.shape[-2]),
347+
)
330348

331349
output = output.to(output_dtype)
332350

333351
if not is_column_order:
334352
output = safe_permute(output, torch.argsort(perm), dim=1)
335353

336-
if len(n_dims) > 2:
337-
output = output.unsqueeze(0)
354+
#if len(n_dims) > 2:
355+
# output = output.unsqueeze(0)
338356

339357
else: # covers channel, token and tensor strategies
340358
if do_quantize:

src/compressed_tensors/quantization/utils/helpers.py

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -167,28 +167,43 @@ def compute_dynamic_scales_and_zp(
167167

168168
keep_dims = True
169169
if args.strategy == QuantizationStrategy.TOKEN:
170-
dim = {1, 2}
170+
dim = {0, 1, 2}
171171
reduce_dims = tuple(idx for idx in range(value.ndim) if idx not in dim)
172172
elif args.strategy == QuantizationStrategy.TENSOR:
173173
reduce_dims = None
174174
elif args.strategy in (
175175
QuantizationStrategy.TENSOR_GROUP,
176176
QuantizationStrategy.GROUP,
177177
):
178+
#if len(value.shape) > 2:
179+
# value = value.squeeze(0)
178180
if len(value.shape) > 2:
179-
value = value.squeeze(0)
181+
dim = {0, 1, 2}
182+
else:
183+
dim = {0, 1}
180184

181-
dim = {0, 1}
182-
reduce_dims = tuple(idx for idx in range(3) if idx not in dim)
185+
reduce_dims = tuple(idx for idx in range(len(value.shape) + 1) if idx not in dim)
183186
keep_dims = False
184-
value = torch.reshape(
185-
value,
186-
(
187-
value.shape[0],
188-
math.ceil(value.shape[1] / args.group_size),
189-
args.group_size,
190-
),
191-
)
187+
188+
if len(value.shape) > 2:
189+
value = torch.reshape(
190+
value,
191+
(
192+
value.shape[0],
193+
value.shape[1],
194+
math.ceil(value.shape[-1] / args.group_size),
195+
args.group_size,
196+
),
197+
)
198+
else:
199+
value = torch.reshape(
200+
value,
201+
(
202+
value.shape[0],
203+
math.ceil(value.shape[-1] / args.group_size),
204+
args.group_size,
205+
),
206+
)
192207
else:
193208
supported_strategies = (
194209
QuantizationStrategy.TOKEN,

0 commit comments

Comments
 (0)