Skip to content

Commit 5121d47

Browse files
authored
Add block strategy and structure validation (#483)
* add validation Signed-off-by: Kyle Sayers <[email protected]> * fix test Signed-off-by: Kyle Sayers <[email protected]> --------- Signed-off-by: Kyle Sayers <[email protected]>
1 parent b5dc1e9 commit 5121d47

File tree

2 files changed

+23
-8
lines changed

2 files changed

+23
-8
lines changed

src/compressed_tensors/quantization/quant_args.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,7 @@ def validate_model_after(model: "QuantizationArgs") -> "QuantizationArgs":
259259
# extract user-passed values from dictionary
260260
strategy = model.strategy
261261
group_size = model.group_size
262+
block_structure = model.block_structure
262263
actorder = model.actorder
263264
dynamic = model.dynamic
264265
observer = model.observer
@@ -277,7 +278,7 @@ def validate_model_after(model: "QuantizationArgs") -> "QuantizationArgs":
277278
"strategy='group' and group_size = -1 for 'channel'"
278279
)
279280

280-
# validate strategy and group
281+
# validate group strategy
281282
if strategy == QuantizationStrategy.GROUP:
282283
if group_size is None or group_size <= 0:
283284
raise ValueError(
@@ -292,6 +293,14 @@ def validate_model_after(model: "QuantizationArgs") -> "QuantizationArgs":
292293
):
293294
raise ValueError("group_size requires strategy to be set to 'group'")
294295

296+
# validate block strategy
297+
has_block_strategy = strategy == QuantizationStrategy.BLOCK
298+
has_block_structure = block_structure is not None
299+
if has_block_strategy and not has_block_structure:
300+
raise ValueError(f"Block strategy requires block structure\n{model}")
301+
if has_block_structure and not has_block_strategy:
302+
raise ValueError(f"Block structure requires block strategy\n{model}")
303+
295304
# validate activation ordering and strategy
296305
if actorder is not None and strategy != QuantizationStrategy.GROUP:
297306
raise ValueError(

tests/test_quantization/test_utils/test_helpers.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,36 +32,42 @@
3232
[
3333
(
3434
False,
35-
QuantizationStrategy.TENSOR,
35+
"tensor",
3636
torch.Size(
3737
[
3838
1,
3939
]
4040
),
4141
),
42-
(True, QuantizationStrategy.CHANNEL, torch.Size([1, 1])),
43-
(True, QuantizationStrategy.GROUP, torch.Size([1, 1])),
42+
(True, "channel", torch.Size([1, 1])),
43+
(True, "group", torch.Size([1, 1])),
4444
(
4545
False,
46-
QuantizationStrategy.BLOCK,
46+
"block",
4747
torch.Size(
4848
[
4949
1,
5050
]
5151
),
5252
),
53-
(True, QuantizationStrategy.TOKEN, torch.Size([1, 1])),
53+
(True, "token", torch.Size([1, 1])),
5454
],
5555
)
5656
def test_calculate_qparams(keepdims, strategy, exp_shape):
57-
value = torch.randn(14, 5)
57+
value = torch.empty(5, 6)
5858
min_val = torch.amin(value, dim=tuple(), keepdims=keepdims)
5959
max_val = torch.amax(value, dim=tuple(), keepdims=keepdims)
6060

6161
if strategy == QuantizationStrategy.GROUP:
6262
args = QuantizationArgs(strategy=strategy, group_size=2)
63+
elif strategy == QuantizationStrategy.BLOCK:
64+
args = QuantizationArgs(strategy=strategy, block_structure=[1, 3])
6365
else:
64-
args = QuantizationArgs(strategy=strategy)
66+
args = QuantizationArgs(
67+
strategy=strategy,
68+
group_size=(2 if strategy == "group" else None),
69+
block_structure=([1, 3] if strategy == "block" else None),
70+
)
6571
scale, zp = calculate_qparams(min_val, max_val, args)
6672
assert scale.shape == exp_shape
6773
assert zp.shape == exp_shape

0 commit comments

Comments
 (0)