|
| 1 | +import io |
| 2 | +import logging |
1 | 3 | import os
|
2 | 4 | import pickle
|
| 5 | +import tarfile |
3 | 6 | from collections import namedtuple
|
4 | 7 |
|
5 | 8 | import numpy as np
|
6 | 9 | import pytest
|
| 10 | +import requests |
7 | 11 | from absl.testing import parameterized
|
| 12 | +from datasets import load_dataset |
8 | 13 |
|
9 | 14 | from keras.src import backend
|
10 | 15 | from keras.src import layers
|
11 | 16 | from keras.src import losses
|
| 17 | +from keras.src import models |
12 | 18 | from keras.src import testing
|
13 | 19 | from keras.src import tree
|
14 | 20 | from keras.src.layers.core.input_layer import Input
|
15 | 21 | from keras.src.models.functional import Functional
|
16 | 22 | from keras.src.models.model import Model
|
17 | 23 | from keras.src.models.model import model_from_json
|
| 24 | +from keras.src.quantizers.gptqconfig import GPTQConfig |
| 25 | + |
| 26 | +# Configure logging |
| 27 | +logging.basicConfig(level=logging.INFO) |
| 28 | + |
| 29 | + |
| 30 | +def get_dataset_text(dataset_identifier: str, nsamples=1000) -> str: |
| 31 | + """ |
| 32 | + Loads a specified dataset and extracts its text content into a |
| 33 | + single string. |
| 34 | + """ |
| 35 | + DATASET_CONFIGS = { |
| 36 | + "wikitext2": { |
| 37 | + "name": "wikitext", |
| 38 | + "config": "wikitext-2-raw-v1", |
| 39 | + "split": "test", |
| 40 | + "text_column": "text", |
| 41 | + }, |
| 42 | + "ptb": { |
| 43 | + "name": "ptb_text_only", |
| 44 | + "config": "penn_treebank", |
| 45 | + "split": "validation", |
| 46 | + "text_column": "sentence", |
| 47 | + }, |
| 48 | + "c4": { |
| 49 | + "name": "allenai/c4", |
| 50 | + "config": "en", |
| 51 | + "split": "validation", # Use validation for C4's test split |
| 52 | + "text_column": "text", |
| 53 | + }, |
| 54 | + } |
| 55 | + |
| 56 | + if dataset_identifier not in DATASET_CONFIGS: |
| 57 | + raise ValueError( |
| 58 | + f"Unknown dataset identifier '{dataset_identifier}'. " |
| 59 | + f"Available options are: {list(DATASET_CONFIGS.keys())}" |
| 60 | + ) |
| 61 | + |
| 62 | + config = DATASET_CONFIGS[dataset_identifier] |
| 63 | + |
| 64 | + if dataset_identifier == "ptb": |
| 65 | + url = "http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz" |
| 66 | + try: |
| 67 | + # 1. Download the archive into memory |
| 68 | + response = requests.get(url) |
| 69 | + response.raise_for_status() |
| 70 | + |
| 71 | + # 2. Extract only the test file from the in-memory archive |
| 72 | + with tarfile.open( |
| 73 | + fileobj=io.BytesIO(response.content), mode="r:gz" |
| 74 | + ) as tar: |
| 75 | + test_path = "./simple-examples/data/ptb.test.txt" |
| 76 | + test_bytes = tar.extractfile(test_path).read() |
| 77 | + |
| 78 | + # 3. Decode the bytes and join into a single string |
| 79 | + test_lines = test_bytes.decode("utf-8").strip().split("\n") |
| 80 | + all_text = "\n\n".join(test_lines) |
| 81 | + |
| 82 | + print("✅ Successfully processed PTB test data.") |
| 83 | + return all_text |
| 84 | + |
| 85 | + except Exception as e: |
| 86 | + print(f"Failed to download or process PTB data: {e!r}") |
| 87 | + raise e |
| 88 | + |
| 89 | + load_kwargs = {"name": config["config"]} |
| 90 | + |
| 91 | + if dataset_identifier == "c4": |
| 92 | + load_kwargs["streaming"] = True |
| 93 | + # For PTB, force a redownload to bypass potential cache errors. |
| 94 | + if dataset_identifier == "ptb": |
| 95 | + load_kwargs["download_mode"] = "force_redownload" |
| 96 | + |
| 97 | + print(f"Loading dataset '{config['name']}'...") |
| 98 | + |
| 99 | + test_data = load_dataset( |
| 100 | + config["name"], split=config["split"], **load_kwargs |
| 101 | + ) |
| 102 | + |
| 103 | + if dataset_identifier == "c4": |
| 104 | + print(f" -> Limiting C4 to the first {nsamples} documents forspeed.") |
| 105 | + test_data = test_data.take(nsamples) |
| 106 | + |
| 107 | + all_text = "\n\n".join( |
| 108 | + row[config["text_column"]] |
| 109 | + for row in test_data |
| 110 | + if row.get(config["text_column"]) |
| 111 | + ) |
| 112 | + |
| 113 | + print(f"Successfully loaded and processed {dataset_identifier}.") |
| 114 | + return all_text |
18 | 115 |
|
19 | 116 |
|
20 | 117 | def _get_model():
|
@@ -1237,3 +1334,156 @@ def test_export_error(self):
|
1237 | 1334 | ),
|
1238 | 1335 | ):
|
1239 | 1336 | model.export(temp_filepath, format="tf_saved_model")
|
| 1337 | + |
| 1338 | + |
| 1339 | +# Helper function to generate dummy data for quick testing. |
| 1340 | +def dummy_dataset_generator(nsamples, seqlen, vocab_size=1000): |
| 1341 | + """A generator that yields random numpy arrays for fast, |
| 1342 | + self-contained tests.""" |
| 1343 | + for _ in range(nsamples): |
| 1344 | + yield np.random.randint(0, vocab_size, size=(1, seqlen)) |
| 1345 | + |
| 1346 | + |
| 1347 | +# Helper function to build a simple transformer model that uses standard |
| 1348 | +# Keras `Dense` layers for its attention projections. |
| 1349 | +def _get_model_with_dense_attention(): |
| 1350 | + """Builds a simple transformer model using Dense for attention.""" |
| 1351 | + vocab_size = 1000 |
| 1352 | + embed_dim = 32 |
| 1353 | + num_heads = 4 |
| 1354 | + ff_dim = 32 |
| 1355 | + |
| 1356 | + class SimpleTransformerBlock(layers.Layer): |
| 1357 | + def __init__(self, embed_dim, num_heads, ff_dim, **kwargs): |
| 1358 | + super().__init__(**kwargs) |
| 1359 | + # The standard MultiHeadAttention layer uses Dense layers |
| 1360 | + # for its projections. |
| 1361 | + self.att = layers.MultiHeadAttention( |
| 1362 | + num_heads=num_heads, key_dim=embed_dim |
| 1363 | + ) |
| 1364 | + self.ffn = models.Sequential( |
| 1365 | + [ |
| 1366 | + layers.Dense(ff_dim, activation="relu"), |
| 1367 | + layers.Dense(embed_dim), |
| 1368 | + ] |
| 1369 | + ) |
| 1370 | + self.layernorm1 = layers.LayerNormalization(epsilon=1e-6) |
| 1371 | + self.layernorm2 = layers.LayerNormalization(epsilon=1e-6) |
| 1372 | + |
| 1373 | + def call(self, inputs): |
| 1374 | + attention_output = self.att(inputs, inputs) |
| 1375 | + out1 = self.layernorm1(inputs + attention_output) |
| 1376 | + ffn_output = self.ffn(out1) |
| 1377 | + return self.layernorm2(out1 + ffn_output) |
| 1378 | + |
| 1379 | + inputs = layers.Input(shape=(None,), dtype="int32") |
| 1380 | + embedding_layer = layers.Embedding(vocab_size, embed_dim) |
| 1381 | + x = embedding_layer(inputs) |
| 1382 | + transformer_block = SimpleTransformerBlock(embed_dim, num_heads, ff_dim) |
| 1383 | + x = transformer_block(x) |
| 1384 | + outputs = layers.Dense(vocab_size)(x) |
| 1385 | + model = models.Model(inputs=inputs, outputs=outputs) |
| 1386 | + return model |
| 1387 | + |
| 1388 | + |
| 1389 | +def _run_gptq_test_on_dataset(test_case, dataset): |
| 1390 | + """Helper function to run a full GPTQ quantization |
| 1391 | + test on a given dataset.""" |
| 1392 | + model = _get_model_with_dense_attention() |
| 1393 | + |
| 1394 | + # --- 1. Common Setup --- |
| 1395 | + NUM_SAMPLES = 16 |
| 1396 | + SEQUENCE_LENGTH = 128 |
| 1397 | + VOCAB_SIZE = 1000 |
| 1398 | + W_BITS = 4 |
| 1399 | + GROUP_SIZE = 32 |
| 1400 | + |
| 1401 | + mock_tokenizer = lambda text: np.array([ord(c) % VOCAB_SIZE for c in text]) |
| 1402 | + mock_tokenizer.tokenize = mock_tokenizer |
| 1403 | + |
| 1404 | + # --- 2. Find Target Layer and Get Original Weights --- |
| 1405 | + target_layer = None |
| 1406 | + for layer in model.layers: |
| 1407 | + if hasattr(layer, "ffn") and hasattr(layer.ffn, "layers"): |
| 1408 | + dense_layer_in_ffn = next( |
| 1409 | + ( |
| 1410 | + ffn_layer |
| 1411 | + for ffn_layer in layer.ffn.layers |
| 1412 | + if isinstance(ffn_layer, layers.Dense) |
| 1413 | + ), |
| 1414 | + None, |
| 1415 | + ) |
| 1416 | + if dense_layer_in_ffn: |
| 1417 | + target_layer = dense_layer_in_ffn |
| 1418 | + break |
| 1419 | + |
| 1420 | + test_case.assertIsNotNone( |
| 1421 | + target_layer, |
| 1422 | + "Test setup failed: No Dense layer was found inside an 'ffn' block.", |
| 1423 | + ) |
| 1424 | + original_weights = np.copy(target_layer.kernel.numpy()) |
| 1425 | + |
| 1426 | + # --- 3. Configure and Run Quantization --- |
| 1427 | + gptq_config = GPTQConfig( |
| 1428 | + dataset=dataset, |
| 1429 | + tokenizer=mock_tokenizer, |
| 1430 | + wbits=W_BITS, |
| 1431 | + nsamples=NUM_SAMPLES, |
| 1432 | + seqlen=SEQUENCE_LENGTH, |
| 1433 | + groupsize=GROUP_SIZE, |
| 1434 | + ) |
| 1435 | + model.quantize("gptq", quant_config=gptq_config) |
| 1436 | + |
| 1437 | + # --- 4. Assertions and Verification --- |
| 1438 | + quantized_weights = target_layer.kernel.numpy() |
| 1439 | + |
| 1440 | + # Assert that the weights have been changed |
| 1441 | + test_case.assertFalse( |
| 1442 | + np.allclose(original_weights, quantized_weights), |
| 1443 | + f"Weights were not changed by the GPTQ process for dataset: {dataset}", |
| 1444 | + ) |
| 1445 | + |
| 1446 | + # Verify the quantized model can still make a prediction |
| 1447 | + try: |
| 1448 | + dummy_input = np.random.randint( |
| 1449 | + 0, VOCAB_SIZE, size=(1, SEQUENCE_LENGTH) |
| 1450 | + ) |
| 1451 | + _ = model.predict(dummy_input) |
| 1452 | + except Exception as e: |
| 1453 | + test_case.fail( |
| 1454 | + "Prediction failed for the quantized model with dataset: " |
| 1455 | + f"{dataset}. Error: {e}" |
| 1456 | + ) |
| 1457 | + |
| 1458 | + |
| 1459 | +@pytest.mark.requires_trainable_backend |
| 1460 | +class ModelQuantizationTest(testing.TestCase): |
| 1461 | + def test_quantize_gptq_with_dense_attention(self): |
| 1462 | + """Tests GPTQ with an in-memory list of strings as the dataset.""" |
| 1463 | + |
| 1464 | + long_text = """auto-gptq is an easy-to-use model quantization library |
| 1465 | + with user-friendly apis, based on GPTQ algorithm. The goal is to |
| 1466 | + quantize pre-trained models to 4-bit or even 3-bit precision with |
| 1467 | + minimal performance degradation. |
| 1468 | + This allows for running larger models on less powerful hardware, |
| 1469 | + reducing memory footprint and increasing inference speed. |
| 1470 | + The process involves calibrating the model on a small dataset |
| 1471 | + to determine the quantization parameters. |
| 1472 | + This technique is particularly useful for deploying large language |
| 1473 | + models in resource-constrained environments where every bit of memory |
| 1474 | + and every millisecond of latency counts.""" |
| 1475 | + |
| 1476 | + string_dataset = [long_text] |
| 1477 | + _run_gptq_test_on_dataset(self, string_dataset) |
| 1478 | + |
| 1479 | + def test_quantize_gptq_with_data_gen(self): |
| 1480 | + """Tests GPTQ with a Python generator as the dataset.""" |
| 1481 | + generator_dataset = dummy_dataset_generator( |
| 1482 | + nsamples=16, seqlen=128, vocab_size=1000 |
| 1483 | + ) |
| 1484 | + _run_gptq_test_on_dataset(self, generator_dataset) |
| 1485 | + |
| 1486 | + @pytest.mark.slow |
| 1487 | + def test_quantize_gptq_with_wikitext2(self): |
| 1488 | + """Tests GPTQ with the 'wikitext2' dataset identifier.""" |
| 1489 | + _run_gptq_test_on_dataset(self, "wikitext2") |
0 commit comments