Skip to content

Commit 085923e

Browse files
reworked on the review comments
1 parent 4ac758e commit 085923e

File tree

1 file changed

+5
-6
lines changed

1 file changed

+5
-6
lines changed

keras/src/models/model_test.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1245,8 +1245,9 @@ def test_export_error(self):
12451245
def dummy_dataset_generator(nsamples, seqlen, vocab_size=1000):
12461246
"""A generator that yields random numpy arrays for fast,
12471247
self-contained tests."""
1248+
rng = np.random.default_rng(seed=42)
12481249
for _ in range(nsamples):
1249-
yield np.random.randint(0, vocab_size, size=(1, seqlen))
1250+
yield rng.integers(low=0, high=vocab_size, size=(1, seqlen))
12501251

12511252

12521253
# Helper function to build a simple transformer model that uses standard
@@ -1327,7 +1328,7 @@ def _run_gptq_test_on_dataset(self, dataset, **config_kwargs):
13271328
target_layer,
13281329
"Test setup failed: No Dense layer found in 'ffn' block.",
13291330
)
1330-
original_weights = np.copy(target_layer.kernel.numpy())
1331+
original_weights = np.copy(target_layer.kernel)
13311332

13321333
# Configure and run quantization
13331334
final_config = {**base_config, **config_kwargs}
@@ -1336,7 +1337,7 @@ def _run_gptq_test_on_dataset(self, dataset, **config_kwargs):
13361337
model.quantize("gptq", quant_config=gptq_config)
13371338

13381339
# Assertions and verification
1339-
quantized_weights = target_layer.kernel.numpy()
1340+
quantized_weights = target_layer.kernel
13401341

13411342
self.assertNotAllClose(
13421343
original_weights,
@@ -1353,7 +1354,7 @@ def test_quantize_gptq_on_different_datasets(self):
13531354
"""Tests GPTQ with various dataset types (string list, generator)."""
13541355

13551356
# Define the datasets to be tested
1356-
long_text = """auto-gptq is an easy-to-use model quantization library
1357+
long_text = """gptq is an easy-to-use model quantization library
13571358
with user-friendly apis, based on GPTQ algorithm. The goal is to
13581359
quantize pre-trained models to 4-bit or even 3-bit precision with
13591360
minimal performance degradation.
@@ -1374,8 +1375,6 @@ def test_quantize_gptq_on_different_datasets(self):
13741375

13751376
# Loop through the datasets and run each as a sub-test
13761377
for dataset_name, dataset in datasets_to_test.items():
1377-
# 'with self.subTest(...)' ensures that failures are reported
1378-
# for each specific dataset without stopping the whole test.
13791378
with self.subTest(dataset_type=dataset_name):
13801379
self._run_gptq_test_on_dataset(dataset)
13811380

0 commit comments

Comments
 (0)