Skip to content

Commit 176eb63

Browse files
Added test case to improve the coverage
1 parent 80e5cf5 commit 176eb63

File tree

1 file changed

+70
-0
lines changed

1 file changed

+70
-0
lines changed
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
import pytest
2+
3+
from keras.src import layers
4+
from keras.src import models
5+
from keras.src.quantizers import gptq_core
6+
from keras.src.quantizers.gptq_config import GPTQConfig
7+
8+
9+
def _get_model_no_embedding():
10+
"""Returns a simple model that lacks an Embedding layer."""
11+
return models.Sequential([layers.Dense(10, input_shape=(5,))])
12+
13+
14+
def _get_model_no_blocks():
15+
"""Returns a model with an embedding layer but no subsequent container
16+
layers."""
17+
return models.Sequential([layers.Embedding(100, 10, input_shape=(5,))])
18+
19+
20+
class MockTokenizer:
21+
"""A mock tokenizer that mimics the real API for testing."""
22+
23+
def tokenize(self, text):
24+
return [ord(c) for c in "".join(text)]
25+
26+
def __call__(self, text):
27+
return self.tokenize(text)
28+
29+
30+
architecture_test_cases = [
31+
(
32+
_get_model_no_embedding(),
33+
"Could not automatically find an embedding layer",
34+
"no_embedding_layer",
35+
),
36+
(
37+
_get_model_no_blocks(),
38+
"Could not automatically find any transformer-like blocks",
39+
"no_transformer_blocks",
40+
),
41+
]
42+
43+
44+
@pytest.mark.requires_trainable_backend
45+
class TestGPTQCore:
46+
def test_get_dataloader_with_empty_dataset(self):
47+
"""
48+
Tests that get_dataloader raises a ValueError for an empty dataset.
49+
"""
50+
with pytest.raises(ValueError, match="Provided dataset is empty"):
51+
gptq_core.get_dataloader(
52+
tokenizer=MockTokenizer(), seqlen=10, dataset=[], nsamples=10
53+
)
54+
55+
@pytest.mark.parametrize(
56+
"model, match_message, test_id",
57+
architecture_test_cases,
58+
ids=[case[-1] for case in architecture_test_cases],
59+
)
60+
def test_apply_gptq_with_unsupported_architectures(
61+
self, model, match_message, test_id
62+
):
63+
"""
64+
Tests that quantize fails correctly for various unsupported model
65+
architectures.
66+
"""
67+
config = GPTQConfig(dataset=["test"], tokenizer=MockTokenizer())
68+
69+
with pytest.raises(ValueError, match=match_message):
70+
model.quantize("gptq", config=config)

0 commit comments

Comments
 (0)