Skip to content

Commit f83ae0b

Browse files
Added test case to improve the coverage
1 parent 2ccbf61 commit f83ae0b

File tree

1 file changed

+57
-0
lines changed

1 file changed

+57
-0
lines changed

keras/src/models/model_test.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1306,6 +1306,38 @@ def call(self, inputs):
13061306
}
13071307

13081308

1309+
def _get_simple_model():
1310+
"""Builds a simple sequential model for testing."""
1311+
return models.Sequential([layers.Dense(10, input_shape=(5,))])
1312+
1313+
1314+
quantize_test_cases = [
1315+
# --- Error Scenarios ---
1316+
(
1317+
"gptq",
1318+
{"wbits": 4}, # Invalid config (dict, not GPTQConfig)
1319+
TypeError,
1320+
"must pass a `config` argument of type",
1321+
"gptq_with_invalid_config",
1322+
),
1323+
(
1324+
"int8",
1325+
GPTQConfig(dataset=["test"], tokenizer=lambda x: x),
1326+
ValueError,
1327+
"is only supported for 'gptq' mode",
1328+
"non_gptq_with_unsupported_config",
1329+
),
1330+
# --- Valid Scenario ---
1331+
(
1332+
"int8",
1333+
None, # No config, which is correct
1334+
None, # No exception expected
1335+
None,
1336+
"non_gptq_runs_without_error",
1337+
),
1338+
]
1339+
1340+
13091341
@pytest.mark.requires_trainable_backend
13101342
class TestModelQuantization:
13111343
def _run_gptq_test_on_dataset(self, dataset, **config_kwargs):
@@ -1359,3 +1391,28 @@ def _run_gptq_test_on_dataset(self, dataset, **config_kwargs):
13591391
def test_quantize_gptq_combinations(self, dataset, config):
13601392
"""Runs GPTQ tests across different datasets and config variations."""
13611393
self._run_gptq_test_on_dataset(dataset, **config)
1394+
1395+
@pytest.mark.parametrize(
1396+
"mode, config, expected_exception, match_message, test_id",
1397+
quantize_test_cases,
1398+
ids=[case[-1] for case in quantize_test_cases],
1399+
)
1400+
def test_quantize_scenarios(
1401+
self, mode, config, expected_exception, match_message, test_id
1402+
):
1403+
"""
1404+
Tests various scenarios for the model.quantize() method, including
1405+
error handling and valid calls.
1406+
"""
1407+
model = _get_simple_model()
1408+
1409+
if expected_exception:
1410+
# Test for cases where an error is expected
1411+
with pytest.raises(expected_exception, match=match_message):
1412+
model.quantize(mode, config=config)
1413+
else:
1414+
# Test for valid cases where no error should occur
1415+
try:
1416+
model.quantize(mode, config=config)
1417+
except (ValueError, TypeError) as e:
1418+
pytest.fail(f"Test case '{test_id}' failed unexpectedly: {e}")

0 commit comments

Comments
 (0)