Skip to content

Commit abad2d8

Browse files
ilopezlunaCopilot
andauthored
fix(safetensors): include tokenizer.model (#244)
* fix(safetensors): include tokenizer.model in config file collection and add tests * Update pkg/distribution/packaging/safetensors_test.go Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent 9be9145 commit abad2d8

File tree

2 files changed

+299
-2
lines changed

2 files changed

+299
-2
lines changed

pkg/distribution/packaging/safetensors.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@ func PackageFromDirectory(dirPath string) (safetensorsPaths []string, tempConfig
3636
safetensorsPaths = append(safetensorsPaths, fullPath)
3737
}
3838

39-
// Collect config files: *.json, merges.txt
40-
if strings.HasSuffix(strings.ToLower(name), ".json") || strings.EqualFold(name, "merges.txt") {
39+
// Collect config files: *.json, merges.txt and tokenizer.model
40+
if strings.HasSuffix(strings.ToLower(name), ".json") || strings.EqualFold(name, "merges.txt") || strings.EqualFold(name, "tokenizer.model") {
4141
configFiles = append(configFiles, fullPath)
4242
}
4343
}
Lines changed: 297 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,297 @@
1+
package packaging
2+
3+
import (
4+
"archive/tar"
5+
"io"
6+
"os"
7+
"path/filepath"
8+
"sort"
9+
"strings"
10+
"testing"
11+
)
12+
13+
func TestPackageFromDirectory_WithTokenizerModel(t *testing.T) {
14+
// Create temporary directory
15+
tempDir := t.TempDir()
16+
17+
// Create test files
18+
files := map[string]string{
19+
"model.safetensors": "safetensors content",
20+
"config.json": `{"model_type": "test"}`,
21+
"tokenizer.model": "tokenizer model binary content",
22+
"tokenizer_config.json": `{"tokenizer_class": "TestTokenizer"}`,
23+
"not.included": `not included content`,
24+
}
25+
26+
for name, content := range files {
27+
path := filepath.Join(tempDir, name)
28+
if err := os.WriteFile(path, []byte(content), 0644); err != nil {
29+
t.Fatalf("Failed to create test file %s: %v", name, err)
30+
}
31+
}
32+
33+
// Call PackageFromDirectory
34+
safetensorsPaths, tempConfigArchive, err := PackageFromDirectory(tempDir)
35+
if err != nil {
36+
t.Fatalf("PackageFromDirectory failed: %v", err)
37+
}
38+
39+
// Clean up temp archive
40+
if tempConfigArchive != "" {
41+
defer os.Remove(tempConfigArchive)
42+
}
43+
44+
// Verify safetensors files were found
45+
if len(safetensorsPaths) != 1 {
46+
t.Errorf("Expected 1 safetensors file, got %d", len(safetensorsPaths))
47+
}
48+
49+
// Verify config archive was created
50+
if tempConfigArchive == "" {
51+
t.Fatal("Expected config archive to be created")
52+
}
53+
54+
// Verify tokenizer.model is in the archive
55+
archiveFiles, err := readTarArchive(tempConfigArchive)
56+
if err != nil {
57+
t.Fatalf("Failed to read tar archive: %v", err)
58+
}
59+
60+
expectedFiles := []string{"config.json", "tokenizer.model", "tokenizer_config.json"}
61+
sort.Strings(expectedFiles)
62+
sort.Strings(archiveFiles)
63+
64+
if len(archiveFiles) != len(expectedFiles) {
65+
t.Errorf("Expected %d files in archive, got %d", len(expectedFiles), len(archiveFiles))
66+
}
67+
68+
for i, expected := range expectedFiles {
69+
if i >= len(archiveFiles) || archiveFiles[i] != expected {
70+
t.Errorf("Expected file %s in archive, got %v", expected, archiveFiles)
71+
}
72+
}
73+
}
74+
75+
func TestPackageFromDirectory_BasicFunctionality(t *testing.T) {
76+
// Create temporary directory
77+
tempDir := t.TempDir()
78+
79+
// Create test files
80+
files := map[string]string{
81+
"model-00001-of-00002.safetensors": "safetensors content 1",
82+
"model-00002-of-00002.safetensors": "safetensors content 2",
83+
"config.json": `{"model_type": "test"}`,
84+
"merges.txt": "merge1 merge2",
85+
"tokenizer.model": "tokenizer content",
86+
"special_tokens_map.json": `{"unk_token": "<unk>"}`,
87+
}
88+
89+
for name, content := range files {
90+
path := filepath.Join(tempDir, name)
91+
if err := os.WriteFile(path, []byte(content), 0644); err != nil {
92+
t.Fatalf("Failed to create test file %s: %v", name, err)
93+
}
94+
}
95+
96+
// Call PackageFromDirectory
97+
safetensorsPaths, tempConfigArchive, err := PackageFromDirectory(tempDir)
98+
if err != nil {
99+
t.Fatalf("PackageFromDirectory failed: %v", err)
100+
}
101+
102+
// Clean up temp archive
103+
if tempConfigArchive != "" {
104+
defer os.Remove(tempConfigArchive)
105+
}
106+
107+
// Verify safetensors files
108+
if len(safetensorsPaths) != 2 {
109+
t.Errorf("Expected 2 safetensors files, got %d", len(safetensorsPaths))
110+
}
111+
112+
// Verify files are sorted
113+
for i := 0; i < len(safetensorsPaths)-1; i++ {
114+
if safetensorsPaths[i] > safetensorsPaths[i+1] {
115+
t.Error("Safetensors paths are not sorted")
116+
}
117+
}
118+
119+
// Verify config archive was created
120+
if tempConfigArchive == "" {
121+
t.Fatal("Expected config archive to be created")
122+
}
123+
124+
// Verify archive contents
125+
archiveFiles, err := readTarArchive(tempConfigArchive)
126+
if err != nil {
127+
t.Fatalf("Failed to read tar archive: %v", err)
128+
}
129+
130+
expectedConfigFiles := []string{
131+
"config.json",
132+
"merges.txt",
133+
"tokenizer.model",
134+
"special_tokens_map.json",
135+
}
136+
sort.Strings(expectedConfigFiles)
137+
sort.Strings(archiveFiles)
138+
139+
if len(archiveFiles) != len(expectedConfigFiles) {
140+
t.Errorf("Expected %d config files in archive, got %d", len(expectedConfigFiles), len(archiveFiles))
141+
}
142+
143+
for i, expected := range expectedConfigFiles {
144+
if i >= len(archiveFiles) || archiveFiles[i] != expected {
145+
t.Errorf("Expected file %s in archive at position %d, got %v", expected, i, archiveFiles)
146+
}
147+
}
148+
}
149+
150+
func TestPackageFromDirectory_NoSafetensorsFiles(t *testing.T) {
151+
// Create temporary directory
152+
tempDir := t.TempDir()
153+
154+
// Create only config files (no safetensors)
155+
files := map[string]string{
156+
"config.json": `{"model_type": "test"}`,
157+
"tokenizer.model": "tokenizer content",
158+
}
159+
160+
for name, content := range files {
161+
path := filepath.Join(tempDir, name)
162+
if err := os.WriteFile(path, []byte(content), 0644); err != nil {
163+
t.Fatalf("Failed to create test file %s: %v", name, err)
164+
}
165+
}
166+
167+
// Call PackageFromDirectory
168+
_, _, err := PackageFromDirectory(tempDir)
169+
if err == nil {
170+
t.Fatal("Expected error when no safetensors files found, got nil")
171+
}
172+
173+
expectedError := "no safetensors files found"
174+
if !strings.Contains(err.Error(), expectedError) {
175+
t.Errorf("Expected error containing %q, got: %v", expectedError, err)
176+
}
177+
}
178+
179+
func TestPackageFromDirectory_OnlySafetensorsFiles(t *testing.T) {
180+
// Create temporary directory
181+
tempDir := t.TempDir()
182+
183+
// Create only safetensors files (no config files)
184+
files := map[string]string{
185+
"model.safetensors": "safetensors content",
186+
}
187+
188+
for name, content := range files {
189+
path := filepath.Join(tempDir, name)
190+
if err := os.WriteFile(path, []byte(content), 0644); err != nil {
191+
t.Fatalf("Failed to create test file %s: %v", name, err)
192+
}
193+
}
194+
195+
// Call PackageFromDirectory
196+
safetensorsPaths, tempConfigArchive, err := PackageFromDirectory(tempDir)
197+
if err != nil {
198+
t.Fatalf("PackageFromDirectory failed: %v", err)
199+
}
200+
201+
// Verify safetensors files were found
202+
if len(safetensorsPaths) != 1 {
203+
t.Errorf("Expected 1 safetensors file, got %d", len(safetensorsPaths))
204+
}
205+
206+
// Verify no config archive was created
207+
if tempConfigArchive != "" {
208+
defer os.Remove(tempConfigArchive)
209+
t.Error("Expected no config archive to be created when no config files exist")
210+
}
211+
}
212+
213+
func TestPackageFromDirectory_SkipsSubdirectories(t *testing.T) {
214+
// Create temporary directory
215+
tempDir := t.TempDir()
216+
217+
// Create test files in root
218+
rootFiles := map[string]string{
219+
"model.safetensors": "safetensors content",
220+
"config.json": `{"model_type": "test"}`,
221+
}
222+
223+
for name, content := range rootFiles {
224+
path := filepath.Join(tempDir, name)
225+
if err := os.WriteFile(path, []byte(content), 0644); err != nil {
226+
t.Fatalf("Failed to create test file %s: %v", name, err)
227+
}
228+
}
229+
230+
// Create subdirectory with files that should be ignored
231+
subDir := filepath.Join(tempDir, "subdir")
232+
if err := os.Mkdir(subDir, 0755); err != nil {
233+
t.Fatalf("Failed to create subdirectory: %v", err)
234+
}
235+
236+
subFiles := map[string]string{
237+
"ignored.safetensors": "should be ignored",
238+
"ignored.json": `{"ignored": true}`,
239+
}
240+
241+
for name, content := range subFiles {
242+
path := filepath.Join(subDir, name)
243+
if err := os.WriteFile(path, []byte(content), 0644); err != nil {
244+
t.Fatalf("Failed to create test file in subdir %s: %v", name, err)
245+
}
246+
}
247+
248+
// Call PackageFromDirectory
249+
safetensorsPaths, tempConfigArchive, err := PackageFromDirectory(tempDir)
250+
if err != nil {
251+
t.Fatalf("PackageFromDirectory failed: %v", err)
252+
}
253+
254+
// Clean up temp archive
255+
if tempConfigArchive != "" {
256+
defer os.Remove(tempConfigArchive)
257+
}
258+
259+
// Verify only root-level files were processed
260+
if len(safetensorsPaths) != 1 {
261+
t.Errorf("Expected 1 safetensors file from root directory, got %d", len(safetensorsPaths))
262+
}
263+
264+
archiveFiles, err := readTarArchive(tempConfigArchive)
265+
if err != nil {
266+
t.Fatalf("Failed to read tar archive: %v", err)
267+
}
268+
269+
if len(archiveFiles) != 1 || archiveFiles[0] != "config.json" {
270+
t.Errorf("Expected only config.json from root directory, got %v", archiveFiles)
271+
}
272+
}
273+
274+
// Helper function to read tar archive and return list of file names
275+
func readTarArchive(archivePath string) ([]string, error) {
276+
file, err := os.Open(archivePath)
277+
if err != nil {
278+
return nil, err
279+
}
280+
defer file.Close()
281+
282+
tr := tar.NewReader(file)
283+
var files []string
284+
285+
for {
286+
header, err := tr.Next()
287+
if err == io.EOF {
288+
break
289+
}
290+
if err != nil {
291+
return nil, err
292+
}
293+
files = append(files, header.Name)
294+
}
295+
296+
return files, nil
297+
}

0 commit comments

Comments
 (0)