File tree Expand file tree Collapse file tree 1 file changed +8
-2
lines changed
src/compressed_tensors/quantization/lifecycle Expand file tree Collapse file tree 1 file changed +8
-2
lines changed Original file line number Diff line number Diff line change @@ -202,8 +202,8 @@ def initialize_qparams(
202
202
expected_shape = (1 , 1 )
203
203
204
204
elif strategy == QuantizationStrategy .CHANNEL :
205
- if len (observed_shape ) < 1 :
206
- raise ValueError ("Channel quant requires at least 1 observed dimension " )
205
+ if len (observed_shape ) < 2 :
206
+ raise ValueError ("Channel quant requires at least 2 observed dimensions " )
207
207
208
208
expected_shape = (observed_shape [- 2 ], 1 )
209
209
@@ -234,6 +234,12 @@ def initialize_qparams(
234
234
num_cols = strategy_cdiv (observed_shape [- 1 ], block_structure [- 1 ], strategy )
235
235
expected_shape = (num_rows , num_cols )
236
236
237
+ elif strategy == QuantizationStrategy .ATTN_HEAD :
238
+ if len (observed_shape ) < 2 :
239
+ raise ValueError ("Attention quant requires at least 2 observed dimensions" )
240
+
241
+ expected_shape = (observed_shape [- 2 ], 1 )
242
+
237
243
else :
238
244
assert False , f"Unknown strategy { strategy } "
239
245
You can’t perform that action at this time.
0 commit comments