Skip to content

Commit 37370e0

Browse files
Added test case to improve the coverage
1 parent 176eb63 commit 37370e0

File tree

1 file changed

+122
-32
lines changed

1 file changed

+122
-32
lines changed
Lines changed: 122 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,56 +1,146 @@
11
import pytest
2+
from absl import logging
23

34
from keras.src import layers
45
from keras.src import models
56
from keras.src.quantizers import gptq_core
67
from keras.src.quantizers.gptq_config import GPTQConfig
78

8-
9-
def _get_model_no_embedding():
10-
"""Returns a simple model that lacks an Embedding layer."""
11-
return models.Sequential([layers.Dense(10, input_shape=(5,))])
12-
13-
14-
def _get_model_no_blocks():
15-
"""Returns a model with an embedding layer but no subsequent container
16-
layers."""
17-
return models.Sequential([layers.Embedding(100, 10, input_shape=(5,))])
9+
VOCAB_SIZE = 100
1810

1911

2012
class MockTokenizer:
2113
"""A mock tokenizer that mimics the real API for testing."""
2214

2315
def tokenize(self, text):
24-
return [ord(c) for c in "".join(text)]
16+
return [ord(c) % VOCAB_SIZE for c in "".join(text)]
2517

2618
def __call__(self, text):
2719
return self.tokenize(text)
2820

2921

30-
architecture_test_cases = [
31-
(
32-
_get_model_no_embedding(),
33-
"Could not automatically find an embedding layer",
34-
"no_embedding_layer",
35-
),
36-
(
37-
_get_model_no_blocks(),
38-
"Could not automatically find any transformer-like blocks",
39-
"no_transformer_blocks",
40-
),
41-
]
22+
class MockEmptyBlock(layers.Layer):
23+
"""A mock block that contains no quantizable layers."""
24+
25+
def __init__(self, **kwargs):
26+
super().__init__(**kwargs)
27+
self.ln = layers.LayerNormalization()
28+
29+
def call(self, inputs):
30+
return self.ln(inputs)
31+
32+
33+
class MockTransformerBlock(layers.Layer):
34+
"""A mock transformer block with a quantizable Dense layer."""
35+
36+
def __init__(self, **kwargs):
37+
super().__init__(**kwargs)
38+
self.dense = layers.Dense(128)
39+
40+
def call(self, inputs):
41+
return self.dense(inputs)
42+
43+
44+
def _get_model_with_backbone(
45+
has_transformer_layers=True, embedding_name="embedding"
46+
):
47+
"""Creates a mock KerasNLP-style model with a backbone."""
48+
49+
class MockBackbone(layers.Layer):
50+
def __init__(self, **kwargs):
51+
super().__init__(**kwargs)
52+
if has_transformer_layers:
53+
self.transformer_layers = [MockTransformerBlock()]
54+
setattr(self, embedding_name, layers.Embedding(VOCAB_SIZE, 128))
55+
56+
class MockModel(models.Model):
57+
def __init__(self, **kwargs):
58+
super().__init__(**kwargs)
59+
self.backbone = MockBackbone()
60+
61+
def call(self, inputs):
62+
return self.backbone(inputs)
63+
64+
model = MockModel()
65+
model.build(input_shape=(None, 10))
66+
return model
4267

4368

4469
@pytest.mark.requires_trainable_backend
4570
class TestGPTQCore:
46-
def test_get_dataloader_with_empty_dataset(self):
47-
"""
48-
Tests that get_dataloader raises a ValueError for an empty dataset.
49-
"""
71+
def test_get_dataloader_error_scenarios(self):
72+
"""Tests error cases for get_dataloader."""
5073
with pytest.raises(ValueError, match="Provided dataset is empty"):
5174
gptq_core.get_dataloader(
5275
tokenizer=MockTokenizer(), seqlen=10, dataset=[], nsamples=10
5376
)
77+
with pytest.raises(
78+
TypeError,
79+
match="Providing a dataset name as a "
80+
"string is not supported. Please pass the "
81+
"loaded dataset directly.",
82+
):
83+
gptq_core.get_dataloader(
84+
tokenizer=MockTokenizer(),
85+
seqlen=10,
86+
dataset="wikitext2",
87+
nsamples=10,
88+
)
89+
90+
def test_apply_gptq_on_multi_block_model(self):
91+
"""Tests quantization on a model with multiple blocks."""
92+
model = models.Sequential(
93+
[
94+
layers.Embedding(VOCAB_SIZE, 128),
95+
MockTransformerBlock(),
96+
MockTransformerBlock(),
97+
]
98+
)
99+
model.build(input_shape=(None, 10))
100+
config = GPTQConfig(
101+
dataset=["test data"], tokenizer=MockTokenizer(), group_size=32
102+
)
103+
try:
104+
model.quantize("gptq", config=config)
105+
except Exception as e:
106+
pytest.fail(f"Multi-block quantization failed unexpectedly: {e}")
107+
108+
def test_apply_gptq_with_empty_block(self, caplog):
109+
"""Tests that a block with no quantizable layers is skipped
110+
correctly."""
111+
caplog.set_level(logging.INFO)
112+
model = models.Sequential(
113+
[layers.Embedding(VOCAB_SIZE, 10), MockEmptyBlock()]
114+
)
115+
model.build(input_shape=(None, 10))
116+
config = GPTQConfig(dataset=["test data"], tokenizer=MockTokenizer())
117+
model.quantize("gptq", config=config)
118+
assert "No Dense or EinsumDense layers found" in caplog.text
119+
120+
architecture_test_cases = [
121+
(
122+
models.Sequential([layers.Dense(10)]),
123+
"Could not automatically find an embedding layer",
124+
"no_embedding_layer",
125+
),
126+
(
127+
models.Sequential(
128+
[layers.Embedding(VOCAB_SIZE, 10), layers.Dense(10)]
129+
),
130+
"Could not automatically find any transformer-like blocks",
131+
"no_transformer_blocks",
132+
),
133+
(
134+
_get_model_with_backbone(has_transformer_layers=False),
135+
"backbone does not have a 'transformer_layers' attribute",
136+
"backbone_no_layers",
137+
),
138+
(
139+
_get_model_with_backbone(embedding_name="wrong_name"),
140+
"Could not automatically find an embedding layer in the model",
141+
"backbone_no_embedding",
142+
),
143+
]
54144

55145
@pytest.mark.parametrize(
56146
"model, match_message, test_id",
@@ -60,11 +150,11 @@ def test_get_dataloader_with_empty_dataset(self):
60150
def test_apply_gptq_with_unsupported_architectures(
61151
self, model, match_message, test_id
62152
):
63-
"""
64-
Tests that quantize fails correctly for various unsupported model
65-
architectures.
66-
"""
67-
config = GPTQConfig(dataset=["test"], tokenizer=MockTokenizer())
153+
"""Tests that quantize fails correctly for various unsupported
154+
model architectures."""
155+
if not model.built:
156+
model.build(input_shape=(None, 10))
68157

158+
config = GPTQConfig(dataset=["test"], tokenizer=MockTokenizer())
69159
with pytest.raises(ValueError, match=match_message):
70160
model.quantize("gptq", config=config)

0 commit comments

Comments
 (0)