@@ -2,30 +2,18 @@ package service
22
33import (
44 "context"
5- "fmt"
65 "io"
76 "os"
8- "sort"
97 "strings"
10- "sync"
11- "sync/atomic"
12- "time"
138
14- "github.com/dustin/go-humanize"
159 "github.com/modelpack/modctl/pkg/backend"
1610 modctlConfig "github.com/modelpack/modctl/pkg/config"
1711 "github.com/modelpack/model-csi-driver/pkg/config"
1812 "github.com/modelpack/model-csi-driver/pkg/config/auth"
1913 "github.com/modelpack/model-csi-driver/pkg/logger"
20- "github.com/modelpack/model-csi-driver/pkg/metrics"
2114 "github.com/modelpack/model-csi-driver/pkg/status"
22- "github.com/modelpack/model-csi-driver/pkg/tracing"
23- modelspec "github.com/modelpack/model-spec/specs-go/v1"
24- "github.com/opencontainers/go-digest"
2515 ocispec "github.com/opencontainers/image-spec/specs-go/v1"
2616 "github.com/pkg/errors"
27- "go.opentelemetry.io/otel/attribute"
28- otelCodes "go.opentelemetry.io/otel/codes"
2917)
3018
3119const (
@@ -37,28 +25,11 @@ type PullHook interface {
3725 AfterPullLayer (desc ocispec.Descriptor , err error )
3826}
3927
40- type Hook struct {
41- ctx context.Context
42- mutex sync.Mutex
43- manifest * ocispec.Manifest
44- pulled atomic.Uint32
45- progress map [digest.Digest ]* status.ProgressItem
46- progressCb func (progress status.Progress )
47- }
48-
49- func NewHook (ctx context.Context , progressCb func (progress status.Progress )) * Hook {
50- return & Hook {
51- ctx : ctx ,
52- progress : make (map [digest.Digest ]* status.ProgressItem ),
53- progressCb : progressCb ,
54- }
55- }
56-
5728type Puller interface {
5829 Pull (ctx context.Context , reference , targetDir string , excludeModelWeights bool ) error
5930}
6031
61- var NewPuller = func (ctx context.Context , pullCfg * config.PullConfig , hook * Hook , diskQuotaChecker * DiskQuotaChecker ) Puller {
32+ var NewPuller = func (ctx context.Context , pullCfg * config.PullConfig , hook * status. Hook , diskQuotaChecker * DiskQuotaChecker ) Puller {
6233 return & puller {
6334 pullCfg : pullCfg ,
6435 hook : hook ,
@@ -68,144 +39,10 @@ var NewPuller = func(ctx context.Context, pullCfg *config.PullConfig, hook *Hook
6839
6940type puller struct {
7041 pullCfg * config.PullConfig
71- hook * Hook
42+ hook * status. Hook
7243 diskQuotaChecker * DiskQuotaChecker
7344}
7445
75- func (h * Hook ) getProgressDesc () string {
76- finished := h .pulled .Load ()
77- if h .manifest == nil {
78- return fmt .Sprintf ("%d/unknown" , finished )
79- }
80-
81- total := len (h .manifest .Layers )
82-
83- return fmt .Sprintf ("%d/%d" , finished , total )
84- }
85-
86- func (h * Hook ) BeforePullLayer (desc ocispec.Descriptor , manifest ocispec.Manifest ) {
87- h .mutex .Lock ()
88- defer h .mutex .Unlock ()
89-
90- filePath := ""
91- if desc .Annotations != nil && desc .Annotations [modelspec .AnnotationFilepath ] != "" {
92- filePath = fmt .Sprintf ("/%s" , desc .Annotations [modelspec .AnnotationFilepath ])
93- }
94-
95- _ , span := tracing .Tracer .Start (h .ctx , "PullLayer" )
96- span .SetAttributes (attribute .String ("digest" , desc .Digest .String ()))
97- span .SetAttributes (attribute .String ("media_type" , desc .MediaType ))
98- span .SetAttributes (attribute .String ("file_path" , filePath ))
99- span .SetAttributes (attribute .Int64 ("size" , desc .Size ))
100-
101- h .manifest = & manifest
102- h .progress [desc .Digest ] = & status.ProgressItem {
103- Digest : desc .Digest ,
104- Path : filePath ,
105- Size : desc .Size ,
106- StartedAt : time .Now (),
107- FinishedAt : nil ,
108- Error : nil ,
109- Span : span ,
110- }
111-
112- h .progressCb (h .getProgress ())
113- }
114-
115- func (h * Hook ) AfterPullLayer (desc ocispec.Descriptor , err error ) {
116- h .mutex .Lock ()
117- defer h .mutex .Unlock ()
118-
119- progress := h .progress [desc .Digest ]
120- if progress == nil {
121- return
122- }
123-
124- metrics .NodePullOpObserve ("pull_layer" , progress .Size , progress .StartedAt , err )
125-
126- var finishedAt * time.Time
127- if err != nil {
128- logger .WithContext (h .ctx ).WithError (err ).Errorf ("failed to pull layer: %s%s (%s)" , progress .Digest , progress .Path , h .getProgressDesc ())
129- } else {
130- now := time .Now ()
131- finishedAt = & now
132- h .pulled .Add (1 )
133- duration := time .Since (progress .StartedAt )
134- logger .WithContext (h .ctx ).Infof (
135- "pulled layer: %s %s %s %s (%s) %s" ,
136- desc .MediaType , progress .Digest , progress .Path , humanize .Bytes (uint64 (progress .Size )), h .getProgressDesc (), duration ,
137- )
138- }
139-
140- progress .FinishedAt = finishedAt
141- progress .Error = err
142-
143- if err != nil {
144- progress .Span .SetStatus (otelCodes .Error , "failed to pull layer" )
145- progress .Span .RecordError (err )
146- }
147- progress .Span .End ()
148-
149- h .progressCb (h .getProgress ())
150- }
151-
152- func (p * puller ) checkLongPulling (ctx context.Context ) {
153- ticker := time .NewTicker (30 * time .Second )
154- defer ticker .Stop ()
155-
156- recorded := map [digest.Digest ]bool {}
157-
158- for {
159- select {
160- case <- ticker .C :
161- p .hook .mutex .Lock ()
162- for _ , progress := range p .hook .progress {
163- if progress .FinishedAt == nil &&
164- p .pullCfg .PullLayerTimeoutInSeconds > 0 &&
165- time .Since (progress .StartedAt ) > time .Duration (p .pullCfg .PullLayerTimeoutInSeconds )* time .Second &&
166- ! recorded [progress .Digest ] {
167- logger .WithContext (ctx ).Warnf ("pulling layer %s is taking too long: %s" , progress .Digest , time .Since (progress .StartedAt ))
168- metrics .NodePullLayerTooLong .Inc ()
169- recorded [progress .Digest ] = true
170- }
171- }
172- p .hook .mutex .Unlock ()
173- case <- ctx .Done ():
174- return
175- }
176- }
177- }
178-
179- func (h * Hook ) getProgress () status.Progress {
180- items := []status.ProgressItem {}
181- for _ , item := range h .progress {
182- items = append (items , * item )
183- }
184-
185- sort .Slice (items , func (i , j int ) bool {
186- if items [i ].StartedAt .Equal (items [j ].StartedAt ) {
187- return items [i ].Digest < items [j ].Digest
188- }
189- return items [i ].StartedAt .Before (items [j ].StartedAt )
190- })
191-
192- total := 0
193- if h .manifest != nil {
194- total = len (h .manifest .Layers )
195- }
196- return status.Progress {
197- Total : total ,
198- Items : items ,
199- }
200- }
201-
202- func (h * Hook ) GetProgress () status.Progress {
203- h .mutex .Lock ()
204- defer h .mutex .Unlock ()
205-
206- return h .getProgress ()
207- }
208-
20946func (p * puller ) Pull (ctx context.Context , reference , targetDir string , excludeModelWeights bool ) error {
21047 keyChain , err := auth .GetKeyChainByRef (reference )
21148 if err != nil {
@@ -231,8 +68,6 @@ func (p *puller) Pull(ctx context.Context, reference, targetDir string, excludeM
23168 }
23269
23370 if ! excludeModelWeights {
234- go p .checkLongPulling (ctx )
235-
23671 pullConfig := modctlConfig .NewPull ()
23772 pullConfig .Concurrency = int (p .pullCfg .Concurrency )
23873 pullConfig .PlainHTTP = plainHTTP
0 commit comments