Skip to content

Commit 8053b51

Browse files
committed
increase num of required observed dims
Signed-off-by: Kyle Sayers <[email protected]>
1 parent 72560d4 commit 8053b51

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

src/compressed_tensors/quantization/lifecycle/initialize.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -202,8 +202,8 @@ def initialize_qparams(
202202
expected_shape = (1, 1)
203203

204204
elif strategy == QuantizationStrategy.CHANNEL:
205-
if len(observed_shape) < 1:
206-
raise ValueError("Channel quant requires at least 1 observed dimension")
205+
if len(observed_shape) < 2:
206+
raise ValueError("Channel quant requires at least 2 observed dimensions")
207207

208208
expected_shape = (observed_shape[-2], 1)
209209

@@ -234,6 +234,12 @@ def initialize_qparams(
234234
num_cols = strategy_cdiv(observed_shape[-1], block_structure[-1], strategy)
235235
expected_shape = (num_rows, num_cols)
236236

237+
elif strategy == QuantizationStrategy.ATTN_HEAD:
238+
if len(observed_shape) < 2:
239+
raise ValueError("Attention quant requires at least 2 observed dimensions")
240+
241+
expected_shape = (observed_shape[-2], 1)
242+
237243
else:
238244
assert False, f"Unknown strategy {strategy}"
239245

0 commit comments

Comments
 (0)