Skip to content

Commit 4b1d777

Browse files
committed
Test warning raised on unmatched target
1 parent c9b8ba2 commit 4b1d777

File tree

1 file changed

+15
-7
lines changed

1 file changed

+15
-7
lines changed

tests/test_quantization/lifecycle/test_apply.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -258,13 +258,15 @@ def get_sample_tinyllama_quant_config(status: str = "frozen"):
258258

259259
@requires_accelerate()
260260
@pytest.mark.parametrize(
261-
"ignore",
261+
"target,should_raise_warning",
262262
[
263-
("lm_head", "re:.*gate"),
264-
("lm_head", "re:.*foobarbaz"),
263+
[("Linear",), False],
264+
[("Linear", "re:.*foobarbaz"), True],
265265
],
266266
)
267-
def test_apply_quantization_status(ignore):
267+
def test_apply_quantization_status(caplog, target, should_raise_warning):
268+
import logging
269+
268270
# load a dense, unquantized tiny llama model
269271
model = get_tinyllama_model()
270272
quantization_config_dict = {
@@ -279,13 +281,19 @@ def test_apply_quantization_status(ignore):
279281
"symmetric": False,
280282
"strategy": "tensor",
281283
},
282-
"targets": ["Linear"],
284+
"targets": target,
283285
}
284286
},
287+
"ignore": ["lm_head", "re:.*gate"],
285288
}
286-
quantization_config_dict["ignore"] = ignore
287289

288290
config = QuantizationConfig(**quantization_config_dict)
289291
config.quantization_status = QuantizationStatus.CALIBRATION
290292

291-
apply_quantization_config(model, config)
293+
# mismatch in the ignore key of quantization_config_dict
294+
with caplog.at_level(logging.WARNING):
295+
apply_quantization_config(model, config)
296+
if should_raise_warning:
297+
assert len(caplog.text) > 0
298+
else:
299+
assert len(caplog.text) == 0

0 commit comments

Comments
 (0)