|
| 1 | +import pytest |
| 2 | + |
| 3 | +from keras.src import layers |
| 4 | +from keras.src import models |
| 5 | +from keras.src.quantizers import gptq_core |
| 6 | +from keras.src.quantizers.gptq_config import GPTQConfig |
| 7 | + |
| 8 | + |
| 9 | +def _get_model_no_embedding(): |
| 10 | + """Returns a simple model that lacks an Embedding layer.""" |
| 11 | + return models.Sequential([layers.Dense(10, input_shape=(5,))]) |
| 12 | + |
| 13 | + |
| 14 | +def _get_model_no_blocks(): |
| 15 | + """Returns a model with an embedding layer but no subsequent container |
| 16 | + layers.""" |
| 17 | + return models.Sequential([layers.Embedding(100, 10, input_shape=(5,))]) |
| 18 | + |
| 19 | + |
| 20 | +class MockTokenizer: |
| 21 | + """A mock tokenizer that mimics the real API for testing.""" |
| 22 | + |
| 23 | + def tokenize(self, text): |
| 24 | + return [ord(c) for c in "".join(text)] |
| 25 | + |
| 26 | + def __call__(self, text): |
| 27 | + return self.tokenize(text) |
| 28 | + |
| 29 | + |
| 30 | +architecture_test_cases = [ |
| 31 | + ( |
| 32 | + _get_model_no_embedding(), |
| 33 | + "Could not automatically find an embedding layer", |
| 34 | + "no_embedding_layer", |
| 35 | + ), |
| 36 | + ( |
| 37 | + _get_model_no_blocks(), |
| 38 | + "Could not automatically find any transformer-like blocks", |
| 39 | + "no_transformer_blocks", |
| 40 | + ), |
| 41 | +] |
| 42 | + |
| 43 | + |
| 44 | +@pytest.mark.requires_trainable_backend |
| 45 | +class TestGPTQCore: |
| 46 | + def test_get_dataloader_with_empty_dataset(self): |
| 47 | + """ |
| 48 | + Tests that get_dataloader raises a ValueError for an empty dataset. |
| 49 | + """ |
| 50 | + with pytest.raises(ValueError, match="Provided dataset is empty"): |
| 51 | + gptq_core.get_dataloader( |
| 52 | + tokenizer=MockTokenizer(), seqlen=10, dataset=[], nsamples=10 |
| 53 | + ) |
| 54 | + |
| 55 | + @pytest.mark.parametrize( |
| 56 | + "model, match_message, test_id", |
| 57 | + architecture_test_cases, |
| 58 | + ids=[case[-1] for case in architecture_test_cases], |
| 59 | + ) |
| 60 | + def test_apply_gptq_with_unsupported_architectures( |
| 61 | + self, model, match_message, test_id |
| 62 | + ): |
| 63 | + """ |
| 64 | + Tests that quantize fails correctly for various unsupported model |
| 65 | + architectures. |
| 66 | + """ |
| 67 | + config = GPTQConfig(dataset=["test"], tokenizer=MockTokenizer()) |
| 68 | + |
| 69 | + with pytest.raises(ValueError, match=match_message): |
| 70 | + model.quantize("gptq", config=config) |
0 commit comments