Skip to content

Commit a20c7c8

Browse files
ilopezlunadoringeman
authored andcommitted
Add support to mmproj (docker#108)
* Add support for Multimodal projector file * Add tests for Multimodal support * instead of modifying Builder for testing purposes, add an implementation of Target to the test code that accomplishes the same thing
1 parent 4274e4c commit a20c7c8

File tree

10 files changed

+423
-11
lines changed

10 files changed

+423
-11
lines changed
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
This is a dummy multimodal projector file for testing purposes.
2+
It contains sample content to simulate a real multimodal projector file.

pkg/distribution/builder/builder.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,17 @@ func (b *Builder) WithContextSize(size uint64) *Builder {
4444
}
4545
}
4646

47+
// WithMultimodalProjector adds a Multimodal projector file to the artifact
48+
func (b *Builder) WithMultimodalProjector(path string) (*Builder, error) {
49+
mmprojLayer, err := partial.NewLayer(path, types.MediaTypeMultimodalProjector)
50+
if err != nil {
51+
return nil, fmt.Errorf("mmproj layer from %q: %w", path, err)
52+
}
53+
return &Builder{
54+
model: mutate.AppendLayers(b.model, mmprojLayer),
55+
}, nil
56+
}
57+
4758
// Target represents a build target
4859
type Target interface {
4960
Write(context.Context, types.ModelArtifact, io.Writer) error
Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
package builder_test
2+
3+
import (
4+
"context"
5+
"io"
6+
"path/filepath"
7+
"testing"
8+
9+
"github.com/docker/model-distribution/builder"
10+
"github.com/docker/model-distribution/types"
11+
)
12+
13+
func TestWithMultimodalProjector(t *testing.T) {
14+
// Create a builder from a GGUF file
15+
b, err := builder.FromGGUF(filepath.Join("..", "assets", "dummy.gguf"))
16+
if err != nil {
17+
t.Fatalf("Failed to create builder from GGUF: %v", err)
18+
}
19+
20+
// Add multimodal projector
21+
b2, err := b.WithMultimodalProjector(filepath.Join("..", "assets", "dummy.mmproj"))
22+
if err != nil {
23+
t.Fatalf("Failed to add multimodal projector: %v", err)
24+
}
25+
26+
// Build the model
27+
target := &fakeTarget{}
28+
if err := b2.Build(t.Context(), target, nil); err != nil {
29+
t.Fatalf("Failed to build model: %v", err)
30+
}
31+
32+
// Verify the model has the expected layers
33+
manifest, err := target.artifact.Manifest()
34+
if err != nil {
35+
t.Fatalf("Failed to get manifest: %v", err)
36+
}
37+
38+
// Should have 2 layers: GGUF + multimodal projector
39+
if len(manifest.Layers) != 2 {
40+
t.Fatalf("Expected 2 layers, got %d", len(manifest.Layers))
41+
}
42+
43+
// Check that one layer has the multimodal projector media type
44+
foundMMProjLayer := false
45+
for _, layer := range manifest.Layers {
46+
if layer.MediaType == types.MediaTypeMultimodalProjector {
47+
foundMMProjLayer = true
48+
break
49+
}
50+
}
51+
52+
if !foundMMProjLayer {
53+
t.Error("Expected to find a layer with multimodal projector media type")
54+
}
55+
56+
// Note: We can't directly test MMPROJPath() on ModelArtifact interface
57+
// but we can verify the layer was added with correct media type above
58+
}
59+
60+
func TestWithMultimodalProjectorInvalidPath(t *testing.T) {
61+
// Create a builder from a GGUF file
62+
b, err := builder.FromGGUF(filepath.Join("..", "assets", "dummy.gguf"))
63+
if err != nil {
64+
t.Fatalf("Failed to create builder from GGUF: %v", err)
65+
}
66+
67+
// Try to add multimodal projector with invalid path
68+
_, err = b.WithMultimodalProjector("nonexistent/path/to/mmproj")
69+
if err == nil {
70+
t.Error("Expected error when adding multimodal projector with invalid path")
71+
}
72+
}
73+
74+
func TestWithMultimodalProjectorChaining(t *testing.T) {
75+
// Create a builder from a GGUF file
76+
b, err := builder.FromGGUF(filepath.Join("..", "assets", "dummy.gguf"))
77+
if err != nil {
78+
t.Fatalf("Failed to create builder from GGUF: %v", err)
79+
}
80+
81+
// Chain multiple operations: license + multimodal projector + context size
82+
b, err = b.WithLicense(filepath.Join("..", "assets", "license.txt"))
83+
if err != nil {
84+
t.Fatalf("Failed to add license: %v", err)
85+
}
86+
87+
b, err = b.WithMultimodalProjector(filepath.Join("..", "assets", "dummy.mmproj"))
88+
if err != nil {
89+
t.Fatalf("Failed to add multimodal projector: %v", err)
90+
}
91+
92+
b = b.WithContextSize(4096)
93+
94+
// Build the model
95+
target := &fakeTarget{}
96+
if err := b.Build(t.Context(), target, nil); err != nil {
97+
t.Fatalf("Failed to build model: %v", err)
98+
}
99+
100+
// Verify the final model has all expected layers and properties
101+
manifest, err := target.artifact.Manifest()
102+
if err != nil {
103+
t.Fatalf("Failed to get manifest: %v", err)
104+
}
105+
106+
// Should have 3 layers: GGUF + license + multimodal projector
107+
if len(manifest.Layers) != 3 {
108+
t.Fatalf("Expected 3 layers, got %d", len(manifest.Layers))
109+
}
110+
111+
// Check media types - using string comparison since we can't use types.MediaType directly
112+
expectedMediaTypes := map[string]bool{
113+
string(types.MediaTypeGGUF): false,
114+
string(types.MediaTypeLicense): false,
115+
string(types.MediaTypeMultimodalProjector): false,
116+
}
117+
118+
for _, layer := range manifest.Layers {
119+
if _, exists := expectedMediaTypes[string(layer.MediaType)]; exists {
120+
expectedMediaTypes[string(layer.MediaType)] = true
121+
}
122+
}
123+
124+
for mediaType, found := range expectedMediaTypes {
125+
if !found {
126+
t.Errorf("Expected to find layer with media type %s", mediaType)
127+
}
128+
}
129+
130+
// Check context size
131+
config, err := target.artifact.Config()
132+
if err != nil {
133+
t.Fatalf("Failed to get config: %v", err)
134+
}
135+
136+
if config.ContextSize == nil || *config.ContextSize != 4096 {
137+
t.Errorf("Expected context size 4096, got %v", config.ContextSize)
138+
}
139+
140+
// Note: We can't directly test GGUFPath() and MMPROJPath() on ModelArtifact interface
141+
// but we can verify the layers were added with correct media types above
142+
}
143+
144+
var _ builder.Target = &fakeTarget{}
145+
146+
type fakeTarget struct {
147+
artifact types.ModelArtifact
148+
}
149+
150+
func (ft *fakeTarget) Write(ctx context.Context, artifact types.ModelArtifact, writer io.Writer) error {
151+
ft.artifact = artifact
152+
return nil
153+
}

pkg/distribution/internal/partial/partial.go

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ import (
77
v1 "github.com/google/go-containerregistry/pkg/v1"
88
"github.com/google/go-containerregistry/pkg/v1/partial"
99
ggcr "github.com/google/go-containerregistry/pkg/v1/types"
10-
"github.com/pkg/errors"
1110

1211
"github.com/docker/model-distribution/types"
1312
)
@@ -67,22 +66,31 @@ type WithLayers interface {
6766
}
6867

6968
func GGUFPath(i WithLayers) (string, error) {
69+
return layerPathByMediaType(i, types.MediaTypeGGUF)
70+
}
71+
72+
func MMPROJPath(i WithLayers) (string, error) {
73+
return layerPathByMediaType(i, types.MediaTypeMultimodalProjector)
74+
}
75+
76+
// layerPathByMediaType is a generic helper function that finds a layer by media type and returns its path
77+
func layerPathByMediaType(i WithLayers, mediaType ggcr.MediaType) (string, error) {
7078
layers, err := i.Layers()
7179
if err != nil {
7280
return "", fmt.Errorf("get layers: %w", err)
7381
}
7482
for _, l := range layers {
7583
mt, err := l.MediaType()
76-
if err != nil || mt != types.MediaTypeGGUF {
84+
if err != nil || mt != mediaType {
7785
continue
7886
}
79-
ggufLayer, ok := l.(*Layer)
87+
layer, ok := l.(*Layer)
8088
if !ok {
81-
return "", errors.New("gguf Layer is not available locally")
89+
return "", fmt.Errorf("%s Layer is not available locally", mediaType)
8290
}
83-
return ggufLayer.Path, nil
91+
return layer.Path, nil
8492
}
85-
return "", errors.New("model does not contain a GGUF layer")
93+
return "", fmt.Errorf("model does not contain a %s layer", mediaType)
8694
}
8795

8896
func ManifestForLayers(i WithLayers) (*v1.Manifest, error) {
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
package partial_test
2+
3+
import (
4+
"path/filepath"
5+
"testing"
6+
7+
"github.com/docker/model-distribution/internal/gguf"
8+
"github.com/docker/model-distribution/internal/mutate"
9+
"github.com/docker/model-distribution/internal/partial"
10+
"github.com/docker/model-distribution/types"
11+
)
12+
13+
func TestMMPROJPath(t *testing.T) {
14+
// Create a model from GGUF file
15+
mdl, err := gguf.NewModel(filepath.Join("..", "..", "assets", "dummy.gguf"))
16+
if err != nil {
17+
t.Fatalf("Failed to create model from GGUF: %v", err)
18+
}
19+
20+
// Add multimodal projector layer
21+
mmprojLayer, err := partial.NewLayer(filepath.Join("..", "..", "assets", "dummy.mmproj"), types.MediaTypeMultimodalProjector)
22+
if err != nil {
23+
t.Fatalf("Failed to create multimodal projector layer: %v", err)
24+
}
25+
26+
mdlWithMMProj := mutate.AppendLayers(mdl, mmprojLayer)
27+
28+
// Test MMPROJPath function
29+
mmprojPath, err := partial.MMPROJPath(mdlWithMMProj)
30+
if err != nil {
31+
t.Fatalf("Failed to get multimodal projector path: %v", err)
32+
}
33+
34+
expectedPath := filepath.Join("..", "..", "assets", "dummy.mmproj")
35+
if mmprojPath != expectedPath {
36+
t.Errorf("Expected multimodal projector path %s, got %s", expectedPath, mmprojPath)
37+
}
38+
}
39+
40+
func TestMMPROJPathNotFound(t *testing.T) {
41+
// Create a model from a GGUF file without a Multimodal projector
42+
mdl, err := gguf.NewModel(filepath.Join("..", "..", "assets", "dummy.gguf"))
43+
if err != nil {
44+
t.Fatalf("Failed to create model from GGUF: %v", err)
45+
}
46+
47+
// Test MMPROJPath function should return error
48+
_, err = partial.MMPROJPath(mdl)
49+
if err == nil {
50+
t.Error("Expected error when getting multimodal projector path from model without multimodal projector layer")
51+
}
52+
53+
expectedErrorMsg := "model does not contain a application/vnd.docker.ai.mmproj layer"
54+
if err.Error() != expectedErrorMsg {
55+
t.Errorf("Expected error message %q, got %q", expectedErrorMsg, err.Error())
56+
}
57+
}
58+
59+
func TestGGUFPath(t *testing.T) {
60+
// Create a model from GGUF file
61+
mdl, err := gguf.NewModel(filepath.Join("..", "..", "assets", "dummy.gguf"))
62+
if err != nil {
63+
t.Fatalf("Failed to create model from GGUF: %v", err)
64+
}
65+
66+
// Test GGUFPath function
67+
ggufPath, err := partial.GGUFPath(mdl)
68+
if err != nil {
69+
t.Fatalf("Failed to get GGUF path: %v", err)
70+
}
71+
72+
expectedPath := filepath.Join("..", "..", "assets", "dummy.gguf")
73+
if ggufPath != expectedPath {
74+
t.Errorf("Expected GGUF path %s, got %s", expectedPath, ggufPath)
75+
}
76+
}
77+
78+
func TestLayerPathByMediaType(t *testing.T) {
79+
// Create a model from GGUF file
80+
mdl, err := gguf.NewModel(filepath.Join("..", "..", "assets", "dummy.gguf"))
81+
if err != nil {
82+
t.Fatalf("Failed to create model from GGUF: %v", err)
83+
}
84+
85+
// Add license layer
86+
licenseLayer, err := partial.NewLayer(filepath.Join("..", "..", "assets", "license.txt"), types.MediaTypeLicense)
87+
if err != nil {
88+
t.Fatalf("Failed to create license layer: %v", err)
89+
}
90+
91+
// Add a Multimodal projector layer
92+
mmprojLayer, err := partial.NewLayer(filepath.Join("..", "..", "assets", "dummy.mmproj"), types.MediaTypeMultimodalProjector)
93+
if err != nil {
94+
t.Fatalf("Failed to create multimodal projector layer: %v", err)
95+
}
96+
97+
mdlWithLayers := mutate.AppendLayers(mdl, licenseLayer, mmprojLayer)
98+
99+
// Test that we can find each layer type
100+
ggufPath, err := partial.GGUFPath(mdlWithLayers)
101+
if err != nil {
102+
t.Fatalf("Failed to get GGUF path: %v", err)
103+
}
104+
if ggufPath != filepath.Join("..", "..", "assets", "dummy.gguf") {
105+
t.Errorf("Expected GGUF path to be: %s, got: %s", filepath.Join("..", "..", "assets", "dummy.gguf"), ggufPath)
106+
}
107+
108+
mmprojPath, err := partial.MMPROJPath(mdlWithLayers)
109+
if err != nil {
110+
t.Fatalf("Failed to get multimodal projector path: %v", err)
111+
}
112+
if mmprojPath != filepath.Join("..", "..", "assets", "dummy.mmproj") {
113+
t.Errorf("Expected multimodal projector path to be: %s, got: %s", filepath.Join("..", "..", "assets", "dummy.mmproj"), mmprojPath)
114+
}
115+
116+
}

pkg/distribution/internal/store/model.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,10 @@ func (m *Model) GGUFPath() (string, error) {
114114
return mdpartial.GGUFPath(m)
115115
}
116116

117+
func (m *Model) MMPROJPath() (string, error) {
118+
return mdpartial.MMPROJPath(m)
119+
}
120+
117121
func (m *Model) Tags() []string {
118122
return m.tags
119123
}

0 commit comments

Comments
 (0)