@@ -1306,6 +1306,38 @@ def call(self, inputs):
1306
1306
}
1307
1307
1308
1308
1309
+ def _get_simple_model ():
1310
+ """Builds a simple sequential model for testing."""
1311
+ return models .Sequential ([layers .Dense (10 , input_shape = (5 ,))])
1312
+
1313
+
1314
+ quantize_test_cases = [
1315
+ # --- Error Scenarios ---
1316
+ (
1317
+ "gptq" ,
1318
+ {"wbits" : 4 }, # Invalid config (dict, not GPTQConfig)
1319
+ TypeError ,
1320
+ "must pass a `config` argument of type" ,
1321
+ "gptq_with_invalid_config" ,
1322
+ ),
1323
+ (
1324
+ "int8" ,
1325
+ GPTQConfig (dataset = ["test" ], tokenizer = lambda x : x ),
1326
+ ValueError ,
1327
+ "is only supported for 'gptq' mode" ,
1328
+ "non_gptq_with_unsupported_config" ,
1329
+ ),
1330
+ # --- Valid Scenario ---
1331
+ (
1332
+ "int8" ,
1333
+ None , # No config, which is correct
1334
+ None , # No exception expected
1335
+ None ,
1336
+ "non_gptq_runs_without_error" ,
1337
+ ),
1338
+ ]
1339
+
1340
+
1309
1341
@pytest .mark .requires_trainable_backend
1310
1342
class TestModelQuantization :
1311
1343
def _run_gptq_test_on_dataset (self , dataset , ** config_kwargs ):
@@ -1359,3 +1391,28 @@ def _run_gptq_test_on_dataset(self, dataset, **config_kwargs):
1359
1391
def test_quantize_gptq_combinations (self , dataset , config ):
1360
1392
"""Runs GPTQ tests across different datasets and config variations."""
1361
1393
self ._run_gptq_test_on_dataset (dataset , ** config )
1394
+
1395
+ @pytest .mark .parametrize (
1396
+ "mode, config, expected_exception, match_message, test_id" ,
1397
+ quantize_test_cases ,
1398
+ ids = [case [- 1 ] for case in quantize_test_cases ],
1399
+ )
1400
+ def test_quantize_scenarios (
1401
+ self , mode , config , expected_exception , match_message , test_id
1402
+ ):
1403
+ """
1404
+ Tests various scenarios for the model.quantize() method, including
1405
+ error handling and valid calls.
1406
+ """
1407
+ model = _get_simple_model ()
1408
+
1409
+ if expected_exception :
1410
+ # Test for cases where an error is expected
1411
+ with pytest .raises (expected_exception , match = match_message ):
1412
+ model .quantize (mode , config = config )
1413
+ else :
1414
+ # Test for valid cases where no error should occur
1415
+ try :
1416
+ model .quantize (mode , config = config )
1417
+ except (ValueError , TypeError ) as e :
1418
+ pytest .fail (f"Test case '{ test_id } ' failed unexpectedly: { e } " )
0 commit comments