@@ -1293,57 +1293,57 @@ def call(self, inputs):
1293
1293
1294
1294
@pytest .mark .requires_trainable_backend
1295
1295
class ModelQuantizationTest (testing .TestCase ):
1296
- def _run_gptq_test_on_dataset (self , dataset ):
1297
- """Helper function to run a full GPTQ quantization
1298
- test on a given dataset."""
1299
-
1296
+ def _run_gptq_test_on_dataset (self , dataset , ** config_kwargs ):
1297
+ """Helper function to run a full GPTQ quantization test."""
1300
1298
model = _get_model_with_dense_attention ()
1301
1299
rng = np .random .default_rng (seed = 42 )
1300
+
1302
1301
# 1. Common setup
1303
1302
NUM_SAMPLES = 16
1304
1303
SEQUENCE_LENGTH = 128
1305
1304
VOCAB_SIZE = 1000
1306
1305
W_BITS = 4
1307
- GROUP_SIZE = 32
1306
+
1307
+ # Default config that can be overridden by config_kwargs
1308
+ base_config = {
1309
+ "dataset" : dataset ,
1310
+ "wbits" : W_BITS ,
1311
+ "nsamples" : NUM_SAMPLES ,
1312
+ "seqlen" : SEQUENCE_LENGTH ,
1313
+ "group_size" : 32 ,
1314
+ "symmetric" : False ,
1315
+ "act_order" : False ,
1316
+ }
1308
1317
1309
1318
mock_tokenizer = lambda text : np .array (
1310
1319
[ord (c ) % VOCAB_SIZE for c in text ]
1311
1320
)
1312
1321
mock_tokenizer .tokenize = mock_tokenizer
1322
+ base_config ["tokenizer" ] = mock_tokenizer
1313
1323
1314
- # 2. Find target layer and get original weights
1324
+ # Find target layer and get original weights
1315
1325
target_layer = model .layers [2 ].ffn .layers [0 ]
1316
-
1317
1326
self .assertIsNotNone (
1318
1327
target_layer ,
1319
- "Test setup failed: No Dense layer was found inside "
1320
- "an 'ffn' block." ,
1328
+ "Test setup failed: No Dense layer found in 'ffn' block." ,
1321
1329
)
1322
1330
original_weights = np .copy (target_layer .kernel .numpy ())
1323
1331
1324
- # 3. Configure and run quantization
1325
- gptq_config = GPTQConfig (
1326
- dataset = dataset ,
1327
- tokenizer = mock_tokenizer ,
1328
- wbits = W_BITS ,
1329
- nsamples = NUM_SAMPLES ,
1330
- seqlen = SEQUENCE_LENGTH ,
1331
- group_size = GROUP_SIZE ,
1332
- )
1332
+ # Configure and run quantization
1333
+ final_config = {** base_config , ** config_kwargs }
1334
+ gptq_config = GPTQConfig (** final_config )
1335
+
1333
1336
model .quantize ("gptq" , quant_config = gptq_config )
1334
1337
1335
- # 4. Assertions and verification
1338
+ # Assertions and verification
1336
1339
quantized_weights = target_layer .kernel .numpy ()
1337
1340
1338
- # Assert that the weights have been changed
1339
1341
self .assertNotAllClose (
1340
1342
original_weights ,
1341
1343
quantized_weights ,
1342
- msg = "Weights were not changed by the GPTQ process for "
1343
- "dataset: {dataset}" ,
1344
+ msg = f"Weights not changed by GPTQ for config: { config_kwargs } " ,
1344
1345
)
1345
1346
1346
- # Verify the quantized model can still make a prediction
1347
1347
dummy_sample = rng .integers (
1348
1348
low = 0 , high = VOCAB_SIZE , size = (1 , SEQUENCE_LENGTH )
1349
1349
)
@@ -1378,3 +1378,21 @@ def test_quantize_gptq_on_different_datasets(self):
1378
1378
# for each specific dataset without stopping the whole test.
1379
1379
with self .subTest (dataset_type = dataset_name ):
1380
1380
self ._run_gptq_test_on_dataset (dataset )
1381
+
1382
+ def test_quantize_gptq_with_config_variations (self ):
1383
+ """Tests GPTQ with specific config variations."""
1384
+ config_variations = {
1385
+ "per_channel" : {"group_size" : - 1 },
1386
+ "act_order" : {"act_order" : True },
1387
+ "symmetric" : {"symmetric" : True },
1388
+ "all_options_enabled" : {
1389
+ "group_size" : - 1 ,
1390
+ "act_order" : True ,
1391
+ "symmetric" : True ,
1392
+ },
1393
+ }
1394
+
1395
+ dataset = ["This is the calibration data for the test." ]
1396
+ for config_name , config_overrides in config_variations .items ():
1397
+ with self .subTest (config_type = config_name ):
1398
+ self ._run_gptq_test_on_dataset (dataset , ** config_overrides )
0 commit comments