Skip to content

Commit 57b9bf4

Browse files
ekcaseydoringeman
andcommitted
Load model from archive (docker#113)
* Adds support for packaging/loading models to/from TAR archive Signed-off-by: Emily Casey <emily.casey@docker.com> * Check for blobs when writing manifest Signed-off-by: Emily Casey <emily.casey@docker.com> * LoadModel: don't crash on cancellation Signed-off-by: Dorin Geman <dorin.geman@docker.com> --------- Signed-off-by: Emily Casey <emily.casey@docker.com> Signed-off-by: Dorin Geman <dorin.geman@docker.com> Co-authored-by: Dorin Geman <dorin.geman@docker.com>
1 parent 4cea686 commit 57b9bf4

File tree

1 file changed

+84
-12
lines changed

1 file changed

+84
-12
lines changed

cmd/mdltool/main.go

Lines changed: 84 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111
"github.com/docker/model-distribution/builder"
1212
"github.com/docker/model-distribution/distribution"
1313
"github.com/docker/model-distribution/registry"
14+
"github.com/docker/model-distribution/tarball"
1415
)
1516

1617
// stringSliceFlag is a flag that can be specified multiple times to collect multiple string values
@@ -103,6 +104,8 @@ func main() {
103104
exitCode = cmdRm(client, args)
104105
case "tag":
105106
exitCode = cmdTag(client, args)
107+
case "load":
108+
exitCode = cmdLoad(client, args)
106109
default:
107110
fmt.Fprintf(os.Stderr, "Unknown command: %s\n", command)
108111
printUsage()
@@ -154,15 +157,23 @@ func cmdPull(client *distribution.Client, args []string) int {
154157

155158
func cmdPackage(args []string) int {
156159
fs := flag.NewFlagSet("package", flag.ExitOnError)
157-
var licensePaths stringSliceFlag
158-
var contextSize uint64
159-
var mmproj string
160+
var (
161+
licensePaths stringSliceFlag
162+
contextSize uint64
163+
file string
164+
tag string
165+
mmproj string
166+
)
160167

161168
fs.Var(&licensePaths, "licenses", "Paths to license files (can be specified multiple times)")
162169
fs.Uint64Var(&contextSize, "context-size", 0, "Context size in tokens")
163170
fs.StringVar(&mmproj, "mmproj", "", "Path to Multimodal Projector file")
171+
fs.StringVar(&file, "file", "", "Write archived model to the given file")
172+
fs.StringVar(&tag, "tag", "", "Push model to the given registry tag")
173+
fs.Parse(args)
174+
164175
fs.Usage = func() {
165-
fmt.Fprintf(os.Stderr, "Usage: model-distribution-tool package [OPTIONS] <path-to-gguf> <reference>\n\n")
176+
fmt.Fprintf(os.Stderr, "Usage: model-distribution-tool package [OPTIONS] <path-to-gguf>\n\n")
166177
fmt.Fprintf(os.Stderr, "Options:\n")
167178
fs.PrintDefaults()
168179
}
@@ -173,14 +184,18 @@ func cmdPackage(args []string) int {
173184
}
174185
args = fs.Args()
175186

176-
if len(args) < 2 {
187+
if len(args) < 1 {
177188
fmt.Fprintf(os.Stderr, "Error: missing arguments\n")
178189
fs.Usage()
179190
return 1
180191
}
192+
if file == "" && tag == "" {
193+
fmt.Fprintf(os.Stderr, "Error: one of --file or --tag is required\n")
194+
fs.Usage()
195+
return 1
196+
}
181197

182198
source := args[0]
183-
reference := args[1]
184199
ctx := context.Background()
185200

186201
// Check if source file exists
@@ -210,11 +225,18 @@ func cmdPackage(args []string) int {
210225
// Create registry client once with all options
211226
registryClient := registry.NewClient(registryClientOpts...)
212227

213-
// Parse the reference
214-
target, err := registryClient.NewTarget(reference)
215-
if err != nil {
216-
fmt.Fprintf(os.Stderr, "Error parsing reference: %v\n", err)
217-
return 1
228+
var (
229+
target builder.Target
230+
err error
231+
)
232+
if file != "" {
233+
target = tarball.NewFileTarget(file)
234+
} else {
235+
target, err = registryClient.NewTarget(tag)
236+
if err != nil {
237+
fmt.Fprintf(os.Stderr, "Create packaging target: %v\n", err)
238+
return 1
239+
}
218240
}
219241

220242
// Create image with layer
@@ -250,9 +272,59 @@ func cmdPackage(args []string) int {
250272

251273
// Push the image
252274
if err := builder.Build(ctx, target, os.Stdout); err != nil {
253-
fmt.Fprintf(os.Stderr, "Error writing model %q to registry: %v\n", reference, err)
275+
fmt.Fprintf(os.Stderr, "Error writing model to registry: %v\n", err)
254276
return 1
255277
}
278+
if tag != "" {
279+
fmt.Printf("Successfully packaged and pushed model: %s\n", tag)
280+
} else {
281+
fmt.Printf("Successfully packaged model to file: %s\n", file)
282+
}
283+
return 0
284+
}
285+
286+
func cmdLoad(client *distribution.Client, args []string) int {
287+
fs := flag.NewFlagSet("load", flag.ExitOnError)
288+
var (
289+
tag string
290+
)
291+
fs.StringVar(&tag, "tag", "", "Apply tag to the loaded model")
292+
fs.Usage = func() {
293+
fmt.Fprintf(os.Stderr, "Usage: model-distribution-tool load [OPTIONS] <path-to-archive>\n\n")
294+
fmt.Fprintf(os.Stderr, "Options:\n")
295+
fs.PrintDefaults()
296+
}
297+
298+
if err := fs.Parse(args); err != nil {
299+
fmt.Fprintf(os.Stderr, "Error parsing flags: %v\n", err)
300+
return 1
301+
}
302+
args = fs.Args()
303+
304+
if len(args) < 1 {
305+
fmt.Fprintf(os.Stderr, "Error: missing required argument\n")
306+
fs.Usage()
307+
return 1
308+
}
309+
path := args[0]
310+
311+
f, err := os.Open(path)
312+
if err != nil {
313+
fmt.Fprintf(os.Stderr, "Error opening model file: %v\n", err)
314+
return 1
315+
}
316+
defer f.Close()
317+
318+
id, err := client.LoadModel(f, os.Stdout)
319+
if err != nil {
320+
fmt.Fprintf(os.Stderr, "Error loading model: %v\n", err)
321+
return 1
322+
}
323+
fmt.Fprintln(os.Stdout, "Loaded model:", id)
324+
if err := client.Tag(id, tag); err != nil {
325+
fmt.Fprintf(os.Stderr, "Error tagging model: %v\n", err)
326+
}
327+
fmt.Fprintln(os.Stdout, "Tagged model:", tag)
256328
return 0
257329
}
258330

0 commit comments

Comments
 (0)