11package main
22
33import (
4+ "archive/tar"
45 "context"
56 "flag"
67 "fmt"
8+ "io"
79 "os"
810 "path/filepath"
11+ "sort"
912 "strings"
1013
1114 "github.com/docker/model-runner/pkg/distribution/builder"
@@ -178,7 +181,12 @@ func cmdPackage(args []string) int {
178181 fs .StringVar (& chatTemplate , "chat-template" , "" , "Jinja chat template file" )
179182
180183 fs .Usage = func () {
181- fmt .Fprintf (os .Stderr , "Usage: model-distribution-tool package [OPTIONS] <path-to-gguf>\n \n " )
184+ fmt .Fprintf (os .Stderr , "Usage: model-distribution-tool package [OPTIONS] <path-to-model-or-directory>\n \n " )
185+ fmt .Fprintf (os .Stderr , "Examples:\n " )
186+ fmt .Fprintf (os .Stderr , " # GGUF model:\n " )
187+ fmt .Fprintf (os .Stderr , " model-distribution-tool package model.gguf --tag registry/model:tag\n \n " )
188+ fmt .Fprintf (os .Stderr , " # Safetensors model:\n " )
189+ fmt .Fprintf (os .Stderr , " model-distribution-tool package ./qwen-model-dir --tag registry/model:tag\n \n " )
182190 fmt .Fprintf (os .Stderr , "Options:\n " )
183191 fs .PrintDefaults ()
184192 }
@@ -189,32 +197,62 @@ func cmdPackage(args []string) int {
189197 }
190198 args = fs .Args ()
191199
200+ // Get the source from positional argument
192201 if len (args ) < 1 {
193- fmt .Fprintf (os .Stderr , "Error: missing arguments \n " )
202+ fmt .Fprintf (os .Stderr , "Error: no model file or directory specified \n " )
194203 fs .Usage ()
195204 return 1
196205 }
197- if file == "" && tag == "" {
198- fmt .Fprintf (os .Stderr , "Error: one of --file or --tag is required\n " )
199- fs .Usage ()
206+
207+ source := args [0 ]
208+ var isSafetensors bool
209+ var configArchive string // For safetensors config
210+ var safetensorsPaths []string // For safetensors model files
211+
212+ // Check if source exists
213+ sourceInfo , err := os .Stat (source )
214+ if os .IsNotExist (err ) {
215+ fmt .Fprintf (os .Stderr , "Error: source does not exist: %s\n " , source )
200216 return 1
201217 }
202218
203- source := args [0 ]
204- ctx := context .Background ()
219+ // Handle directory-based packaging (for safetensors models)
220+ if sourceInfo .IsDir () {
221+ fmt .Printf ("Detected directory, scanning for safetensors model...\n " )
222+ var err error
223+ safetensorsPaths , configArchive , err = packageFromDirectory (source )
224+ if err != nil {
225+ fmt .Fprintf (os .Stderr , "Error scanning directory: %v\n " , err )
226+ return 1
227+ }
205228
206- // Check if source file exists
207- if _ , err := os .Stat (source ); os .IsNotExist (err ) {
208- fmt .Fprintf (os .Stderr , "Error: source file does not exist: %s\n " , source )
209- return 1
229+ isSafetensors = true
230+ fmt .Printf ("Found %d safetensors file(s)\n " , len (safetensorsPaths ))
231+
232+ // Clean up temp config archive when done
233+ if configArchive != "" {
234+ defer os .Remove (configArchive )
235+ fmt .Printf ("Created temporary config archive from directory\n " )
236+ }
237+ } else {
238+ // Handle single file (GGUF model)
239+ if strings .HasSuffix (strings .ToLower (source ), ".gguf" ) {
240+ isSafetensors = false
241+ fmt .Println ("Detected GGUF model file" )
242+ } else {
243+ fmt .Fprintf (os .Stderr , "Warning: could not determine model type for: %s\n " , source )
244+ fmt .Fprintf (os .Stderr , "Assuming GGUF format.\n " )
245+ }
210246 }
211247
212- // Check if source file is a GGUF file
213- if ! strings . HasSuffix ( strings . ToLower ( source ) , ".gguf" ) {
214- fmt . Fprintf ( os . Stderr , "Warning: source file does not have .gguf extension: %s \n " , source )
215- fmt . Fprintf ( os . Stderr , "Continuing anyway, but this may cause issues. \n " )
248+ if file == "" && tag == "" {
249+ fmt . Fprintf ( os . Stderr , "Error: one of --file or --tag is required \n " )
250+ fs . Usage ( )
251+ return 1
216252 }
217253
254+ ctx := context .Background ()
255+
218256 // Prepare registry client options
219257 registryClientOpts := []registry.ClientOption {
220258 registry .WithUserAgent ("model-distribution-tool/" + version ),
@@ -230,31 +268,49 @@ func cmdPackage(args []string) int {
230268 // Create registry client once with all options
231269 registryClient := registry .NewClient (registryClientOpts ... )
232270
233- var (
234- target builder.Target
235- err error
236- )
271+ var target builder.Target
237272 if file != "" {
238273 target = tarball .NewFileTarget (file )
239274 } else {
275+ var err error
240276 target , err = registryClient .NewTarget (tag )
241277 if err != nil {
242278 fmt .Fprintf (os .Stderr , "Create packaging target: %v\n " , err )
243279 return 1
244280 }
245281 }
246282
247- // Create image with layer
248- builder , err := builder .FromGGUF (source )
249- if err != nil {
250- fmt .Fprintf (os .Stderr , "Error creating model from gguf: %v\n " , err )
251- return 1
283+ // Create builder based on model type
284+ var b * builder.Builder
285+ if isSafetensors {
286+ fmt .Println ("Creating safetensors model" )
287+ b , err = builder .FromSafetensors (safetensorsPaths )
288+ if err != nil {
289+ fmt .Fprintf (os .Stderr , "Error creating model from safetensors: %v\n " , err )
290+ return 1
291+ }
292+
293+ // Add config archive if provided
294+ if configArchive != "" {
295+ fmt .Printf ("Adding config archive: %s\n " , configArchive )
296+ b , err = b .WithConfigArchive (configArchive )
297+ if err != nil {
298+ fmt .Fprintf (os .Stderr , "Error adding config archive: %v\n " , err )
299+ return 1
300+ }
301+ }
302+ } else {
303+ b , err = builder .FromGGUF (source )
304+ if err != nil {
305+ fmt .Fprintf (os .Stderr , "Error creating model from gguf: %v\n " , err )
306+ return 1
307+ }
252308 }
253309
254310 // Add all license files as layers
255311 for _ , path := range licensePaths {
256312 fmt .Println ("Adding license file:" , path )
257- builder , err = builder .WithLicense (path )
313+ b , err = b .WithLicense (path )
258314 if err != nil {
259315 fmt .Fprintf (os .Stderr , "Error adding license layer for %s: %v\n " , path , err )
260316 return 1
@@ -263,12 +319,12 @@ func cmdPackage(args []string) int {
263319
264320 if contextSize > 0 {
265321 fmt .Println ("Setting context size:" , contextSize )
266- builder = builder .WithContextSize (contextSize )
322+ b = b .WithContextSize (contextSize )
267323 }
268324
269325 if mmproj != "" {
270326 fmt .Println ("Adding multimodal projector file:" , mmproj )
271- builder , err = builder .WithMultimodalProjector (mmproj )
327+ b , err = b .WithMultimodalProjector (mmproj )
272328 if err != nil {
273329 fmt .Fprintf (os .Stderr , "Error adding multimodal projector layer for %s: %v\n " , mmproj , err )
274330 return 1
@@ -277,15 +333,15 @@ func cmdPackage(args []string) int {
277333
278334 if chatTemplate != "" {
279335 fmt .Println ("Adding chat template file:" , chatTemplate )
280- builder , err = builder .WithChatTemplateFile (chatTemplate )
336+ b , err = b .WithChatTemplateFile (chatTemplate )
281337 if err != nil {
282338 fmt .Fprintf (os .Stderr , "Error adding chat template layer for %s: %v\n " , chatTemplate , err )
283339 return 1
284340 }
285341 }
286342
287343 // Push the image
288- if err := builder .Build (ctx , target , os .Stdout ); err != nil {
344+ if err := b .Build (ctx , target , os .Stdout ); err != nil {
289345 fmt .Fprintf (os .Stderr , "Error writing model to registry: %v\n " , err )
290346 return 1
291347 }
@@ -525,3 +581,132 @@ func cmdBundle(client *distribution.Client, args []string) int {
525581 fmt .Fprint (os .Stdout , bundle .RootDir ())
526582 return 0
527583}
584+
585+ // packageFromDirectory scans a directory for safetensors files and config files,
586+ // creating a temporary tar archive of the config files
587+ func packageFromDirectory (dirPath string ) (safetensorsPaths []string , tempConfigArchive string , err error ) {
588+ // Read directory contents (only top level, no subdirectories)
589+ entries , err := os .ReadDir (dirPath )
590+ if err != nil {
591+ return nil , "" , fmt .Errorf ("read directory: %w" , err )
592+ }
593+
594+ var configFiles []string
595+
596+ for _ , entry := range entries {
597+ if entry .IsDir () {
598+ continue // Skip subdirectories
599+ }
600+
601+ name := entry .Name ()
602+ fullPath := filepath .Join (dirPath , name )
603+
604+ // Collect safetensors files
605+ if strings .HasSuffix (strings .ToLower (name ), ".safetensors" ) {
606+ safetensorsPaths = append (safetensorsPaths , fullPath )
607+ }
608+
609+ // Collect config files: *.json, merges.txt
610+ if strings .HasSuffix (strings .ToLower (name ), ".json" ) ||
611+ name == "merges.txt" {
612+ configFiles = append (configFiles , fullPath )
613+ }
614+ }
615+
616+ if len (safetensorsPaths ) == 0 {
617+ return nil , "" , fmt .Errorf ("no safetensors files found in directory: %s" , dirPath )
618+ }
619+
620+ // Sort to ensure reproducible artifacts
621+ sort .Strings (safetensorsPaths )
622+
623+ // Create temporary tar archive with config files if any exist
624+ if len (configFiles ) > 0 {
625+ // Sort config files for reproducible tar archive
626+ sort .Strings (configFiles )
627+
628+ tempConfigArchive , err = createTempConfigArchive (configFiles )
629+ if err != nil {
630+ return nil , "" , fmt .Errorf ("create config archive: %w" , err )
631+ }
632+ }
633+
634+ return safetensorsPaths , tempConfigArchive , nil
635+ }
636+
637+ // createTempConfigArchive creates a temporary tar archive containing the specified config files
638+ func createTempConfigArchive (configFiles []string ) (string , error ) {
639+ // Create temp file
640+ tmpFile , err := os .CreateTemp ("" , "vllm-config-*.tar" )
641+ if err != nil {
642+ return "" , fmt .Errorf ("create temp file: %w" , err )
643+ }
644+ tmpPath := tmpFile .Name ()
645+
646+ // Create tar writer
647+ tw := tar .NewWriter (tmpFile )
648+
649+ // Add each config file to tar (preserving just filename, not full path)
650+ for _ , filePath := range configFiles {
651+ // Open the file
652+ file , err := os .Open (filePath )
653+ if err != nil {
654+ tw .Close ()
655+ tmpFile .Close ()
656+ os .Remove (tmpPath )
657+ return "" , fmt .Errorf ("open config file %s: %w" , filePath , err )
658+ }
659+
660+ // Get file info for tar header
661+ fileInfo , err := file .Stat ()
662+ if err != nil {
663+ file .Close ()
664+ tw .Close ()
665+ tmpFile .Close ()
666+ os .Remove (tmpPath )
667+ return "" , fmt .Errorf ("stat config file %s: %w" , filePath , err )
668+ }
669+
670+ // Create tar header (use only basename, not full path)
671+ header := & tar.Header {
672+ Name : filepath .Base (filePath ),
673+ Size : fileInfo .Size (),
674+ Mode : int64 (fileInfo .Mode ()),
675+ ModTime : fileInfo .ModTime (),
676+ }
677+
678+ // Write header
679+ if err := tw .WriteHeader (header ); err != nil {
680+ file .Close ()
681+ tw .Close ()
682+ tmpFile .Close ()
683+ os .Remove (tmpPath )
684+ return "" , fmt .Errorf ("write tar header for %s: %w" , filePath , err )
685+ }
686+
687+ // Copy file contents
688+ if _ , err := io .Copy (tw , file ); err != nil {
689+ file .Close ()
690+ tw .Close ()
691+ tmpFile .Close ()
692+ os .Remove (tmpPath )
693+ return "" , fmt .Errorf ("write tar content for %s: %w" , filePath , err )
694+ }
695+
696+ file .Close ()
697+ }
698+
699+ // Close tar writer and file
700+ if err := tw .Close (); err != nil {
701+ tmpFile .Close ()
702+ os .Remove (tmpPath )
703+ return "" , fmt .Errorf ("close tar writer: %w" , err )
704+ }
705+
706+ if err := tmpFile .Close (); err != nil {
707+ os .Remove (tmpPath )
708+ return "" , fmt .Errorf ("close temp file: %w" , err )
709+ }
710+
711+ return tmpPath , nil
712+ }
0 commit comments