55 "fmt"
66 "io"
77 "os"
8+ "path/filepath"
89 "sort"
910 "sync"
1011 "sync/atomic"
@@ -27,6 +28,11 @@ import (
2728 otelCodes "go.opentelemetry.io/otel/codes"
2829)
2930
31+ const (
32+ safetensorFilePath = "model.safetensors.index.json"
33+ safetensorFileExt = ".safetensors"
34+ )
35+
3036type PullHook interface {
3137 BeforePullLayer (desc ocispec.Descriptor , manifest ocispec.Manifest )
3238 AfterPullLayer (desc ocispec.Descriptor , err error )
@@ -50,7 +56,7 @@ func NewHook(ctx context.Context, progressCb func(progress status.Progress)) *Ho
5056}
5157
5258type Puller interface {
53- Pull (ctx context.Context , reference , targetDir string ) error
59+ Pull (ctx context.Context , reference , targetDir string , excludeModelWeights bool ) error
5460}
5561
5662var NewPuller = func (ctx context.Context , pullCfg * config.PullConfig , hook * Hook , diskQuotaChecker * DiskQuotaChecker ) Puller {
@@ -201,7 +207,50 @@ func (h *Hook) GetProgress() status.Progress {
201207 return h .getProgress ()
202208}
203209
204- func (p * puller ) Pull (ctx context.Context , reference , targetDir string ) error {
210+ func isSafetensorFile (layer backend.InspectedModelArtifactLayer ) bool {
211+ // Check file path
212+ if layer .Filepath == safetensorFilePath {
213+ return true
214+ }
215+ // Compatibility for old model artifact format
216+ if filepath .Ext (layer .Filepath ) == safetensorFileExt {
217+ return true
218+ }
219+ return false
220+ }
221+
222+ func isWeightLayer (layer backend.InspectedModelArtifactLayer ) bool {
223+ // Check media type
224+ if layer .MediaType == modelspec .MediaTypeModelWeightRaw ||
225+ layer .MediaType == modelspec .MediaTypeModelWeight ||
226+ layer .MediaType == modelspec .MediaTypeModelWeightGzip ||
227+ layer .MediaType == modelspec .MediaTypeModelWeightZstd {
228+ return true
229+ }
230+ if isSafetensorFile (layer ) {
231+ return true
232+ }
233+ return false
234+ }
235+
236+ func getPatternsWithoutWeights (ctx context.Context , layers []backend.InspectedModelArtifactLayer ) []string {
237+ paths := []string {}
238+ for idx := range layers {
239+ layer := layers [idx ]
240+ if layer .Filepath == "" {
241+ logger .Logger ().WithContext (ctx ).Warnf (
242+ "layer %s has no file path, skip" , layer .Digest ,
243+ )
244+ continue
245+ }
246+ if ! isWeightLayer (layer ) {
247+ paths = append (paths , layer .Filepath )
248+ }
249+ }
250+ return paths
251+ }
252+
253+ func (p * puller ) Pull (ctx context.Context , reference , targetDir string , excludeModelWeights bool ) error {
205254 keyChain , err := auth .GetKeyChainByRef (reference )
206255 if err != nil {
207256 return errors .Wrapf (err , "get auth for model: %s" , reference )
@@ -223,28 +272,55 @@ func (p *puller) Pull(ctx context.Context, reference, targetDir string) error {
223272 return errors .Wrapf (err , "create model dir: %s" , targetDir )
224273 }
225274
226- if p .pullCfg .Concurrency < 1 {
227- p .pullCfg .Concurrency = 5
228- }
275+ go p .checkLongPulling (ctx )
229276
230- pullConfig := modctlConfig .NewPull ()
231- pullConfig .Concurrency = int (p .pullCfg .Concurrency )
232- pullConfig .PlainHTTP = keyChain .ServerScheme == "http"
233- pullConfig .Proxy = p .pullCfg .ProxyURL
234- pullConfig .DragonflyEndpoint = p .pullCfg .DragonflyEndpoint
235- pullConfig .Insecure = true
236- pullConfig .ExtractDir = targetDir
237- pullConfig .ExtractFromRemote = true
238- pullConfig .Hooks = p .hook
239- pullConfig .ProgressWriter = io .Discard
240- pullConfig .DisableProgress = true
277+ plainHTTP := keyChain .ServerScheme == "http"
278+
279+ if ! excludeModelWeights {
280+ pullConfig := modctlConfig .NewPull ()
281+ pullConfig .Concurrency = int (p .pullCfg .Concurrency )
282+ pullConfig .PlainHTTP = plainHTTP
283+ pullConfig .Proxy = p .pullCfg .ProxyURL
284+ pullConfig .DragonflyEndpoint = p .pullCfg .DragonflyEndpoint
285+ pullConfig .Insecure = true
286+ pullConfig .ExtractDir = targetDir
287+ pullConfig .ExtractFromRemote = true
288+ pullConfig .Hooks = p .hook
289+ pullConfig .ProgressWriter = io .Discard
290+ pullConfig .DisableProgress = true
291+
292+ if err := b .Pull (ctx , reference , pullConfig ); err != nil {
293+ logger .WithContext (ctx ).WithError (err ).Errorf ("failed to pull model image: %s" , reference )
294+ return errors .Wrap (err , "pull model image" )
295+ }
241296
242- go p .checkLongPulling (ctx )
297+ return nil
298+ }
243299
244- if err := b .Pull (ctx , reference , pullConfig ); err != nil {
245- logger .WithContext (ctx ).WithError (err ).Errorf ("failed to pull model image: %s" , reference )
246- return errors .Wrap (err , "pull model image" )
300+ start := time .Now ()
301+ result , err := b .Inspect (ctx , reference , & modctlConfig.Inspect {
302+ Remote : true ,
303+ Insecure : true ,
304+ PlainHTTP : plainHTTP ,
305+ })
306+ if err != nil {
307+ return errors .Wrap (err , "inspect model" )
247308 }
309+ logger .WithContext (ctx ).Infof ("inspected model %s, duration: %s" , reference , time .Since (start ))
310+ modelArtifact , ok := result .(* backend.InspectedModelArtifact )
311+ if ! ok {
312+ return errors .Errorf ("invalid inspected result: %s" , reference )
313+ }
314+
315+ patterns := getPatternsWithoutWeights (ctx , modelArtifact .Layers )
316+
317+ fetchConfig := modctlConfig .NewFetch ()
318+ fetchConfig .Concurrency = int (p .pullCfg .Concurrency )
319+ fetchConfig .PlainHTTP = plainHTTP
320+ fetchConfig .Proxy = p .pullCfg .ProxyURL
321+ fetchConfig .Insecure = true
322+ fetchConfig .Output = targetDir
323+ fetchConfig .Patterns = patterns
248324
249325 return nil
250326}
0 commit comments