Skip to content

Commit 4aee00f

Browse files
Merge pull request #10 from Jooho/sync_kserve_230501
[Source Sync] 2023.05.01 version
2 parents 9a047f5 + 7fc855f commit 4aee00f

File tree

5 files changed

+107
-9
lines changed

5 files changed

+107
-9
lines changed

OWNERS

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,3 @@ reviewers:
1111
- Jooho
1212
- VedantMahabaleshwarkar
1313
- Xaenalt
14-

model-mesh-mlserver-adapter/server/adaptmodellayout_test.go

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,36 @@ var adaptModelLayoutTests = []adaptModelLayoutTestCase{
405405
},
406406
},
407407
},
408+
409+
// model with filename that alphabetically precedes model-settings.json
410+
411+
{
412+
ModelID: "model-filename-precedes-model-settings",
413+
ModelType: "sklearn",
414+
ModelPath: "data",
415+
InputFiles: []string{
416+
"data/model-settings.json",
417+
"data/aaaaa.json",
418+
},
419+
InputConfig: map[string]interface{}{
420+
"name": "model-name",
421+
"implementation": "mlserver_sklearn.SKLearnModel",
422+
"parameters": map[string]interface{}{
423+
"uri": "./aaaaa.json",
424+
},
425+
},
426+
ExpectedFiles: []string{
427+
"model-settings.json",
428+
"aaaaa.json",
429+
},
430+
ExpectedConfig: map[string]interface{}{
431+
"name": "model-filename-precedes-model-settings",
432+
"implementation": "mlserver_sklearn.SKLearnModel",
433+
"parameters": map[string]interface{}{
434+
"uri": filepath.Join(generatedMlserverModelsDir, "model-filename-precedes-model-settings", "aaaaa.json"),
435+
},
436+
},
437+
},
408438
}
409439

410440
func TestAdaptModelLayoutForRuntime(t *testing.T) {

model-mesh-mlserver-adapter/server/server.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,12 +190,18 @@ func adaptModelLayoutForRuntime(rootModelDir, modelID, modelType, modelPath, sch
190190
// check if the config file exists
191191
// if it does, we assume files are in the "native" repo structure
192192
assumeNativeLayout := false
193-
for _, f := range files {
193+
configFileIndex := -1
194+
for i, f := range files {
194195
if f.Name() == mlserverRepositoryConfigFilename {
195196
assumeNativeLayout = true
197+
configFileIndex = i
196198
break
197199
}
198200
}
201+
// always process the config file first to ensure that uri conversion within config file is correct
202+
if configFileIndex > 0 {
203+
files[0], files[configFileIndex] = files[configFileIndex], files[0]
204+
}
199205
if assumeNativeLayout {
200206
err = adaptNativeModelLayout(files, modelID, modelPath, schemaPath, mlserverModelIDDir, log)
201207
} else {

model-serving-puller/server/modelstate.go

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -103,13 +103,19 @@ func (m *modelStateManager) submitRequest(ctx context.Context, req grpcRequest)
103103
}
104104

105105
func (m *modelStateManager) loadModel(ctx context.Context, req *mmesh.LoadModelRequest) (*mmesh.LoadModelResponse, error) {
106-
resp, err := m.submitRequest(ctx, req)
107-
return resp.(*mmesh.LoadModelResponse), err
106+
res, err := m.submitRequest(ctx, req)
107+
if resp, ok := res.(*mmesh.LoadModelResponse); ok {
108+
return resp, err
109+
}
110+
return nil, err
108111
}
109112

110113
func (m *modelStateManager) unloadModel(ctx context.Context, req *mmesh.UnloadModelRequest) (*mmesh.UnloadModelResponse, error) {
111-
resp, err := m.submitRequest(ctx, req)
112-
return resp.(*mmesh.UnloadModelResponse), err
114+
res, err := m.submitRequest(ctx, req)
115+
if resp, ok := res.(*mmesh.UnloadModelResponse); ok {
116+
return resp, err
117+
}
118+
return nil, err
113119
}
114120

115121
func (m *modelStateManager) execute() {

model-serving-puller/server/modelstate_test.go

Lines changed: 60 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@ package server
1515

1616
import (
1717
"context"
18+
"errors"
1819
"testing"
20+
"time"
1921

2022
"github.com/kserve/modelmesh-runtime-adapter/internal/proto/mmesh"
2123
"sigs.k8s.io/controller-runtime/pkg/log/zap"
@@ -43,16 +45,71 @@ func TestStateManagerLoadModel(t *testing.T) {
4345
mockPullerServer := &mockPullerServer{}
4446
sm.s = mockPullerServer
4547

46-
req := &mmesh.LoadModelRequest{ModelId: "model-id"}
47-
sm.loadModel(context.Background(), req)
48+
load_req := &mmesh.LoadModelRequest{ModelId: "model-id"}
49+
sm.loadModel(context.Background(), load_req)
4850

4951
if mockPullerServer.loaded != 1 {
5052
t.Fatal("Load should have been called 1 time")
5153
}
5254
if mockPullerServer.unloaded != 0 {
53-
t.Fatal("Load should have been called 1 time")
55+
t.Fatal("Unload should not have been called")
56+
}
57+
if len(sm.data) > 0 {
58+
t.Fatal("StateManager map should be empty")
59+
}
60+
61+
// now unload the model
62+
unload_req := &mmesh.UnloadModelRequest{ModelId: "model-id"}
63+
sm.unloadModel(context.Background(), unload_req)
64+
65+
if mockPullerServer.unloaded != 1 {
66+
t.Fatal("Unload should now have been called")
5467
}
5568
if len(sm.data) > 0 {
5669
t.Fatal("StateManager map should be empty")
5770
}
5871
}
72+
73+
type mockPullerServerError struct {
74+
}
75+
76+
func (m *mockPullerServerError) loadModel(ctx context.Context, req *mmesh.LoadModelRequest) (*mmesh.LoadModelResponse, error) {
77+
// sleep to simulate a delay that could cause the context to be cancelled
78+
time.Sleep(50 * time.Millisecond)
79+
return nil, errors.New("failed load")
80+
}
81+
82+
func (m *mockPullerServerError) unloadModel(ctx context.Context, req *mmesh.UnloadModelRequest) (*mmesh.UnloadModelResponse, error) {
83+
return nil, errors.New("failed unload")
84+
}
85+
86+
func TestStateManagerErrors(t *testing.T) {
87+
log := zap.New()
88+
s := NewPullerServer(log)
89+
sm := s.sm
90+
mockPullerServerError := &mockPullerServerError{}
91+
sm.s = mockPullerServerError
92+
93+
// check that error returns are handled
94+
var err error
95+
load_req := &mmesh.LoadModelRequest{ModelId: "model-id"}
96+
_, err = sm.loadModel(context.Background(), load_req)
97+
if err == nil {
98+
t.Fatal("An error should have been returned")
99+
}
100+
101+
unload_req := &mmesh.UnloadModelRequest{ModelId: "model-id"}
102+
_, err = sm.unloadModel(context.Background(), unload_req)
103+
if err == nil {
104+
t.Fatal("An error should have been returned")
105+
}
106+
107+
// check that context cancellation is handled
108+
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond)
109+
defer cancel()
110+
_, err = sm.loadModel(ctx, load_req)
111+
if err == nil {
112+
t.Fatal("An error should have been returned")
113+
}
114+
115+
}

0 commit comments

Comments
 (0)