Skip to content

Commit b29792f

Browse files
committed
make generic
1 parent be89690 commit b29792f

File tree

3 files changed

+26
-62
lines changed

3 files changed

+26
-62
lines changed

src/compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def compression_param_names(self) -> Tuple[str]:
6060
"weight_zero_point",
6161
"weight_global_scale",
6262
)
63-
63+
6464
def compression_param_info(
6565
self,
6666
weight_shape: torch.Size,
@@ -75,11 +75,13 @@ def compression_param_info(
7575
:return: dictionary mapping compressed parameter names to shape and dtype
7676
"""
7777
output = {
78-
"weight_packed": (torch.Size((weight_shape[0], weight_shape[1] // 2)), torch.uint8),
78+
"weight_packed": (
79+
torch.Size((weight_shape[0], weight_shape[1] // 2)),
80+
torch.uint8,
81+
),
7982
}
8083
return output
8184

82-
8385
def compress_weight(
8486
self,
8587
weight: Tensor,

src/compressed_tensors/quantization/lifecycle/forward.py

Lines changed: 11 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -257,11 +257,7 @@ def _process_quantization(
257257
QuantizationStrategy.GROUP,
258258
QuantizationStrategy.TENSOR_GROUP,
259259
):
260-
"""
261-
n_dims = x.shape
262-
if len(n_dims) > 2:
263-
x = x.squeeze(0)
264-
"""
260+
265261
output_dtype = dtype if dtype is not None else x.dtype
266262
output = torch.zeros_like(x).to(output_dtype)
267263
columns = output.shape[-1]
@@ -294,25 +290,12 @@ def _process_quantization(
294290
perm = torch.argsort(g_idx)
295291
x = safe_permute(x, perm, dim=1)
296292

297-
if len(x.shape) > 2:
298-
x = torch.reshape(
299-
x,
300-
(
301-
x.shape[0],
302-
x.shape[1],
303-
ceil(x.shape[-1] / group_size),
304-
group_size,
305-
),
306-
)
307-
else:
308-
x = torch.reshape(
309-
x,
310-
(
311-
x.shape[0],
312-
ceil(x.shape[-1] / group_size),
313-
group_size,
314-
),
315-
)
293+
# Maintain all dimensions apart from the last dim, which is divided by the group_size
294+
reshaped_dims = tuple(x.shape[:-1]) + (
295+
ceil(x.shape[-1] / group_size),
296+
group_size,
297+
)
298+
x = torch.reshape(x, reshaped_dims)
316299

317300
if do_quantize:
318301
output = _quantize(
@@ -335,25 +318,16 @@ def _process_quantization(
335318
global_scale=global_scale,
336319
)
337320

338-
if len(x.shape) > 3:
339-
output = torch.reshape(
340-
output,
341-
(output.shape[0], output.shape[1], output.shape[-1] * output.shape[-2]),
342-
)
343-
else:
344-
output = torch.reshape(
345-
output,
346-
(output.shape[0], output.shape[-1] * output.shape[-2]),
347-
)
321+
original_shaped_dims = tuple(output.shape[:-2]) + (
322+
output.shape[-1] * output.shape[-2],
323+
)
324+
output = torch.reshape(output, original_shaped_dims)
348325

349326
output = output.to(output_dtype)
350327

351328
if not is_column_order:
352329
output = safe_permute(output, torch.argsort(perm), dim=1)
353330

354-
#if len(n_dims) > 2:
355-
# output = output.unsqueeze(0)
356-
357331
else: # covers channel, token and tensor strategies
358332
if do_quantize:
359333
output = _quantize(

src/compressed_tensors/quantization/utils/helpers.py

Lines changed: 10 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -175,35 +175,23 @@ def compute_dynamic_scales_and_zp(
175175
QuantizationStrategy.TENSOR_GROUP,
176176
QuantizationStrategy.GROUP,
177177
):
178-
#if len(value.shape) > 2:
179-
# value = value.squeeze(0)
178+
180179
if len(value.shape) > 2:
181180
dim = {0, 1, 2}
182181
else:
183182
dim = {0, 1}
184183

185-
reduce_dims = tuple(idx for idx in range(len(value.shape) + 1) if idx not in dim)
184+
reduce_dims = tuple(
185+
idx for idx in range(len(value.shape) + 1) if idx not in dim
186+
)
186187
keep_dims = False
187188

188-
if len(value.shape) > 2:
189-
value = torch.reshape(
190-
value,
191-
(
192-
value.shape[0],
193-
value.shape[1],
194-
math.ceil(value.shape[-1] / args.group_size),
195-
args.group_size,
196-
),
197-
)
198-
else:
199-
value = torch.reshape(
200-
value,
201-
(
202-
value.shape[0],
203-
math.ceil(value.shape[-1] / args.group_size),
204-
args.group_size,
205-
),
206-
)
189+
reshaped_dims = tuple(value.shape[:-1]) + (
190+
math.ceil(value.shape[-1] / args.group_size),
191+
args.group_size,
192+
)
193+
value = torch.reshape(value, reshaped_dims)
194+
207195
else:
208196
supported_strategies = (
209197
QuantizationStrategy.TOKEN,

0 commit comments

Comments
 (0)