File tree Expand file tree Collapse file tree 2 files changed +6
-12
lines changed
src/compressed_tensors/quantization Expand file tree Collapse file tree 2 files changed +6
-12
lines changed Original file line number Diff line number Diff line change @@ -291,11 +291,11 @@ def _process_quantization(
291
291
x = safe_permute (x , perm , dim = 1 )
292
292
293
293
# Maintain all dimensions apart from the last dim, which is divided by the group_size
294
- reshaped_dims = tuple ( x . shape [: - 1 ]) + (
294
+ reshaped_dims = (
295
295
ceil (x .shape [- 1 ] / group_size ),
296
296
group_size ,
297
297
)
298
- x = torch . reshape ( x , reshaped_dims )
298
+ x = x . unflatten ( - 1 , reshaped_dims )
299
299
300
300
if do_quantize :
301
301
output = _quantize (
@@ -318,11 +318,7 @@ def _process_quantization(
318
318
global_scale = global_scale ,
319
319
)
320
320
321
- original_shaped_dims = tuple (output .shape [:- 2 ]) + (
322
- output .shape [- 1 ] * output .shape [- 2 ],
323
- )
324
- output = torch .reshape (output , original_shaped_dims )
325
-
321
+ output = output .flatten (start_dim = - 2 )
326
322
output = output .to (output_dtype )
327
323
328
324
if not is_column_order :
Original file line number Diff line number Diff line change @@ -176,16 +176,14 @@ def compute_dynamic_scales_and_zp(
176
176
QuantizationStrategy .GROUP ,
177
177
):
178
178
179
- reduce_dims = tuple (
180
- idx for idx in range (len (value .shape ) + 1 ) if idx not in range (value .dim ())
181
- )
179
+ reduce_dims = - 1
182
180
keep_dims = False
183
181
184
- reshaped_dims = tuple ( value . shape [: - 1 ]) + (
182
+ reshaped_dims = (
185
183
math .ceil (value .shape [- 1 ] / args .group_size ),
186
184
args .group_size ,
187
185
)
188
- value = torch . reshape ( value , reshaped_dims )
186
+ value = value . unflatten ( - 1 , reshaped_dims )
189
187
190
188
else :
191
189
supported_strategies = (
You can’t perform that action at this time.
0 commit comments