Skip to content

Commit dc36cfa

Browse files
committed
make generic
1 parent b593884 commit dc36cfa

File tree

3 files changed

+34
-66
lines changed

3 files changed

+34
-66
lines changed

src/compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def compression_param_names(self) -> Tuple[str]:
6060
"weight_zero_point",
6161
"weight_global_scale",
6262
)
63-
63+
6464
def compression_param_info(
6565
self,
6666
weight_shape: torch.Size,
@@ -75,11 +75,13 @@ def compression_param_info(
7575
:return: dictionary mapping compressed parameter names to shape and dtype
7676
"""
7777
output = {
78-
"weight_packed": (torch.Size((weight_shape[0], weight_shape[1] // 2)), torch.uint8),
78+
"weight_packed": (
79+
torch.Size((weight_shape[0], weight_shape[1] // 2)),
80+
torch.uint8,
81+
),
7982
}
8083
return output
8184

82-
8385
def compress_weight(
8486
self,
8587
weight: Tensor,

src/compressed_tensors/quantization/lifecycle/forward.py

Lines changed: 19 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -112,17 +112,21 @@ def dequantize(
112112
if scale.shape[1] == 1:
113113
args = QuantizationArgs(strategy=QuantizationStrategy.CHANNEL)
114114
# Scale height matches input or is 1 -> group quantization across columns
115-
#
115+
#
116116
# Example 1: scale.shape[0] == 1
117117
# x_q: (4, 8), scale: (1, 4) -> 2 columns per group
118118
#
119-
# Example 2: scale.shape[0] == x_q.shape[0]
119+
# Example 2: scale.shape[0] == x_q.shape[0]
120120
# x_q: (4, 8), scale: (4, 4) -> 2 elements per group (per row)
121121
elif (scale.shape[0] == 1) or (scale.shape[0] == x_q.shape[0]):
122122
group_size = int(x_q.shape[1] / scale.shape[1])
123-
args = QuantizationArgs(strategy=QuantizationStrategy.GROUP, group_size=group_size)
123+
args = QuantizationArgs(
124+
strategy=QuantizationStrategy.GROUP, group_size=group_size
125+
)
124126
else:
125-
args = QuantizationArgs(strategy=QuantizationStrategy.BLOCK, block_structure=scale.shape)
127+
args = QuantizationArgs(
128+
strategy=QuantizationStrategy.BLOCK, block_structure=scale.shape
129+
)
126130
else:
127131
raise ValueError(
128132
f"Could not infer a quantization strategy from scale with {scale.ndim} "
@@ -253,11 +257,7 @@ def _process_quantization(
253257
QuantizationStrategy.GROUP,
254258
QuantizationStrategy.TENSOR_GROUP,
255259
):
256-
"""
257-
n_dims = x.shape
258-
if len(n_dims) > 2:
259-
x = x.squeeze(0)
260-
"""
260+
261261
output_dtype = dtype if dtype is not None else x.dtype
262262
output = torch.zeros_like(x).to(output_dtype)
263263
columns = output.shape[-1]
@@ -290,25 +290,12 @@ def _process_quantization(
290290
perm = torch.argsort(g_idx)
291291
x = safe_permute(x, perm, dim=1)
292292

293-
if len(x.shape) > 2:
294-
x = torch.reshape(
295-
x,
296-
(
297-
x.shape[0],
298-
x.shape[1],
299-
ceil(x.shape[-1] / group_size),
300-
group_size,
301-
),
302-
)
303-
else:
304-
x = torch.reshape(
305-
x,
306-
(
307-
x.shape[0],
308-
ceil(x.shape[-1] / group_size),
309-
group_size,
310-
),
311-
)
293+
# Maintain all dimensions apart from the last dim, which is divided by the group_size
294+
reshaped_dims = tuple(x.shape[:-1]) + (
295+
ceil(x.shape[-1] / group_size),
296+
group_size,
297+
)
298+
x = torch.reshape(x, reshaped_dims)
312299

313300
if do_quantize:
314301
output = _quantize(
@@ -331,25 +318,16 @@ def _process_quantization(
331318
global_scale=global_scale,
332319
)
333320

334-
if len(x.shape) > 3:
335-
output = torch.reshape(
336-
output,
337-
(output.shape[0], output.shape[1], output.shape[-1] * output.shape[-2]),
338-
)
339-
else:
340-
output = torch.reshape(
341-
output,
342-
(output.shape[0], output.shape[-1] * output.shape[-2]),
343-
)
321+
original_shaped_dims = tuple(output.shape[:-2]) + (
322+
output.shape[-1] * output.shape[-2],
323+
)
324+
output = torch.reshape(output, original_shaped_dims)
344325

345326
output = output.to(output_dtype)
346327

347328
if not is_column_order:
348329
output = safe_permute(output, torch.argsort(perm), dim=1)
349330

350-
#if len(n_dims) > 2:
351-
# output = output.unsqueeze(0)
352-
353331
else: # covers channel, token and tensor strategies
354332
if do_quantize:
355333
output = _quantize(

src/compressed_tensors/quantization/utils/helpers.py

Lines changed: 10 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -175,35 +175,23 @@ def compute_dynamic_scales_and_zp(
175175
QuantizationStrategy.TENSOR_GROUP,
176176
QuantizationStrategy.GROUP,
177177
):
178-
#if len(value.shape) > 2:
179-
# value = value.squeeze(0)
178+
180179
if len(value.shape) > 2:
181180
dim = {0, 1, 2}
182181
else:
183182
dim = {0, 1}
184183

185-
reduce_dims = tuple(idx for idx in range(len(value.shape) + 1) if idx not in dim)
184+
reduce_dims = tuple(
185+
idx for idx in range(len(value.shape) + 1) if idx not in dim
186+
)
186187
keep_dims = False
187188

188-
if len(value.shape) > 2:
189-
value = torch.reshape(
190-
value,
191-
(
192-
value.shape[0],
193-
value.shape[1],
194-
math.ceil(value.shape[-1] / args.group_size),
195-
args.group_size,
196-
),
197-
)
198-
else:
199-
value = torch.reshape(
200-
value,
201-
(
202-
value.shape[0],
203-
math.ceil(value.shape[-1] / args.group_size),
204-
args.group_size,
205-
),
206-
)
189+
reshaped_dims = tuple(value.shape[:-1]) + (
190+
math.ceil(value.shape[-1] / args.group_size),
191+
args.group_size,
192+
)
193+
value = torch.reshape(value, reshaped_dims)
194+
207195
else:
208196
supported_strategies = (
209197
QuantizationStrategy.TOKEN,

0 commit comments

Comments
 (0)