Skip to content

Commit 0658ae6

Browse files
authored
hook: refactor to avoid write status to disk (#12)
The hook currently executes during each layer pull, updating progress info to the status file, which is a disk write operation that may impact model pull concurrency on lower-performance disks. This PR refactors the hook implementation to avoid writing per-layer progress to the disk-based status file, instead keeping it cached in memory, users can still get the model pull progress via the API as before. Signed-off-by: imeoer <[email protected]>
1 parent 4a710cf commit 0658ae6

File tree

6 files changed

+216
-186
lines changed

6 files changed

+216
-186
lines changed

pkg/server/server_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ const (
2929
type mockPuller struct {
3030
pullCfg *config.PullConfig
3131
duration time.Duration
32-
hook *service.Hook
32+
hook *status.Hook
3333
}
3434

3535
func (puller *mockPuller) Pull(
@@ -560,7 +560,7 @@ func TestServer(t *testing.T) {
560560
cfg.Get().PullConfig.ProxyURL = ""
561561
service.CacheSacnInterval = 1 * time.Second
562562

563-
service.NewPuller = func(ctx context.Context, pullCfg *config.PullConfig, hook *service.Hook, diskQuotaChecker *service.DiskQuotaChecker) service.Puller {
563+
service.NewPuller = func(ctx context.Context, pullCfg *config.PullConfig, hook *status.Hook, diskQuotaChecker *service.DiskQuotaChecker) service.Puller {
564564
return &mockPuller{
565565
pullCfg: pullCfg,
566566
duration: time.Second * 2,

pkg/service/controller_local.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,8 @@ func (s *Service) localListVolumes(
190190
logger.WithContext(ctx).WithError(err).Errorf("failed to get volume status")
191191
return nil, status.Error(codes.Internal, err.Error())
192192
}
193-
progress, err := modelStatus.Progress.String()
193+
progress := s.worker.sm.HookManager.GetProgress(statusPath)
194+
progressStr, err := progress.String()
194195
if err != nil {
195196
logger.WithContext(ctx).WithError(err).Errorf("failed to marshal progress")
196197
return nil, status.Error(codes.Internal, err.Error())
@@ -201,7 +202,7 @@ func (s *Service) localListVolumes(
201202
VolumeContext: map[string]string{
202203
s.cfg.Get().ParameterKeyReference(): modelStatus.Reference,
203204
s.cfg.Get().ParameterKeyStatusState(): modelStatus.State,
204-
s.cfg.Get().ParameterKeyStatusProgress(): progress,
205+
s.cfg.Get().ParameterKeyStatusProgress(): progressStr,
205206
},
206207
},
207208
}, nil

pkg/service/puller.go

Lines changed: 2 additions & 167 deletions
Original file line numberDiff line numberDiff line change
@@ -2,30 +2,18 @@ package service
22

33
import (
44
"context"
5-
"fmt"
65
"io"
76
"os"
8-
"sort"
97
"strings"
10-
"sync"
11-
"sync/atomic"
12-
"time"
138

14-
"github.com/dustin/go-humanize"
159
"github.com/modelpack/modctl/pkg/backend"
1610
modctlConfig "github.com/modelpack/modctl/pkg/config"
1711
"github.com/modelpack/model-csi-driver/pkg/config"
1812
"github.com/modelpack/model-csi-driver/pkg/config/auth"
1913
"github.com/modelpack/model-csi-driver/pkg/logger"
20-
"github.com/modelpack/model-csi-driver/pkg/metrics"
2114
"github.com/modelpack/model-csi-driver/pkg/status"
22-
"github.com/modelpack/model-csi-driver/pkg/tracing"
23-
modelspec "github.com/modelpack/model-spec/specs-go/v1"
24-
"github.com/opencontainers/go-digest"
2515
ocispec "github.com/opencontainers/image-spec/specs-go/v1"
2616
"github.com/pkg/errors"
27-
"go.opentelemetry.io/otel/attribute"
28-
otelCodes "go.opentelemetry.io/otel/codes"
2917
)
3018

3119
const (
@@ -37,28 +25,11 @@ type PullHook interface {
3725
AfterPullLayer(desc ocispec.Descriptor, err error)
3826
}
3927

40-
type Hook struct {
41-
ctx context.Context
42-
mutex sync.Mutex
43-
manifest *ocispec.Manifest
44-
pulled atomic.Uint32
45-
progress map[digest.Digest]*status.ProgressItem
46-
progressCb func(progress status.Progress)
47-
}
48-
49-
func NewHook(ctx context.Context, progressCb func(progress status.Progress)) *Hook {
50-
return &Hook{
51-
ctx: ctx,
52-
progress: make(map[digest.Digest]*status.ProgressItem),
53-
progressCb: progressCb,
54-
}
55-
}
56-
5728
type Puller interface {
5829
Pull(ctx context.Context, reference, targetDir string, excludeModelWeights bool) error
5930
}
6031

61-
var NewPuller = func(ctx context.Context, pullCfg *config.PullConfig, hook *Hook, diskQuotaChecker *DiskQuotaChecker) Puller {
32+
var NewPuller = func(ctx context.Context, pullCfg *config.PullConfig, hook *status.Hook, diskQuotaChecker *DiskQuotaChecker) Puller {
6233
return &puller{
6334
pullCfg: pullCfg,
6435
hook: hook,
@@ -68,144 +39,10 @@ var NewPuller = func(ctx context.Context, pullCfg *config.PullConfig, hook *Hook
6839

6940
type puller struct {
7041
pullCfg *config.PullConfig
71-
hook *Hook
42+
hook *status.Hook
7243
diskQuotaChecker *DiskQuotaChecker
7344
}
7445

75-
func (h *Hook) getProgressDesc() string {
76-
finished := h.pulled.Load()
77-
if h.manifest == nil {
78-
return fmt.Sprintf("%d/unknown", finished)
79-
}
80-
81-
total := len(h.manifest.Layers)
82-
83-
return fmt.Sprintf("%d/%d", finished, total)
84-
}
85-
86-
func (h *Hook) BeforePullLayer(desc ocispec.Descriptor, manifest ocispec.Manifest) {
87-
h.mutex.Lock()
88-
defer h.mutex.Unlock()
89-
90-
filePath := ""
91-
if desc.Annotations != nil && desc.Annotations[modelspec.AnnotationFilepath] != "" {
92-
filePath = fmt.Sprintf("/%s", desc.Annotations[modelspec.AnnotationFilepath])
93-
}
94-
95-
_, span := tracing.Tracer.Start(h.ctx, "PullLayer")
96-
span.SetAttributes(attribute.String("digest", desc.Digest.String()))
97-
span.SetAttributes(attribute.String("media_type", desc.MediaType))
98-
span.SetAttributes(attribute.String("file_path", filePath))
99-
span.SetAttributes(attribute.Int64("size", desc.Size))
100-
101-
h.manifest = &manifest
102-
h.progress[desc.Digest] = &status.ProgressItem{
103-
Digest: desc.Digest,
104-
Path: filePath,
105-
Size: desc.Size,
106-
StartedAt: time.Now(),
107-
FinishedAt: nil,
108-
Error: nil,
109-
Span: span,
110-
}
111-
112-
h.progressCb(h.getProgress())
113-
}
114-
115-
func (h *Hook) AfterPullLayer(desc ocispec.Descriptor, err error) {
116-
h.mutex.Lock()
117-
defer h.mutex.Unlock()
118-
119-
progress := h.progress[desc.Digest]
120-
if progress == nil {
121-
return
122-
}
123-
124-
metrics.NodePullOpObserve("pull_layer", progress.Size, progress.StartedAt, err)
125-
126-
var finishedAt *time.Time
127-
if err != nil {
128-
logger.WithContext(h.ctx).WithError(err).Errorf("failed to pull layer: %s%s (%s)", progress.Digest, progress.Path, h.getProgressDesc())
129-
} else {
130-
now := time.Now()
131-
finishedAt = &now
132-
h.pulled.Add(1)
133-
duration := time.Since(progress.StartedAt)
134-
logger.WithContext(h.ctx).Infof(
135-
"pulled layer: %s %s %s %s (%s) %s",
136-
desc.MediaType, progress.Digest, progress.Path, humanize.Bytes(uint64(progress.Size)), h.getProgressDesc(), duration,
137-
)
138-
}
139-
140-
progress.FinishedAt = finishedAt
141-
progress.Error = err
142-
143-
if err != nil {
144-
progress.Span.SetStatus(otelCodes.Error, "failed to pull layer")
145-
progress.Span.RecordError(err)
146-
}
147-
progress.Span.End()
148-
149-
h.progressCb(h.getProgress())
150-
}
151-
152-
func (p *puller) checkLongPulling(ctx context.Context) {
153-
ticker := time.NewTicker(30 * time.Second)
154-
defer ticker.Stop()
155-
156-
recorded := map[digest.Digest]bool{}
157-
158-
for {
159-
select {
160-
case <-ticker.C:
161-
p.hook.mutex.Lock()
162-
for _, progress := range p.hook.progress {
163-
if progress.FinishedAt == nil &&
164-
p.pullCfg.PullLayerTimeoutInSeconds > 0 &&
165-
time.Since(progress.StartedAt) > time.Duration(p.pullCfg.PullLayerTimeoutInSeconds)*time.Second &&
166-
!recorded[progress.Digest] {
167-
logger.WithContext(ctx).Warnf("pulling layer %s is taking too long: %s", progress.Digest, time.Since(progress.StartedAt))
168-
metrics.NodePullLayerTooLong.Inc()
169-
recorded[progress.Digest] = true
170-
}
171-
}
172-
p.hook.mutex.Unlock()
173-
case <-ctx.Done():
174-
return
175-
}
176-
}
177-
}
178-
179-
func (h *Hook) getProgress() status.Progress {
180-
items := []status.ProgressItem{}
181-
for _, item := range h.progress {
182-
items = append(items, *item)
183-
}
184-
185-
sort.Slice(items, func(i, j int) bool {
186-
if items[i].StartedAt.Equal(items[j].StartedAt) {
187-
return items[i].Digest < items[j].Digest
188-
}
189-
return items[i].StartedAt.Before(items[j].StartedAt)
190-
})
191-
192-
total := 0
193-
if h.manifest != nil {
194-
total = len(h.manifest.Layers)
195-
}
196-
return status.Progress{
197-
Total: total,
198-
Items: items,
199-
}
200-
}
201-
202-
func (h *Hook) GetProgress() status.Progress {
203-
h.mutex.Lock()
204-
defer h.mutex.Unlock()
205-
206-
return h.getProgress()
207-
}
208-
20946
func (p *puller) Pull(ctx context.Context, reference, targetDir string, excludeModelWeights bool) error {
21047
keyChain, err := auth.GetKeyChainByRef(reference)
21148
if err != nil {
@@ -231,8 +68,6 @@ func (p *puller) Pull(ctx context.Context, reference, targetDir string, excludeM
23168
}
23269

23370
if !excludeModelWeights {
234-
go p.checkLongPulling(ctx)
235-
23671
pullConfig := modctlConfig.NewPull()
23772
pullConfig.Concurrency = int(p.pullCfg.Concurrency)
23873
pullConfig.PlainHTTP = plainHTTP

pkg/service/worker.go

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ func (cm *ContextMap) Get(key string) *context.CancelFunc {
5252

5353
type Worker struct {
5454
cfg *config.Config
55-
newPuller func(ctx context.Context, pullCfg *config.PullConfig, hook *Hook, diskQuotaChecker *DiskQuotaChecker) Puller
55+
newPuller func(ctx context.Context, pullCfg *config.PullConfig, hook *status.Hook, diskQuotaChecker *DiskQuotaChecker) Puller
5656
sm *status.StatusManager
5757
inflight singleflight.Group
5858
contextMap *ContextMap
@@ -81,7 +81,6 @@ func (worker *Worker) deleteModel(ctx context.Context, isStaticVolume bool, volu
8181
if err := worker.kmutex.Lock(context.Background(), contextKey); err != nil {
8282
return nil, errors.Wrapf(err, "lock context key: %s", contextKey)
8383
}
84-
8584
defer worker.kmutex.Unlock(contextKey)
8685

8786
volumeDir := worker.cfg.Get().GetVolumeDir(volumeName)
@@ -100,8 +99,13 @@ func (worker *Worker) deleteModel(ctx context.Context, isStaticVolume bool, volu
10099
return nil, errors.Wrapf(err, "retry remove volume dir: %s", volumeDir)
101100
}
102101
logger.WithContext(ctx).Infof("removed volume dir: %s", volumeDir)
102+
103+
statusPath := filepath.Join(volumeDir, "status.json")
104+
worker.sm.HookManager.Delete(statusPath)
105+
103106
return nil, nil
104107
})
108+
105109
return err
106110
}
107111

@@ -139,13 +143,12 @@ func (worker *Worker) PullModel(
139143
}
140144

141145
func (worker *Worker) pullModel(ctx context.Context, statusPath, volumeName, mountID, reference, modelDir string, checkDiskQuota, excludeModelWeights bool) error {
142-
setStatus := func(state status.State, progress status.Progress) (*status.Status, error) {
146+
setStatus := func(state status.State) (*status.Status, error) {
143147
status, err := worker.sm.Set(statusPath, status.Status{
144148
VolumeName: volumeName,
145149
MountID: mountID,
146150
Reference: reference,
147151
State: state,
148-
Progress: progress,
149152
})
150153
if err != nil {
151154
return nil, errors.Wrapf(err, "set model status")
@@ -181,41 +184,39 @@ func (worker *Worker) pullModel(ctx context.Context, statusPath, volumeName, mou
181184
return nil, errors.Wrapf(err, "cleanup model directory before pull: %s", modelDir)
182185
}
183186

184-
hook := NewHook(ctx, func(progress status.Progress) {
185-
if _, err := setStatus(status.StatePullRunning, progress); err != nil {
186-
logger.WithContext(ctx).WithError(err).Errorf("set model status: %v", err)
187-
}
188-
})
187+
hook := status.NewHook(ctx)
188+
worker.sm.HookManager.Set(statusPath, hook)
189+
189190
var diskQuotaChecker *DiskQuotaChecker
190191
checkDiskQuota := worker.cfg.Get().Features.CheckDiskQuota && checkDiskQuota && !worker.isModelExisted(ctx, reference)
191192
if checkDiskQuota {
192193
diskQuotaChecker = NewDiskQuotaChecker(worker.cfg)
193194
}
194195
puller := worker.newPuller(ctx, &worker.cfg.Get().PullConfig, hook, diskQuotaChecker)
195-
_, err := setStatus(status.StatePullRunning, hook.GetProgress())
196+
_, err := setStatus(status.StatePullRunning)
196197
if err != nil {
197198
return nil, errors.Wrapf(err, "set status before pull model")
198199
}
199200
if err := puller.Pull(ctx, reference, modelDir, excludeModelWeights); err != nil {
200201
if errors.Is(err, context.Canceled) {
201202
err = errors.Wrapf(err, "pull model canceled")
202-
if _, err2 := setStatus(status.StatePullCanceled, hook.GetProgress()); err2 != nil {
203+
if _, err2 := setStatus(status.StatePullCanceled); err2 != nil {
203204
return nil, errors.Wrapf(err, "set model status: %v", err2)
204205
}
205206
} else if errors.Is(err, context.DeadlineExceeded) {
206207
err = errors.Wrapf(err, "pull model timeout")
207-
if _, err2 := setStatus(status.StatePullTimeout, hook.GetProgress()); err2 != nil {
208+
if _, err2 := setStatus(status.StatePullTimeout); err2 != nil {
208209
return nil, errors.Wrapf(err, "set model status: %v", err2)
209210
}
210211
} else {
211212
err = errors.Wrapf(err, "pull model failed")
212-
if _, err2 := setStatus(status.StatePullFailed, hook.GetProgress()); err2 != nil {
213+
if _, err2 := setStatus(status.StatePullFailed); err2 != nil {
213214
return nil, errors.Wrapf(err, "set model status: %v", err2)
214215
}
215216
}
216217
return nil, err
217218
}
218-
_, err = setStatus(status.StatePullSucceeded, hook.GetProgress())
219+
_, err = setStatus(status.StatePullSucceeded)
219220
if err != nil {
220221
return nil, errors.Wrapf(err, "set status after pull model succeeded")
221222
}

0 commit comments

Comments
 (0)