1
1
import os
2
2
import pickle
3
3
from collections import namedtuple
4
+ from collections .abc import Callable
4
5
5
6
import numpy as np
6
7
import pytest
@@ -1250,9 +1251,8 @@ def dummy_dataset_generator(nsamples, seqlen, vocab_size=1000):
1250
1251
yield rng .integers (low = 0 , high = vocab_size , size = (1 , seqlen ))
1251
1252
1252
1253
1253
- # Helper function to build a simple transformer model that uses standard
1254
- # Keras `Dense` layers for its attention projections.
1255
- def _get_model_with_dense_attention ():
1254
+ # Helper function to build a simple transformer model.
1255
+ def get_model_with_dense_attention ():
1256
1256
"""Builds a simple transformer model using Dense for attention."""
1257
1257
vocab_size = 1000
1258
1258
embed_dim = 32
@@ -1262,8 +1262,6 @@ def _get_model_with_dense_attention():
1262
1262
class SimpleTransformerBlock (layers .Layer ):
1263
1263
def __init__ (self , embed_dim , num_heads , ff_dim , ** kwargs ):
1264
1264
super ().__init__ (** kwargs )
1265
- # The standard MultiHeadAttention layer uses Dense layers
1266
- # for its projections.
1267
1265
self .att = layers .MultiHeadAttention (
1268
1266
num_heads = num_heads , key_dim = embed_dim
1269
1267
)
@@ -1292,22 +1290,44 @@ def call(self, inputs):
1292
1290
return model
1293
1291
1294
1292
1293
+ # Define parameters for the tests
1294
+ long_text = """gptq is an easy-to-use model quantization library..."""
1295
+ DATASETS = {
1296
+ "string_dataset" : [long_text ],
1297
+ "generator_dataset" : lambda : dummy_dataset_generator (
1298
+ nsamples = 16 , seqlen = 128
1299
+ ),
1300
+ }
1301
+ CONFIGS = {
1302
+ "default" : {},
1303
+ "per_channel" : {"group_size" : - 1 },
1304
+ "act_order" : {"act_order" : True },
1305
+ "symmetric" : {"symmetric" : True },
1306
+ }
1307
+
1308
+
1295
1309
@pytest .mark .requires_trainable_backend
1296
- class ModelQuantizationTest ( testing . TestCase ) :
1310
+ class TestModelQuantization :
1297
1311
def _run_gptq_test_on_dataset (self , dataset , ** config_kwargs ):
1298
1312
"""Helper function to run a full GPTQ quantization test."""
1299
- model = _get_model_with_dense_attention ()
1313
+ if isinstance (dataset , Callable ):
1314
+ dataset = dataset ()
1315
+ model = get_model_with_dense_attention ()
1300
1316
rng = np .random .default_rng (seed = 42 )
1301
1317
1302
- # 1. Common setup
1303
1318
NUM_SAMPLES = 16
1304
1319
SEQUENCE_LENGTH = 128
1305
1320
VOCAB_SIZE = 1000
1306
1321
W_BITS = 4
1307
1322
1308
- # Default config that can be overridden by config_kwargs
1323
+ mock_tokenizer = lambda text : np .array (
1324
+ [ord (c ) % VOCAB_SIZE for c in text ]
1325
+ )
1326
+ mock_tokenizer .tokenize = mock_tokenizer
1327
+
1309
1328
base_config = {
1310
1329
"dataset" : dataset ,
1330
+ "tokenizer" : mock_tokenizer ,
1311
1331
"wbits" : W_BITS ,
1312
1332
"nsamples" : NUM_SAMPLES ,
1313
1333
"seqlen" : SEQUENCE_LENGTH ,
@@ -1316,82 +1336,26 @@ def _run_gptq_test_on_dataset(self, dataset, **config_kwargs):
1316
1336
"act_order" : False ,
1317
1337
}
1318
1338
1319
- mock_tokenizer = lambda text : np .array (
1320
- [ord (c ) % VOCAB_SIZE for c in text ]
1321
- )
1322
- mock_tokenizer .tokenize = mock_tokenizer
1323
- base_config ["tokenizer" ] = mock_tokenizer
1324
-
1325
- # Find target layer and get original weights
1326
1339
target_layer = model .layers [2 ].ffn .layers [0 ]
1327
- self .assertIsNotNone (
1328
- target_layer ,
1329
- "Test setup failed: No Dense layer found in 'ffn' block." ,
1330
- )
1340
+ assert target_layer is not None
1331
1341
original_weights = np .copy (target_layer .kernel )
1332
1342
1333
- # Configure and run quantization
1334
1343
final_config = {** base_config , ** config_kwargs }
1335
1344
gptq_config = GPTQConfig (** final_config )
1336
1345
1337
1346
model .quantize ("gptq" , config = gptq_config )
1338
1347
1339
- # Assertions and verification
1340
1348
quantized_weights = target_layer .kernel
1341
1349
1342
- self .assertNotAllClose (
1343
- original_weights ,
1344
- quantized_weights ,
1345
- msg = f"Weights not changed by GPTQ for config: { config_kwargs } " ,
1346
- )
1350
+ assert not np .allclose (original_weights , quantized_weights )
1347
1351
1348
1352
dummy_sample = rng .integers (
1349
1353
low = 0 , high = VOCAB_SIZE , size = (1 , SEQUENCE_LENGTH )
1350
1354
)
1351
1355
_ = model .predict (dummy_sample )
1352
1356
1353
- def test_quantize_gptq_on_different_datasets (self ):
1354
- """Tests GPTQ with various dataset types (string list, generator)."""
1355
-
1356
- # Define the datasets to be tested
1357
- long_text = """gptq is an easy-to-use model quantization library
1358
- with user-friendly apis, based on GPTQ algorithm. The goal is to
1359
- quantize pre-trained models to 4-bit or even 3-bit precision with
1360
- minimal performance degradation.
1361
- This allows for running larger models on less powerful hardware,
1362
- reducing memory footprint and increasing inference speed.
1363
- The process involves calibrating the model on a small dataset
1364
- to determine the quantization parameters.
1365
- This technique is particularly useful for deploying large language
1366
- models in resource-constrained environments where every bit of memory
1367
- and every millisecond of latency counts."""
1368
-
1369
- datasets_to_test = {
1370
- "string_dataset" : [long_text ],
1371
- "generator_dataset" : dummy_dataset_generator (
1372
- nsamples = 16 , seqlen = 128 , vocab_size = 1000
1373
- ),
1374
- }
1375
-
1376
- # Loop through the datasets and run each as a sub-test
1377
- for dataset_name , dataset in datasets_to_test .items ():
1378
- with self .subTest (dataset_type = dataset_name ):
1379
- self ._run_gptq_test_on_dataset (dataset )
1380
-
1381
- def test_quantize_gptq_with_config_variations (self ):
1382
- """Tests GPTQ with specific config variations."""
1383
- config_variations = {
1384
- "per_channel" : {"group_size" : - 1 },
1385
- "act_order" : {"act_order" : True },
1386
- "symmetric" : {"symmetric" : True },
1387
- "all_options_enabled" : {
1388
- "group_size" : - 1 ,
1389
- "act_order" : True ,
1390
- "symmetric" : True ,
1391
- },
1392
- }
1393
-
1394
- dataset = ["This is the calibration data for the test." ]
1395
- for config_name , config_overrides in config_variations .items ():
1396
- with self .subTest (config_type = config_name ):
1397
- self ._run_gptq_test_on_dataset (dataset , ** config_overrides )
1357
+ @pytest .mark .parametrize ("dataset" , DATASETS .values (), ids = DATASETS .keys ())
1358
+ @pytest .mark .parametrize ("config" , CONFIGS .values (), ids = CONFIGS .keys ())
1359
+ def test_quantize_gptq_combinations (self , dataset , config ):
1360
+ """Runs GPTQ tests across different datasets and config variations."""
1361
+ self ._run_gptq_test_on_dataset (dataset , ** config )
0 commit comments