@@ -258,13 +258,15 @@ def get_sample_tinyllama_quant_config(status: str = "frozen"):
258
258
259
259
@requires_accelerate ()
260
260
@pytest .mark .parametrize (
261
- "ignore " ,
261
+ "target,should_raise_warning " ,
262
262
[
263
- ( "lm_head" , "re:.*gate" ) ,
264
- ( "lm_head " , "re:.*foobarbaz" ),
263
+ [( "Linear" ,), False ] ,
264
+ [( "Linear " , "re:.*foobarbaz" ), True ] ,
265
265
],
266
266
)
267
- def test_apply_quantization_status (ignore ):
267
+ def test_apply_quantization_status (caplog , target , should_raise_warning ):
268
+ import logging
269
+
268
270
# load a dense, unquantized tiny llama model
269
271
model = get_tinyllama_model ()
270
272
quantization_config_dict = {
@@ -279,13 +281,19 @@ def test_apply_quantization_status(ignore):
279
281
"symmetric" : False ,
280
282
"strategy" : "tensor" ,
281
283
},
282
- "targets" : [ "Linear" ] ,
284
+ "targets" : target ,
283
285
}
284
286
},
287
+ "ignore" : ["lm_head" , "re:.*gate" ],
285
288
}
286
- quantization_config_dict ["ignore" ] = ignore
287
289
288
290
config = QuantizationConfig (** quantization_config_dict )
289
291
config .quantization_status = QuantizationStatus .CALIBRATION
290
292
291
- apply_quantization_config (model , config )
293
+ # mismatch in the ignore key of quantization_config_dict
294
+ with caplog .at_level (logging .WARNING ):
295
+ apply_quantization_config (model , config )
296
+ if should_raise_warning :
297
+ assert len (caplog .text ) > 0
298
+ else :
299
+ assert len (caplog .text ) == 0
0 commit comments