Skip to content

Commit 4f6797f

Browse files
authored
Fix bundle (#235)
* fix(bundle): move model files to a dedicated subdirectory and update paths * fix(parse): validate model subdirectory existence for new bundle format * fix(parse): add validation for model weight formats in bundle parsing * fix(bundle): create model subdirectory upfront for unpack operations
1 parent 97a5f3e commit 4f6797f

File tree

5 files changed

+281
-30
lines changed

5 files changed

+281
-30
lines changed

pkg/distribution/distribution/bundle_test.go

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -93,31 +93,31 @@ func TestBundle(t *testing.T) {
9393
ref: singleGGUFID,
9494
description: "single file GGUF by ID",
9595
expectedFiles: map[string]string{
96-
"model.gguf": filepath.Join("..", "assets", "dummy.gguf"),
96+
"model/model.gguf": filepath.Join("..", "assets", "dummy.gguf"),
9797
},
9898
},
9999
{
100100
ref: shardedGGUFID,
101101
description: "sharded GGUF by ID",
102102
expectedFiles: map[string]string{
103-
"model-00001-of-00002.gguf": filepath.Join("..", "assets", "dummy-00001-of-00002.gguf"),
104-
"model-00002-of-00002.gguf": filepath.Join("..", "assets", "dummy-00002-of-00002.gguf"),
103+
"model/model-00001-of-00002.gguf": filepath.Join("..", "assets", "dummy-00001-of-00002.gguf"),
104+
"model/model-00002-of-00002.gguf": filepath.Join("..", "assets", "dummy-00002-of-00002.gguf"),
105105
},
106106
},
107107
{
108108
ref: mmprojMdlID,
109-
description: "sharded GGUF by ID",
109+
description: "model with mmproj file",
110110
expectedFiles: map[string]string{
111-
"model.gguf": filepath.Join("..", "assets", "dummy.gguf"),
112-
"model.mmproj": filepath.Join("..", "assets", "dummy.mmproj"),
111+
"model/model.gguf": filepath.Join("..", "assets", "dummy.gguf"),
112+
"model/model.mmproj": filepath.Join("..", "assets", "dummy.mmproj"),
113113
},
114114
},
115115
{
116116
ref: templateMdlID,
117117
description: "model with template file",
118118
expectedFiles: map[string]string{
119-
"model.gguf": filepath.Join("..", "assets", "dummy.gguf"),
120-
"template.jinja": filepath.Join("..", "assets", "template.jinja"),
119+
"model/model.gguf": filepath.Join("..", "assets", "dummy.gguf"),
120+
"model/template.jinja": filepath.Join("..", "assets", "template.jinja"),
121121
},
122122
},
123123
}

pkg/distribution/internal/bundle/bundle.go

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,11 @@ import (
66
"github.com/docker/model-runner/pkg/distribution/types"
77
)
88

9+
const (
10+
// ModelSubdir is the subdirectory within a bundle where model files are stored
11+
ModelSubdir = "model"
12+
)
13+
914
// Bundle represents a runtime bundle containing a model and runtime config
1015
type Bundle struct {
1116
dir string
@@ -27,31 +32,31 @@ func (b *Bundle) GGUFPath() string {
2732
if b.ggufFile == "" {
2833
return ""
2934
}
30-
return filepath.Join(b.dir, b.ggufFile)
35+
return filepath.Join(b.dir, ModelSubdir, b.ggufFile)
3136
}
3237

3338
// MMPROJPath returns the path to a multi-modal projector file or "" if none is present.
3439
func (b *Bundle) MMPROJPath() string {
3540
if b.mmprojPath == "" {
3641
return ""
3742
}
38-
return filepath.Join(b.dir, b.mmprojPath)
43+
return filepath.Join(b.dir, ModelSubdir, b.mmprojPath)
3944
}
4045

4146
// ChatTemplatePath return the path to a Jinja chat template file or "" if none is present.
4247
func (b *Bundle) ChatTemplatePath() string {
4348
if b.chatTemplatePath == "" {
4449
return ""
4550
}
46-
return filepath.Join(b.dir, b.chatTemplatePath)
51+
return filepath.Join(b.dir, ModelSubdir, b.chatTemplatePath)
4752
}
4853

4954
// SafetensorsPath returns the path to model safetensors file. If the model is sharded this will be the path to the first shard.
5055
func (b *Bundle) SafetensorsPath() string {
5156
if b.safetensorsFile == "" {
5257
return ""
5358
}
54-
return filepath.Join(b.dir, b.safetensorsFile)
59+
return filepath.Join(b.dir, ModelSubdir, b.safetensorsFile)
5560
}
5661

5762
// RuntimeConfig returns config that should be respected by the backend at runtime.

pkg/distribution/internal/bundle/parse.go

Lines changed: 44 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,38 @@ func Parse(rootDir string) (*Bundle, error) {
1414
if fi, err := os.Stat(rootDir); err != nil || !fi.IsDir() {
1515
return nil, fmt.Errorf("inspect bundle root dir: %w", err)
1616
}
17-
ggufPath, err := findGGUFFile(rootDir)
17+
18+
// Check if model subdirectory exists - required for new bundle format
19+
// If it doesn't exist, this is an old bundle format that needs to be recreated
20+
modelDir := filepath.Join(rootDir, ModelSubdir)
21+
if _, err := os.Stat(modelDir); os.IsNotExist(err) {
22+
return nil, fmt.Errorf("bundle uses old format (missing %s subdirectory), will be recreated", ModelSubdir)
23+
}
24+
25+
ggufPath, err := findGGUFFile(modelDir)
1826
if err != nil {
1927
return nil, err
2028
}
21-
mmprojPath, err := findMultiModalProjectorFile(rootDir)
29+
safetensorsPath, err := findSafetensorsFile(modelDir)
30+
if err != nil {
31+
return nil, err
32+
}
33+
34+
// Ensure at least one model weight format is present
35+
if ggufPath == "" && safetensorsPath == "" {
36+
return nil, fmt.Errorf("no supported model weights found (neither GGUF nor safetensors)")
37+
}
38+
39+
mmprojPath, err := findMultiModalProjectorFile(modelDir)
2240
if err != nil {
2341
return nil, err
2442
}
25-
templatePath, err := findChatTemplateFile(rootDir)
43+
templatePath, err := findChatTemplateFile(modelDir)
2644
if err != nil {
2745
return nil, err
2846
}
47+
48+
// Runtime config stays at bundle root
2949
cfg, err := parseRuntimeConfig(rootDir)
3050
if err != nil {
3151
return nil, err
@@ -34,6 +54,7 @@ func Parse(rootDir string) (*Bundle, error) {
3454
dir: rootDir,
3555
mmprojPath: mmprojPath,
3656
ggufFile: ggufPath,
57+
safetensorsFile: safetensorsPath,
3758
runtimeConfig: cfg,
3859
chatTemplatePath: templatePath,
3960
}, nil
@@ -52,19 +73,32 @@ func parseRuntimeConfig(rootDir string) (types.Config, error) {
5273
return cfg, nil
5374
}
5475

55-
func findGGUFFile(rootDir string) (string, error) {
56-
ggufs, err := filepath.Glob(filepath.Join(rootDir, "[^.]*.gguf"))
76+
func findGGUFFile(modelDir string) (string, error) {
77+
ggufs, err := filepath.Glob(filepath.Join(modelDir, "[^.]*.gguf"))
5778
if err != nil {
5879
return "", fmt.Errorf("find gguf files: %w", err)
5980
}
6081
if len(ggufs) == 0 {
61-
return "", fmt.Errorf("no GGUF files found in bundle directory")
82+
// GGUF files are optional - safetensors models won't have them
83+
return "", nil
6284
}
6385
return filepath.Base(ggufs[0]), nil
6486
}
6587

66-
func findMultiModalProjectorFile(rootDir string) (string, error) {
67-
mmprojPaths, err := filepath.Glob(filepath.Join(rootDir, "[^.]*.mmproj"))
88+
func findSafetensorsFile(modelDir string) (string, error) {
89+
safetensors, err := filepath.Glob(filepath.Join(modelDir, "[^.]*.safetensors"))
90+
if err != nil {
91+
return "", fmt.Errorf("find safetensors files: %w", err)
92+
}
93+
if len(safetensors) == 0 {
94+
// Safetensors files are optional - GGUF models won't have them
95+
return "", nil
96+
}
97+
return filepath.Base(safetensors[0]), nil
98+
}
99+
100+
func findMultiModalProjectorFile(modelDir string) (string, error) {
101+
mmprojPaths, err := filepath.Glob(filepath.Join(modelDir, "[^.]*.mmproj"))
68102
if err != nil {
69103
return "", err
70104
}
@@ -77,8 +111,8 @@ func findMultiModalProjectorFile(rootDir string) (string, error) {
77111
return filepath.Base(mmprojPaths[0]), nil
78112
}
79113

80-
func findChatTemplateFile(rootDir string) (string, error) {
81-
templatePaths, err := filepath.Glob(filepath.Join(rootDir, "[^.]*.jinja"))
114+
func findChatTemplateFile(modelDir string) (string, error) {
115+
templatePaths, err := filepath.Glob(filepath.Join(modelDir, "[^.]*.jinja"))
82116
if err != nil {
83117
return "", err
84118
}
Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
package bundle
2+
3+
import (
4+
"encoding/json"
5+
"os"
6+
"path/filepath"
7+
"strings"
8+
"testing"
9+
10+
"github.com/docker/model-runner/pkg/distribution/types"
11+
)
12+
13+
func TestParse_NoModelWeights(t *testing.T) {
14+
// Create a temporary directory for the test bundle
15+
tempDir := t.TempDir()
16+
17+
// Create model subdirectory
18+
modelDir := filepath.Join(tempDir, ModelSubdir)
19+
if err := os.MkdirAll(modelDir, 0755); err != nil {
20+
t.Fatalf("Failed to create model directory: %v", err)
21+
}
22+
23+
// Create a valid config.json at bundle root
24+
cfg := types.Config{
25+
Format: types.FormatGGUF,
26+
}
27+
configPath := filepath.Join(tempDir, "config.json")
28+
f, err := os.Create(configPath)
29+
if err != nil {
30+
t.Fatalf("Failed to create config.json: %v", err)
31+
}
32+
if err := json.NewEncoder(f).Encode(cfg); err != nil {
33+
f.Close()
34+
t.Fatalf("Failed to encode config: %v", err)
35+
}
36+
f.Close()
37+
38+
// Try to parse the bundle - should fail because no model weights are present
39+
_, err = Parse(tempDir)
40+
if err == nil {
41+
t.Fatal("Expected error when parsing bundle without model weights, got nil")
42+
}
43+
44+
expectedErrMsg := "no supported model weights found (neither GGUF nor safetensors)"
45+
if !strings.Contains(err.Error(), expectedErrMsg) {
46+
t.Errorf("Expected error message to contain %q, got: %v", expectedErrMsg, err)
47+
}
48+
}
49+
50+
func TestParse_WithGGUF(t *testing.T) {
51+
// Create a temporary directory for the test bundle
52+
tempDir := t.TempDir()
53+
54+
// Create model subdirectory
55+
modelDir := filepath.Join(tempDir, ModelSubdir)
56+
if err := os.MkdirAll(modelDir, 0755); err != nil {
57+
t.Fatalf("Failed to create model directory: %v", err)
58+
}
59+
60+
// Create a dummy GGUF file
61+
ggufPath := filepath.Join(modelDir, "model.gguf")
62+
if err := os.WriteFile(ggufPath, []byte("dummy gguf content"), 0644); err != nil {
63+
t.Fatalf("Failed to create GGUF file: %v", err)
64+
}
65+
66+
// Create a valid config.json at bundle root
67+
cfg := types.Config{
68+
Format: types.FormatGGUF,
69+
}
70+
configPath := filepath.Join(tempDir, "config.json")
71+
f, err := os.Create(configPath)
72+
if err != nil {
73+
t.Fatalf("Failed to create config.json: %v", err)
74+
}
75+
if err := json.NewEncoder(f).Encode(cfg); err != nil {
76+
f.Close()
77+
t.Fatalf("Failed to encode config: %v", err)
78+
}
79+
f.Close()
80+
81+
// Parse the bundle - should succeed
82+
bundle, err := Parse(tempDir)
83+
if err != nil {
84+
t.Fatalf("Expected successful parse with GGUF file, got error: %v", err)
85+
}
86+
87+
if bundle.ggufFile != "model.gguf" {
88+
t.Errorf("Expected ggufFile to be 'model.gguf', got: %s", bundle.ggufFile)
89+
}
90+
91+
if bundle.safetensorsFile != "" {
92+
t.Errorf("Expected safetensorsFile to be empty, got: %s", bundle.safetensorsFile)
93+
}
94+
}
95+
96+
func TestParse_WithSafetensors(t *testing.T) {
97+
// Create a temporary directory for the test bundle
98+
tempDir := t.TempDir()
99+
100+
// Create model subdirectory
101+
modelDir := filepath.Join(tempDir, ModelSubdir)
102+
if err := os.MkdirAll(modelDir, 0755); err != nil {
103+
t.Fatalf("Failed to create model directory: %v", err)
104+
}
105+
106+
// Create a dummy safetensors file
107+
safetensorsPath := filepath.Join(modelDir, "model.safetensors")
108+
if err := os.WriteFile(safetensorsPath, []byte("dummy safetensors content"), 0644); err != nil {
109+
t.Fatalf("Failed to create safetensors file: %v", err)
110+
}
111+
112+
// Create a valid config.json at bundle root
113+
cfg := types.Config{
114+
Format: types.FormatSafetensors,
115+
}
116+
configPath := filepath.Join(tempDir, "config.json")
117+
f, err := os.Create(configPath)
118+
if err != nil {
119+
t.Fatalf("Failed to create config.json: %v", err)
120+
}
121+
if err := json.NewEncoder(f).Encode(cfg); err != nil {
122+
f.Close()
123+
t.Fatalf("Failed to encode config: %v", err)
124+
}
125+
f.Close()
126+
127+
// Parse the bundle - should succeed
128+
bundle, err := Parse(tempDir)
129+
if err != nil {
130+
t.Fatalf("Expected successful parse with safetensors file, got error: %v", err)
131+
}
132+
133+
if bundle.safetensorsFile != "model.safetensors" {
134+
t.Errorf("Expected safetensorsFile to be 'model.safetensors', got: %s", bundle.safetensorsFile)
135+
}
136+
137+
if bundle.ggufFile != "" {
138+
t.Errorf("Expected ggufFile to be empty, got: %s", bundle.ggufFile)
139+
}
140+
}
141+
142+
func TestParse_WithBothFormats(t *testing.T) {
143+
// Create a temporary directory for the test bundle
144+
tempDir := t.TempDir()
145+
146+
// Create model subdirectory
147+
modelDir := filepath.Join(tempDir, ModelSubdir)
148+
if err := os.MkdirAll(modelDir, 0755); err != nil {
149+
t.Fatalf("Failed to create model directory: %v", err)
150+
}
151+
152+
// Create both GGUF and safetensors files
153+
ggufPath := filepath.Join(modelDir, "model.gguf")
154+
if err := os.WriteFile(ggufPath, []byte("dummy gguf content"), 0644); err != nil {
155+
t.Fatalf("Failed to create GGUF file: %v", err)
156+
}
157+
158+
safetensorsPath := filepath.Join(modelDir, "model.safetensors")
159+
if err := os.WriteFile(safetensorsPath, []byte("dummy safetensors content"), 0644); err != nil {
160+
t.Fatalf("Failed to create safetensors file: %v", err)
161+
}
162+
163+
// Create a valid config.json at bundle root
164+
cfg := types.Config{
165+
Format: types.FormatGGUF,
166+
}
167+
configPath := filepath.Join(tempDir, "config.json")
168+
f, err := os.Create(configPath)
169+
if err != nil {
170+
t.Fatalf("Failed to create config.json: %v", err)
171+
}
172+
if err := json.NewEncoder(f).Encode(cfg); err != nil {
173+
f.Close()
174+
t.Fatalf("Failed to encode config: %v", err)
175+
}
176+
f.Close()
177+
178+
// Parse the bundle - should succeed with both files present
179+
bundle, err := Parse(tempDir)
180+
if err != nil {
181+
t.Fatalf("Expected successful parse with both formats, got error: %v", err)
182+
}
183+
184+
if bundle.ggufFile != "model.gguf" {
185+
t.Errorf("Expected ggufFile to be 'model.gguf', got: %s", bundle.ggufFile)
186+
}
187+
188+
if bundle.safetensorsFile != "model.safetensors" {
189+
t.Errorf("Expected safetensorsFile to be 'model.safetensors', got: %s", bundle.safetensorsFile)
190+
}
191+
}

0 commit comments

Comments
 (0)