Skip to content

Commit 3632b16

Browse files
fix: don't panic for some errors in modelstate (#42)
#### Motivation If a model load timeout is hit a panic is generated in the modelStateManager: ``` panic: interface conversion: interface {} is nil, not *mmesh.LoadModelResponse goroutine 1397 [running]: github.com/kserve/modelmesh-runtime-adapter/model-serving-puller/server.(*modelStateManager).loadModel(...) /opt/app-root/src/model-serving-puller/server/modelstate.go:107 github.com/kserve/modelmesh-runtime-adapter/model-serving-puller/server.(*PullerServer).LoadModel(0xc00016c550, {0x12df4f8, 0xc0005fc960}, 0x485800) /opt/app-root/src/model-serving-puller/server/server.go:116 +0xad github.com/kserve/modelmesh-runtime-adapter/internal/proto/mmesh._ModelRuntime_LoadModel_Handler({0x10219e0, 0xc00016c550}, {0x12df4f8, 0xc0005fc960}, 0xc0004602a0, 0x0) /opt/app-root/src/internal/proto/mmesh/model-runtime_grpc.pb.go:181 +0x170 google.golang.org/grpc.(*Server).processUnaryRPC(0xc00054a1c0, {0x12f37d0, 0xc0006ca4e0}, 0xc000b326c0, 0xc00027daa0, 0x1a6f780, 0x0) /remote-source/deps/gomod/pkg/mod/google.golang.org/[email protected]/server.go:1301 +0xb03 google.golang.org/grpc.(*Server).handleStream(0xc00054a1c0, {0x12f37d0, 0xc0006ca4e0}, 0xc000b326c0, 0x0) /remote-source/deps/gomod/pkg/mod/google.golang.org/[email protected]/server.go:1642 +0xa2a google.golang.org/grpc.(*Server).serveStreams.func1.2() /remote-source/deps/gomod/pkg/mod/google.golang.org/[email protected]/server.go:938 +0x98 created by google.golang.org/grpc.(*Server).serveStreams.func1 /remote-source/deps/gomod/pkg/mod/google.golang.org/[email protected]/server.go:936 +0x294 ``` In a couple of error cases in `submitRequest`, `nil` is returned as the first return value with the error. The code in `loadModel` and `unloadModel` always attempts to cast the value to a pointer to a response, but this will panic if attempting to convert `nil`. #### Modifications - add a test to reproduce the panic - change the code to use a comma-ok type assertion instead of panicking #### Result The puller/adapter doesn't crash when a model load times out. Signed-off-by: Travis Johnson <[email protected]>
1 parent 6992097 commit 3632b16

File tree

2 files changed

+70
-7
lines changed

2 files changed

+70
-7
lines changed

model-serving-puller/server/modelstate.go

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -103,13 +103,19 @@ func (m *modelStateManager) submitRequest(ctx context.Context, req grpcRequest)
103103
}
104104

105105
func (m *modelStateManager) loadModel(ctx context.Context, req *mmesh.LoadModelRequest) (*mmesh.LoadModelResponse, error) {
106-
resp, err := m.submitRequest(ctx, req)
107-
return resp.(*mmesh.LoadModelResponse), err
106+
res, err := m.submitRequest(ctx, req)
107+
if resp, ok := res.(*mmesh.LoadModelResponse); ok {
108+
return resp, err
109+
}
110+
return nil, err
108111
}
109112

110113
func (m *modelStateManager) unloadModel(ctx context.Context, req *mmesh.UnloadModelRequest) (*mmesh.UnloadModelResponse, error) {
111-
resp, err := m.submitRequest(ctx, req)
112-
return resp.(*mmesh.UnloadModelResponse), err
114+
res, err := m.submitRequest(ctx, req)
115+
if resp, ok := res.(*mmesh.UnloadModelResponse); ok {
116+
return resp, err
117+
}
118+
return nil, err
113119
}
114120

115121
func (m *modelStateManager) execute() {

model-serving-puller/server/modelstate_test.go

Lines changed: 60 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@ package server
1515

1616
import (
1717
"context"
18+
"errors"
1819
"testing"
20+
"time"
1921

2022
"github.com/kserve/modelmesh-runtime-adapter/internal/proto/mmesh"
2123
"sigs.k8s.io/controller-runtime/pkg/log/zap"
@@ -43,16 +45,71 @@ func TestStateManagerLoadModel(t *testing.T) {
4345
mockPullerServer := &mockPullerServer{}
4446
sm.s = mockPullerServer
4547

46-
req := &mmesh.LoadModelRequest{ModelId: "model-id"}
47-
sm.loadModel(context.Background(), req)
48+
load_req := &mmesh.LoadModelRequest{ModelId: "model-id"}
49+
sm.loadModel(context.Background(), load_req)
4850

4951
if mockPullerServer.loaded != 1 {
5052
t.Fatal("Load should have been called 1 time")
5153
}
5254
if mockPullerServer.unloaded != 0 {
53-
t.Fatal("Load should have been called 1 time")
55+
t.Fatal("Unload should not have been called")
56+
}
57+
if len(sm.data) > 0 {
58+
t.Fatal("StateManager map should be empty")
59+
}
60+
61+
// now unload the model
62+
unload_req := &mmesh.UnloadModelRequest{ModelId: "model-id"}
63+
sm.unloadModel(context.Background(), unload_req)
64+
65+
if mockPullerServer.unloaded != 1 {
66+
t.Fatal("Unload should now have been called")
5467
}
5568
if len(sm.data) > 0 {
5669
t.Fatal("StateManager map should be empty")
5770
}
5871
}
72+
73+
type mockPullerServerError struct {
74+
}
75+
76+
func (m *mockPullerServerError) loadModel(ctx context.Context, req *mmesh.LoadModelRequest) (*mmesh.LoadModelResponse, error) {
77+
// sleep to simulate a delay that could cause the context to be cancelled
78+
time.Sleep(50 * time.Millisecond)
79+
return nil, errors.New("failed load")
80+
}
81+
82+
func (m *mockPullerServerError) unloadModel(ctx context.Context, req *mmesh.UnloadModelRequest) (*mmesh.UnloadModelResponse, error) {
83+
return nil, errors.New("failed unload")
84+
}
85+
86+
func TestStateManagerErrors(t *testing.T) {
87+
log := zap.New()
88+
s := NewPullerServer(log)
89+
sm := s.sm
90+
mockPullerServerError := &mockPullerServerError{}
91+
sm.s = mockPullerServerError
92+
93+
// check that error returns are handled
94+
var err error
95+
load_req := &mmesh.LoadModelRequest{ModelId: "model-id"}
96+
_, err = sm.loadModel(context.Background(), load_req)
97+
if err == nil {
98+
t.Fatal("An error should have been returned")
99+
}
100+
101+
unload_req := &mmesh.UnloadModelRequest{ModelId: "model-id"}
102+
_, err = sm.unloadModel(context.Background(), unload_req)
103+
if err == nil {
104+
t.Fatal("An error should have been returned")
105+
}
106+
107+
// check that context cancellation is handled
108+
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond)
109+
defer cancel()
110+
_, err = sm.loadModel(ctx, load_req)
111+
if err == nil {
112+
t.Fatal("An error should have been returned")
113+
}
114+
115+
}

0 commit comments

Comments
 (0)