Skip to content

Commit 4d89141

Browse files
kylesayrsdsikka
andauthored
Raise error if group_size is passed but wrong strategy (#149)
* add error * Update src/compressed_tensors/quantization/quant_args.py Co-authored-by: Dipika Sikka <[email protected]> * remove extra line * fix invalid configs it tests --------- Co-authored-by: Dipika Sikka <[email protected]>
1 parent b885229 commit 4d89141

File tree

4 files changed

+19
-3
lines changed

4 files changed

+19
-3
lines changed

src/compressed_tensors/quantization/quant_args.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,12 @@ def validate_model_after(model: "QuantizationArgs") -> Dict[str, Any]:
187187
f"strategy {strategy} requires group_size to be "
188188
"set to a positive value"
189189
)
190+
if (
191+
group_size is not None
192+
and group_size > 0
193+
and strategy != QuantizationStrategy.GROUP
194+
):
195+
raise ValueError("group_size requires strategy to be set to 'group'")
190196

191197
# validate activation ordering and strategy
192198
if actorder is not None and strategy != QuantizationStrategy.GROUP:

tests/test_compressors/test_fp8_quant.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def make_dummy_g_idx(columns: int, group_size: int) -> torch.Tensor:
6767
],
6868
[
6969
QuantizationStrategy.CHANNEL,
70-
128,
70+
None,
7171
torch.rand((512, 1)) * 0.01,
7272
torch.zeros((512, 1), dtype=torch.int8),
7373
],

tests/test_compressors/test_int_quant.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def get_dummy_quant_config(strategy, group_size=None):
5757
[
5858
QuantizationStrategy.CHANNEL,
5959
False,
60-
128,
60+
None,
6161
torch.rand((512, 1)) * 0.01,
6262
((torch.rand((512, 1)) - 0.5) * 127).to(torch.int8),
6363
],
@@ -102,7 +102,7 @@ def test_quant_format(strategy, symmetric, group_size, sc, zp):
102102
],
103103
[
104104
QuantizationStrategy.CHANNEL,
105-
128,
105+
None,
106106
torch.rand((300, 1)) * 0.01,
107107
torch.zeros((300, 1), dtype=torch.int8),
108108
],

tests/test_quantization/test_quant_args.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,16 @@ def test_group():
4343
with pytest.raises(ValueError):
4444
QuantizationArgs(strategy=QuantizationStrategy.GROUP, group_size=-1)
4545

46+
args = QuantizationArgs(group_size=128, strategy="group")
47+
assert args.group_size == 128
48+
assert args.strategy == "group"
49+
50+
with pytest.raises(ValueError):
51+
QuantizationArgs(strategy=QuantizationStrategy.GROUP)
52+
53+
with pytest.raises(ValueError):
54+
QuantizationArgs(strategy="tensor", group_size=128)
55+
4656

4757
def test_block():
4858
kwargs = {"strategy": "block", "block_structure": "2x4"}

0 commit comments

Comments
 (0)