|
1 | 1 | import os
|
2 | 2 | import pickle
|
3 | 3 | from collections import namedtuple
|
| 4 | +from collections.abc import Callable |
4 | 5 |
|
5 | 6 | import numpy as np
|
6 | 7 | import pytest
|
|
9 | 10 | from keras.src import backend
|
10 | 11 | from keras.src import layers
|
11 | 12 | from keras.src import losses
|
| 13 | +from keras.src import models |
12 | 14 | from keras.src import testing
|
13 | 15 | from keras.src import tree
|
14 | 16 | from keras.src.layers.core.input_layer import Input
|
15 | 17 | from keras.src.models.functional import Functional
|
16 | 18 | from keras.src.models.model import Model
|
17 | 19 | from keras.src.models.model import model_from_json
|
| 20 | +from keras.src.quantizers.gptq_config import GPTQConfig |
18 | 21 |
|
19 | 22 |
|
20 | 23 | def _get_model():
|
@@ -1237,3 +1240,177 @@ def test_export_error(self):
|
1237 | 1240 | ),
|
1238 | 1241 | ):
|
1239 | 1242 | model.export(temp_filepath, format="tf_saved_model")
|
| 1243 | + |
| 1244 | + |
| 1245 | +def dummy_dataset_generator(num_samples, sequence_length, vocab_size=1000): |
| 1246 | + """A generator that yields random numpy arrays for fast, |
| 1247 | + self-contained tests.""" |
| 1248 | + rng = np.random.default_rng(seed=42) |
| 1249 | + for _ in range(num_samples): |
| 1250 | + yield rng.integers(low=0, high=vocab_size, size=(1, sequence_length)) |
| 1251 | + |
| 1252 | + |
| 1253 | +def get_model_with_dense_attention(): |
| 1254 | + """Builds a simple transformer model using Dense for attention.""" |
| 1255 | + vocab_size = 1000 |
| 1256 | + embed_dim = 32 |
| 1257 | + num_heads = 4 |
| 1258 | + ff_dim = 32 |
| 1259 | + |
| 1260 | + class SimpleTransformerBlock(layers.Layer): |
| 1261 | + def __init__(self, embed_dim, num_heads, ff_dim, **kwargs): |
| 1262 | + super().__init__(**kwargs) |
| 1263 | + self.att = layers.MultiHeadAttention( |
| 1264 | + num_heads=num_heads, key_dim=embed_dim |
| 1265 | + ) |
| 1266 | + self.ffn = models.Sequential( |
| 1267 | + [ |
| 1268 | + layers.Dense(ff_dim, activation="relu"), |
| 1269 | + layers.Dense(embed_dim), |
| 1270 | + ] |
| 1271 | + ) |
| 1272 | + self.layernorm1 = layers.LayerNormalization(epsilon=1e-6) |
| 1273 | + self.layernorm2 = layers.LayerNormalization(epsilon=1e-6) |
| 1274 | + |
| 1275 | + def call(self, inputs): |
| 1276 | + attention_output = self.att(inputs, inputs) |
| 1277 | + out1 = self.layernorm1(inputs + attention_output) |
| 1278 | + ffn_output = self.ffn(out1) |
| 1279 | + return self.layernorm2(out1 + ffn_output) |
| 1280 | + |
| 1281 | + inputs = layers.Input(shape=(None,), dtype="int32") |
| 1282 | + embedding_layer = layers.Embedding(vocab_size, embed_dim) |
| 1283 | + x = embedding_layer(inputs) |
| 1284 | + transformer_block = SimpleTransformerBlock(embed_dim, num_heads, ff_dim) |
| 1285 | + x = transformer_block(x) |
| 1286 | + outputs = layers.Dense(vocab_size)(x) |
| 1287 | + model = models.Model(inputs=inputs, outputs=outputs) |
| 1288 | + return model |
| 1289 | + |
| 1290 | + |
| 1291 | +# Define parameters for the tests |
| 1292 | +long_text = """gptq is an easy-to-use model quantization library...""" |
| 1293 | +DATASETS = { |
| 1294 | + "string_dataset": [long_text], |
| 1295 | + "generator_dataset": lambda: dummy_dataset_generator( |
| 1296 | + num_samples=16, sequence_length=128 |
| 1297 | + ), |
| 1298 | +} |
| 1299 | +CONFIGS = { |
| 1300 | + "default": {}, |
| 1301 | + "per_channel": {"group_size": -1}, |
| 1302 | + "act_order": {"activation_order": True}, |
| 1303 | + "symmetric": {"symmetric": True}, |
| 1304 | +} |
| 1305 | + |
| 1306 | + |
| 1307 | +def _get_simple_model(): |
| 1308 | + """Builds a simple sequential model for testing.""" |
| 1309 | + return models.Sequential([layers.Dense(10, input_shape=(5,))]) |
| 1310 | + |
| 1311 | + |
| 1312 | +quantize_test_cases = [ |
| 1313 | + # --- Error Scenarios --- |
| 1314 | + ( |
| 1315 | + "gptq", |
| 1316 | + {"weight_bits": 4}, # Invalid config (dict, not GPTQConfig) |
| 1317 | + ValueError, |
| 1318 | + "The `config` argument must be of type", |
| 1319 | + "gptq_with_invalid_config", |
| 1320 | + ), |
| 1321 | + ( |
| 1322 | + "int8", |
| 1323 | + GPTQConfig(dataset=["test"], tokenizer=lambda x: x), |
| 1324 | + ValueError, |
| 1325 | + "is only supported for 'gptq' mode", |
| 1326 | + "non_gptq_with_unsupported_config", |
| 1327 | + ), |
| 1328 | + # --- Valid Scenario --- |
| 1329 | + ( |
| 1330 | + "int8", |
| 1331 | + None, # No config, which is correct |
| 1332 | + None, # No exception expected |
| 1333 | + None, |
| 1334 | + "non_gptq_runs_without_error", |
| 1335 | + ), |
| 1336 | +] |
| 1337 | + |
| 1338 | + |
| 1339 | +@pytest.mark.requires_trainable_backend |
| 1340 | +class TestModelQuantization: |
| 1341 | + def _run_gptq_test_on_dataset(self, dataset, **config_kwargs): |
| 1342 | + """Helper function to run a full GPTQ quantization test.""" |
| 1343 | + if isinstance(dataset, Callable): |
| 1344 | + dataset = dataset() |
| 1345 | + model = get_model_with_dense_attention() |
| 1346 | + rng = np.random.default_rng(seed=42) |
| 1347 | + |
| 1348 | + NUM_SAMPLES = 16 |
| 1349 | + SEQUENCE_LENGTH = 128 |
| 1350 | + VOCAB_SIZE = 1000 |
| 1351 | + W_BITS = 4 |
| 1352 | + |
| 1353 | + mock_tokenizer = lambda text: np.array( |
| 1354 | + [ord(c) % VOCAB_SIZE for c in text] |
| 1355 | + ) |
| 1356 | + mock_tokenizer.tokenize = mock_tokenizer |
| 1357 | + |
| 1358 | + base_config = { |
| 1359 | + "dataset": dataset, |
| 1360 | + "tokenizer": mock_tokenizer, |
| 1361 | + "weight_bits": W_BITS, |
| 1362 | + "num_samples": NUM_SAMPLES, |
| 1363 | + "sequence_length": SEQUENCE_LENGTH, |
| 1364 | + "group_size": 32, |
| 1365 | + "symmetric": False, |
| 1366 | + "activation_order": False, |
| 1367 | + } |
| 1368 | + |
| 1369 | + target_layer = model.layers[2].ffn.layers[0] |
| 1370 | + assert target_layer is not None |
| 1371 | + original_weights = np.copy(target_layer.kernel) |
| 1372 | + |
| 1373 | + final_config = {**base_config, **config_kwargs} |
| 1374 | + gptq_config = GPTQConfig(**final_config) |
| 1375 | + |
| 1376 | + model.quantize("gptq", config=gptq_config) |
| 1377 | + |
| 1378 | + quantized_weights = target_layer.kernel |
| 1379 | + |
| 1380 | + assert not np.allclose(original_weights, quantized_weights) |
| 1381 | + |
| 1382 | + dummy_sample = rng.integers( |
| 1383 | + low=0, high=VOCAB_SIZE, size=(1, SEQUENCE_LENGTH) |
| 1384 | + ) |
| 1385 | + _ = model.predict(dummy_sample) |
| 1386 | + |
| 1387 | + @pytest.mark.parametrize("dataset", DATASETS.values(), ids=DATASETS.keys()) |
| 1388 | + @pytest.mark.parametrize("config", CONFIGS.values(), ids=CONFIGS.keys()) |
| 1389 | + def test_quantize_gptq_combinations(self, dataset, config): |
| 1390 | + """Runs GPTQ tests across different datasets and config variations.""" |
| 1391 | + self._run_gptq_test_on_dataset(dataset, **config) |
| 1392 | + |
| 1393 | + @pytest.mark.parametrize( |
| 1394 | + "mode, config, expected_exception, match_message, test_id", |
| 1395 | + quantize_test_cases, |
| 1396 | + ids=[case[-1] for case in quantize_test_cases], |
| 1397 | + ) |
| 1398 | + def test_quantize_scenarios( |
| 1399 | + self, mode, config, expected_exception, match_message, test_id |
| 1400 | + ): |
| 1401 | + """ |
| 1402 | + Tests various scenarios for the model.quantize() method, including |
| 1403 | + error handling and valid calls. |
| 1404 | + """ |
| 1405 | + model = _get_simple_model() |
| 1406 | + |
| 1407 | + if expected_exception: |
| 1408 | + # Test for cases where an error is expected |
| 1409 | + with pytest.raises(expected_exception, match=match_message): |
| 1410 | + model.quantize(mode, config=config) |
| 1411 | + else: |
| 1412 | + # Test for valid cases where no error should occur |
| 1413 | + try: |
| 1414 | + model.quantize(mode, config=config) |
| 1415 | + except (ValueError, TypeError) as e: |
| 1416 | + pytest.fail(f"Test case '{test_id}' failed unexpectedly: {e}") |
0 commit comments