Skip to content

Commit 1cfd8bb

Browse files
committed
update
1 parent 3548dc5 commit 1cfd8bb

File tree

2 files changed

+6
-12
lines changed

2 files changed

+6
-12
lines changed

src/compressed_tensors/quantization/lifecycle/forward.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -291,11 +291,11 @@ def _process_quantization(
291291
x = safe_permute(x, perm, dim=1)
292292

293293
# Maintain all dimensions apart from the last dim, which is divided by the group_size
294-
reshaped_dims = tuple(x.shape[:-1]) + (
294+
reshaped_dims = (
295295
ceil(x.shape[-1] / group_size),
296296
group_size,
297297
)
298-
x = torch.reshape(x, reshaped_dims)
298+
x = x.unflatten(-1, reshaped_dims)
299299

300300
if do_quantize:
301301
output = _quantize(
@@ -318,11 +318,7 @@ def _process_quantization(
318318
global_scale=global_scale,
319319
)
320320

321-
original_shaped_dims = tuple(output.shape[:-2]) + (
322-
output.shape[-1] * output.shape[-2],
323-
)
324-
output = torch.reshape(output, original_shaped_dims)
325-
321+
output = output.flatten(start_dim=-2)
326322
output = output.to(output_dtype)
327323

328324
if not is_column_order:

src/compressed_tensors/quantization/utils/helpers.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -176,16 +176,14 @@ def compute_dynamic_scales_and_zp(
176176
QuantizationStrategy.GROUP,
177177
):
178178

179-
reduce_dims = tuple(
180-
idx for idx in range(len(value.shape) + 1) if idx not in range(value.dim())
181-
)
179+
reduce_dims = -1
182180
keep_dims = False
183181

184-
reshaped_dims = tuple(value.shape[:-1]) + (
182+
reshaped_dims = (
185183
math.ceil(value.shape[-1] / args.group_size),
186184
args.group_size,
187185
)
188-
value = torch.reshape(value, reshaped_dims)
186+
value = value.unflatten(-1, reshaped_dims)
189187

190188
else:
191189
supported_strategies = (

0 commit comments

Comments
 (0)