Skip to content

Commit 48883fd

Browse files
authored
fix: Propagate context to puller download tasks (#32)
#### Motivation For some reason, the `LoadModel` context is not passed via the puller's `ProcessLoadModelRequest` function to pullman's `Pull` function. This means that storage downloads won't be cancelled when the corresponding request is cancelled or times out. #### Modifications Add a `Context` argument to the puller's `ProcessLoadModelRequest` function, use in the call to pull manager's `Pull` instead of `context.TODO()`. #### Result In-progress model downloads from shared storage downloads will be cancelled when the corresponding request is cancelled or times out. Signed-off-by: Nick Hill <[email protected]>
1 parent 93ee2b2 commit 48883fd

File tree

6 files changed

+20
-19
lines changed

6 files changed

+20
-19
lines changed

model-mesh-mlserver-adapter/server/server.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ func (s *MLServerAdapterServer) LoadModel(ctx context.Context, req *mmesh.LoadMo
121121

122122
if s.AdapterConfig.UseEmbeddedPuller {
123123
var pullerErr error
124-
req, pullerErr = s.Puller.ProcessLoadModelRequest(req)
124+
req, pullerErr = s.Puller.ProcessLoadModelRequest(ctx, req)
125125
if pullerErr != nil {
126126
log.Error(pullerErr, "Failed to pull model from storage")
127127
return nil, pullerErr

model-mesh-ovms-adapter/server/server.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ func (s *OvmsAdapterServer) LoadModel(ctx context.Context, req *mmesh.LoadModelR
117117

118118
if s.AdapterConfig.UseEmbeddedPuller {
119119
var pullerErr error
120-
req, pullerErr = s.Puller.ProcessLoadModelRequest(req)
120+
req, pullerErr = s.Puller.ProcessLoadModelRequest(ctx, req)
121121
if pullerErr != nil {
122122
log.Error(pullerErr, "Failed to pull model from storage")
123123
return nil, pullerErr

model-mesh-triton-adapter/server/server.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ func (s *TritonAdapterServer) LoadModel(ctx context.Context, req *mmesh.LoadMode
106106

107107
if s.AdapterConfig.UseEmbeddedPuller {
108108
var pullerErr error
109-
req, pullerErr = s.Puller.ProcessLoadModelRequest(req)
109+
req, pullerErr = s.Puller.ProcessLoadModelRequest(ctx, req)
110110
if pullerErr != nil {
111111
log.Error(pullerErr, "Failed to pull model from storage")
112112
return nil, pullerErr

model-serving-puller/puller/puller.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ func NewPullerFromConfig(log logr.Logger, config *PullerConfiguration) *Puller {
9191
// - rewrite ModelPath to a local filesystem path
9292
// - rewrite ModelKey["schema_path"] to a local filesystem path
9393
// - add the size of the model on disk to ModelKey["disk_size_bytes"]
94-
func (s *Puller) ProcessLoadModelRequest(req *mmesh.LoadModelRequest) (*mmesh.LoadModelRequest, error) {
94+
func (s *Puller) ProcessLoadModelRequest(ctx context.Context, req *mmesh.LoadModelRequest) (*mmesh.LoadModelRequest, error) {
9595
var modelKey ModelKeyInfo
9696
if parseErr := json.Unmarshal([]byte(req.ModelKey), &modelKey); parseErr != nil {
9797
return nil, fmt.Errorf("Invalid modelKey in LoadModelRequest. Error processing JSON '%s': %w", req.ModelKey, parseErr)
@@ -177,7 +177,7 @@ func (s *Puller) ProcessLoadModelRequest(req *mmesh.LoadModelRequest) (*mmesh.Lo
177177
Directory: modelDir,
178178
Targets: targets,
179179
}
180-
pullerErr := s.PullManager.Pull(context.TODO(), pullCommand)
180+
pullerErr := s.PullManager.Pull(ctx, pullCommand)
181181
if pullerErr != nil {
182182
return nil, status.Errorf(status.Code(pullerErr), "Failed to pull model from storage due to error: %s", pullerErr)
183183
}

model-serving-puller/puller/puller_test.go

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
package puller
1515

1616
import (
17+
"context"
1718
"encoding/json"
1819
"fmt"
1920
"io/ioutil"
@@ -176,7 +177,7 @@ func Test_ProcessLoadModelRequest_Success_SingleFileModel(t *testing.T) {
176177

177178
mockPuller.EXPECT().Pull(gomock.Any(), eqPullCommand(&expectedPullCommand)).Return(nil).Times(1)
178179

179-
returnRequest, err := p.ProcessLoadModelRequest(request)
180+
returnRequest, err := p.ProcessLoadModelRequest(context.Background(), request)
180181
assert.Nil(t, err)
181182
assert.Equal(t, expectedRequestRewrite, returnRequest)
182183
}
@@ -216,7 +217,7 @@ func Test_ProcessLoadModelRequest_Success_MultiFileModel(t *testing.T) {
216217

217218
mockPuller.EXPECT().Pull(gomock.Any(), eqPullCommand(&expectedPullCommand)).Return(nil).Times(1)
218219

219-
returnRequest, err := p.ProcessLoadModelRequest(request)
220+
returnRequest, err := p.ProcessLoadModelRequest(context.Background(), request)
220221
assert.Nil(t, err)
221222
assert.Equal(t, expectedRequestRewrite, returnRequest)
222223
}
@@ -262,7 +263,7 @@ func Test_ProcessLoadModelRequest_SuccessWithSchema(t *testing.T) {
262263

263264
mockPuller.EXPECT().Pull(gomock.Any(), eqPullCommand(&expectedPullCommand)).Return(nil).Times(1)
264265

265-
returnRequest, err := p.ProcessLoadModelRequest(request)
266+
returnRequest, err := p.ProcessLoadModelRequest(context.Background(), request)
266267
assert.Nil(t, err)
267268
assert.Equal(t, expectedRequestRewrite, returnRequest)
268269
}
@@ -302,7 +303,7 @@ func Test_ProcessLoadModelRequest_SuccessWithBucket(t *testing.T) {
302303

303304
mockPuller.EXPECT().Pull(gomock.Any(), eqPullCommand(&expectedPullCommand)).Return(nil).Times(1)
304305

305-
returnRequest, err := p.ProcessLoadModelRequest(request)
306+
returnRequest, err := p.ProcessLoadModelRequest(context.Background(), request)
306307
assert.Nil(t, err)
307308
assert.Equal(t, expectedRequestRewrite, returnRequest)
308309
}
@@ -342,7 +343,7 @@ func Test_ProcessLoadModelRequest_SuccessNoBucket(t *testing.T) {
342343

343344
mockPuller.EXPECT().Pull(gomock.Any(), eqPullCommand(&expectedPullCommand)).Return(nil).Times(1)
344345

345-
returnRequest, err := p.ProcessLoadModelRequest(request)
346+
returnRequest, err := p.ProcessLoadModelRequest(context.Background(), request)
346347
assert.Nil(t, err)
347348
assert.Equal(t, expectedRequestRewrite, returnRequest)
348349
}
@@ -382,7 +383,7 @@ func Test_ProcessLoadModelRequest_SuccessNoBucketNoStorageParams(t *testing.T) {
382383

383384
mockPuller.EXPECT().Pull(gomock.Any(), eqPullCommand(&expectedPullCommand)).Return(nil).Times(1)
384385

385-
returnRequest, err := p.ProcessLoadModelRequest(request)
386+
returnRequest, err := p.ProcessLoadModelRequest(context.Background(), request)
386387
assert.Nil(t, err)
387388
assert.Equal(t, expectedRequestRewrite, returnRequest)
388389
}
@@ -419,7 +420,7 @@ func Test_ProcessLoadModelRequest_SuccessStorageTypeOnly(t *testing.T) {
419420

420421
mockPuller.EXPECT().Pull(gomock.Any(), eqPullCommand(&expectedPullCommand)).Return(nil).Times(1)
421422

422-
returnRequest, err := p.ProcessLoadModelRequest(request)
423+
returnRequest, err := p.ProcessLoadModelRequest(context.Background(), request)
423424
assert.Nil(t, err)
424425
assert.Equal(t, expectedRequestRewrite, returnRequest)
425426
}
@@ -466,7 +467,7 @@ func Test_ProcessLoadModelRequest_DefaultStorageKey(t *testing.T) {
466467

467468
mockPuller.EXPECT().Pull(gomock.Any(), eqPullCommand(&expectedPullCommand)).Return(nil).Times(1)
468469

469-
returnRequest, err := p.ProcessLoadModelRequest(request)
470+
returnRequest, err := p.ProcessLoadModelRequest(context.Background(), request)
470471
assert.Nil(t, err)
471472
assert.Equal(t, expectedRequestRewrite, returnRequest)
472473
}
@@ -504,7 +505,7 @@ func Test_ProcessLoadModelRequest_DefaultStorageKeyTyped(t *testing.T) {
504505

505506
mockPuller.EXPECT().Pull(gomock.Any(), eqPullCommand(&expectedPullCommand)).Return(nil).Times(1)
506507

507-
returnRequest, err := p.ProcessLoadModelRequest(request)
508+
returnRequest, err := p.ProcessLoadModelRequest(context.Background(), request)
508509
assert.Nil(t, err)
509510
assert.Equal(t, expectedRequestRewrite, returnRequest)
510511
}
@@ -544,7 +545,7 @@ func Test_ProcessLoadModelRequest_StorageParamsOverrides(t *testing.T) {
544545

545546
mockPuller.EXPECT().Pull(gomock.Any(), eqPullCommand(&expectedPullCommand)).Return(nil).Times(1)
546547

547-
returnRequest, err := p.ProcessLoadModelRequest(request)
548+
returnRequest, err := p.ProcessLoadModelRequest(context.Background(), request)
548549
assert.Nil(t, err)
549550
assert.Equal(t, expectedRequestRewrite, returnRequest)
550551
}
@@ -559,7 +560,7 @@ func Test_ProcessLoadModelRequest_FailInvalidModelKey(t *testing.T) {
559560

560561
p, _ := newPullerWithMock(t)
561562

562-
returnRequest, err := p.ProcessLoadModelRequest(request)
563+
returnRequest, err := p.ProcessLoadModelRequest(context.Background(), request)
563564
assert.Contains(t, err.Error(), "Invalid modelKey in LoadModelRequest")
564565
assert.Nil(t, returnRequest)
565566
}
@@ -574,7 +575,7 @@ func Test_ProcessLoadModelRequest_FailInvalidSchemaPath(t *testing.T) {
574575

575576
p, _ := newPullerWithMock(t)
576577

577-
returnRequest, err := p.ProcessLoadModelRequest(request)
578+
returnRequest, err := p.ProcessLoadModelRequest(context.Background(), request)
578579
assert.Nil(t, returnRequest)
579580
assert.Error(t, err)
580581
assert.Contains(t, err.Error(), "Invalid modelKey in LoadModelRequest")
@@ -591,7 +592,7 @@ func Test_ProcessLoadModelRequest_FailMissingStorageKeyAndType(t *testing.T) {
591592

592593
p, _ := newPullerWithMock(t)
593594

594-
returnRequest, err := p.ProcessLoadModelRequest(request)
595+
returnRequest, err := p.ProcessLoadModelRequest(context.Background(), request)
595596
assert.Nil(t, returnRequest)
596597
assert.EqualError(t, err, expectedError)
597598
}

model-serving-puller/server/server.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ func (s *PullerServer) loadModel(ctx context.Context, req *mmesh.LoadModelReques
129129

130130
// Pull the model from storage
131131
var pullerErr error
132-
req, pullerErr = s.puller.ProcessLoadModelRequest(req)
132+
req, pullerErr = s.puller.ProcessLoadModelRequest(ctx, req)
133133
if pullerErr != nil {
134134
log.Error(pullerErr, "Failed to pull model from storage")
135135
return nil, pullerErr

0 commit comments

Comments
 (0)