Skip to content

Commit 03a2ebf

Browse files
committed
test: Test case for preprpreprocessor test file
1 parent 92b0d70 commit 03a2ebf

File tree

1 file changed

+54
-0
lines changed

1 file changed

+54
-0
lines changed
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import numpy as np
2+
from keras import ops
3+
from keras_hub.src.tests.test_case import TestCase
4+
from keras_hub.src.models.modernbert.modernbert_tokenizer import ModernBertTokenizer
5+
from keras_hub.src.models.modernbert.modernbert_preprocessor import ModernBertPreprocessor
6+
7+
class ModernBertPreprocessorTest(TestCase):
8+
"""
9+
Test suite for ModernBertPreprocessor.
10+
"""
11+
def setUp(self):
12+
self.vocab = {
13+
"<|endoftext|>": 0,
14+
"<|padding|>": 1,
15+
"a": 2, "b": 3, "c": 4, "ab": 5,
16+
}
17+
self.merges = ["a b"]
18+
self.tokenizer = ModernBertTokenizer(
19+
vocabulary=self.vocab,
20+
merges=self.merges
21+
)
22+
self.preprocessor = ModernBertPreprocessor(
23+
tokenizer=self.tokenizer,
24+
sequence_length=4
25+
)
26+
27+
def test_preprocess_dict(self):
28+
"""Check that output is a dict with correct keys and shapes."""
29+
input_data = ["ab"]
30+
output = self.preprocessor(input_data)
31+
self.assertAllEqual(ops.shape(output["token_ids"]), [1, 4])
32+
self.assertAllEqual(ops.shape(output["padding_mask"]), [1, 4])
33+
34+
def test_padding_logic(self):
35+
"""Verify that sequence padding and ID mapping work correctly."""
36+
input_data = ["a"]
37+
output = self.preprocessor(input_data)
38+
token_ids = ops.convert_to_numpy(output["token_ids"])
39+
# Expected: [2, 1, 1, 1] (ID 2 is 'a', ID 1 is pad)
40+
self.assertEqual(token_ids[0, 0], 2)
41+
self.assertEqual(token_ids[0, 1], 1)
42+
43+
def test_serialization(self):
44+
"""
45+
Ensure preprocessor can be reconstructed from config.
46+
"""
47+
new_preprocessor = ModernBertPreprocessor.from_config(
48+
self.preprocessor.get_config()
49+
)
50+
self.assertEqual(new_preprocessor.sequence_length, 4)
51+
self.assertEqual(
52+
new_preprocessor.tokenizer.pad_token_id,
53+
self.tokenizer.pad_token_id
54+
)

0 commit comments

Comments
 (0)