Skip to content

Commit 2605059

Browse files
authored
Tokenization unit tests (#90)
* tokenizer unit tests Signed-off-by: Sage Ahrac <[email protected]> * add tokenizer pool tests Signed-off-by: Sage Ahrac <[email protected]> * lint Signed-off-by: Sage Ahrac <[email protected]> * test internals Signed-off-by: Sage Ahrac <[email protected]> * lint Signed-off-by: Sage Ahrac <[email protected]> * lint Signed-off-by: Sage Ahrac <[email protected]> * lint Signed-off-by: Sage Ahrac <[email protected]> * tidy Signed-off-by: Sage Ahrac <[email protected]> --------- Signed-off-by: Sage Ahrac <[email protected]>
1 parent 107e3a9 commit 2605059

File tree

3 files changed

+248
-0
lines changed

3 files changed

+248
-0
lines changed

go.mod

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ require (
4343
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
4444
github.com/prometheus/common v0.62.0 // indirect
4545
github.com/prometheus/procfs v0.15.1 // indirect
46+
github.com/stretchr/objx v0.5.2 // indirect
4647
github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect
4748
github.com/x448/float16 v0.8.4 // indirect
4849
github.com/yuin/gopher-lua v1.1.1 // indirect

pkg/tokenization/pool_test.go

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
/*
2+
Copyright 2025 The llm-d Authors.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
//nolint:testpackage // need to test internal types
18+
package tokenization
19+
20+
import (
21+
"context"
22+
"testing"
23+
"time"
24+
25+
"github.com/daulet/tokenizers"
26+
"github.com/stretchr/testify/assert"
27+
"github.com/stretchr/testify/mock"
28+
"github.com/stretchr/testify/require"
29+
)
30+
31+
// MockTokenizer implements the Tokenizer interface for testing.
32+
type MockTokenizer struct {
33+
mock.Mock
34+
}
35+
36+
func (m *MockTokenizer) Encode(input, modelName string) ([]uint32, []tokenizers.Offset, error) {
37+
args := m.Called(input, modelName)
38+
return args.Get(0).([]uint32), args.Get(1).([]tokenizers.Offset), args.Error(2) //nolint:errcheck // return mocked values
39+
}
40+
41+
// MockIndexer implements the prefixstore.Indexer interface for testing.
42+
type MockIndexer struct {
43+
mock.Mock
44+
}
45+
46+
func (m *MockIndexer) AddTokenization(modelName, prompt string, tokens []uint32, offsets []tokenizers.Offset) error {
47+
args := m.Called(modelName, prompt, tokens, offsets)
48+
return args.Error(0)
49+
}
50+
51+
func (m *MockIndexer) FindLongestContainedTokens(prompt, modelName string) []uint32 {
52+
args := m.Called(prompt, modelName)
53+
return args.Get(0).([]uint32) //nolint:errcheck // unused mock
54+
}
55+
56+
func TestPool_ProcessTask(t *testing.T) {
57+
mockIndexer := &MockIndexer{}
58+
mockTokenizer := &MockTokenizer{}
59+
60+
pool := &Pool{
61+
workers: 1,
62+
indexer: mockIndexer,
63+
tokenizer: mockTokenizer,
64+
}
65+
66+
task := Task{
67+
Prompt: "hello world",
68+
ModelName: testModelName,
69+
}
70+
71+
// Setup specific mock return values
72+
expectedTokens := []uint32{12345, 67890, 11111}
73+
expectedOffsets := []tokenizers.Offset{{0, 5}, {6, 11}}
74+
75+
mockTokenizer.On("Encode", task.Prompt, task.ModelName).Return(expectedTokens, expectedOffsets, nil)
76+
77+
// Verify that indexer receives exactly the same tokens and offsets that tokenizer returned
78+
mockIndexer.On("AddTokenization", task.ModelName, task.Prompt, expectedTokens, expectedOffsets).Return(nil)
79+
80+
// Execute
81+
err := pool.processTask(task)
82+
83+
// Assert
84+
assert.NoError(t, err)
85+
mockTokenizer.AssertExpectations(t)
86+
mockIndexer.AssertExpectations(t)
87+
}
88+
89+
func TestPool_RunIntegration(t *testing.T) {
90+
if testing.Short() {
91+
t.Skip("Skipping tokenizer integration test in short mode")
92+
}
93+
94+
mockIndexer := &MockIndexer{}
95+
96+
prompts := []string{"hello world", "this is a test", "unicode test: 世界"}
97+
98+
// Setup mock expectations for each prompt
99+
for _, prompt := range prompts {
100+
mockIndexer.On("AddTokenization", testModelName, prompt,
101+
mock.Anything, mock.Anything).Return(nil).Once()
102+
}
103+
104+
config := &Config{
105+
WorkersCount: 2,
106+
HFTokenizerConfig: &HFTokenizerConfig{
107+
TokenizersCacheDir: t.TempDir(),
108+
},
109+
}
110+
111+
pool, err := NewTokenizationPool(config, mockIndexer)
112+
require.NoError(t, err)
113+
114+
// Create context for the pool
115+
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
116+
defer cancel()
117+
118+
for _, prompt := range prompts {
119+
pool.AddTask(prompt, testModelName)
120+
}
121+
122+
// Run pool
123+
done := make(chan struct{})
124+
go func() {
125+
defer close(done)
126+
pool.Run(ctx)
127+
}()
128+
129+
time.Sleep(2 * time.Second)
130+
cancel()
131+
<-done
132+
133+
mockIndexer.AssertExpectations(t)
134+
}

pkg/tokenization/tokenizer_test.go

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
/*
2+
Copyright 2025 The llm-d Authors.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
//nolint:testpackage // need to test internal types
18+
package tokenization
19+
20+
import (
21+
"testing"
22+
23+
"github.com/stretchr/testify/assert"
24+
"github.com/stretchr/testify/require"
25+
)
26+
27+
// This should be skipped in fast unit tests.
28+
const testModelName = "google-bert/bert-base-uncased"
29+
30+
func TestCachedHFTokenizer_Encode(t *testing.T) {
31+
if testing.Short() {
32+
t.Skip("Skipping tokenizer integration test in short mode")
33+
}
34+
35+
config := &HFTokenizerConfig{
36+
TokenizersCacheDir: t.TempDir(),
37+
}
38+
tokenizer, err := NewCachedHFTokenizer(config)
39+
require.NoError(t, err)
40+
require.NotNil(t, tokenizer)
41+
42+
tests := []struct {
43+
name string
44+
input string
45+
modelName string
46+
}{
47+
{
48+
name: "simple text",
49+
input: "hello world",
50+
modelName: testModelName,
51+
},
52+
{
53+
name: "empty string",
54+
input: "",
55+
modelName: testModelName,
56+
},
57+
}
58+
59+
for _, tt := range tests {
60+
t.Run(tt.name, func(t *testing.T) {
61+
tokenIds, offsets, err := tokenizer.Encode(tt.input, tt.modelName)
62+
63+
assert.NoError(t, err)
64+
assert.GreaterOrEqual(t, len(tokenIds), 0)
65+
assert.Equal(t, len(tokenIds), len(offsets))
66+
})
67+
}
68+
}
69+
70+
func TestCachedHFTokenizer_CacheTokenizer(t *testing.T) {
71+
if testing.Short() {
72+
t.Skip("Skipping tokenizer integration test in short mode")
73+
}
74+
75+
tokenizer, err := NewCachedHFTokenizer(&HFTokenizerConfig{
76+
TokenizersCacheDir: t.TempDir(),
77+
})
78+
require.NoError(t, err)
79+
require.NotNil(t, tokenizer)
80+
81+
// Test that the same model is cached
82+
input := "test input"
83+
84+
// First call - loads tokenizer
85+
tokenIds1, offsets1, err1 := tokenizer.Encode(input, testModelName)
86+
require.NoError(t, err1)
87+
88+
// Second call - should use cached tokenizer
89+
tokenIds2, offsets2, err2 := tokenizer.Encode(input, testModelName)
90+
require.NoError(t, err2)
91+
92+
// Results should be identical
93+
assert.Equal(t, tokenIds1, tokenIds2)
94+
assert.Equal(t, offsets1, offsets2)
95+
}
96+
97+
func TestCachedHFTokenizer_InvalidModel(t *testing.T) {
98+
if testing.Short() {
99+
t.Skip("Skipping tokenizer integration test in short mode")
100+
}
101+
102+
tokenizer, err := NewCachedHFTokenizer(&HFTokenizerConfig{
103+
TokenizersCacheDir: t.TempDir(),
104+
})
105+
require.NoError(t, err)
106+
require.NotNil(t, tokenizer)
107+
108+
// Test with non-existent model
109+
tokenIds, offsets, err := tokenizer.Encode("test", "non-existent/model")
110+
assert.Error(t, err)
111+
assert.Nil(t, tokenIds)
112+
assert.Nil(t, offsets)
113+
}

0 commit comments

Comments
 (0)