Skip to content

Commit 9a61ddc

Browse files
authored
feat: TorchServe support (#34)
#### Motivation The Triton runtime can be used with model-mesh to serve PyTorch torchscript models, but it does not support arbitrary PyTorch models i.e. eager mode. KServe "classic" has integration with TorchServe but it would be good to have integration with model-mesh too so that these kinds of models can be used in distributed multi-model serving contexts. #### Modifications - Add adapter logic to implement the modelmesh management SPI using the torchserve gRPC management API - Build and include new adapter binary in the docker image - Add mock server and basic unit tests Implementation notes: - Model size (mem usage) is not returned from the `LoadModel` RPC but rather done separately in the `ModelSize` rpc (so that the model is available for use slightly sooner) - TorchServe's `DescribeModel` RPC is used to determine the model's memory usage. If that isn't successful it falls back to using a multiple of the model size on disk (similar to other runtimes) - The adapter writes the config file for TorchServe to consume TorchServe does not yet support the KServe V2 gRPC prediction API (only REST) which means that can't currently be used with model-mesh. The native TorchServe gRPC inference interface can be used instead for the time being. A smaller PR to the main modelmesh-serving controller repo will be opened to enable use of TorchServe, which will include the ServingRuntime specification. #### Result TorchServe can be used seamlessly with ModelMesh Serving to serve PyTorch models, including eager mode. Resolves kserve#4 Contributes to kserve/modelmesh-serving#63 Signed-off-by: Nick Hill <[email protected]>
1 parent f4c43a3 commit 9a61ddc

File tree

18 files changed

+3366
-1
lines changed

18 files changed

+3366
-1
lines changed

Dockerfile

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ RUN go build -o puller model-serving-puller/main.go
8686
RUN go build -o triton-adapter model-mesh-triton-adapter/main.go
8787
RUN go build -o mlserver-adapter model-mesh-mlserver-adapter/main.go
8888
RUN go build -o ovms-adapter model-mesh-ovms-adapter/main.go
89-
89+
RUN go build -o torchserve-adapter model-mesh-torchserve-adapter/main.go
9090

9191
###############################################################################
9292
# Stage 3: Copy build assets to create the smallest final runtime image
@@ -121,6 +121,8 @@ COPY --from=build /opt/app/triton-adapter /opt/app/
121121
COPY --from=build /opt/app/mlserver-adapter /opt/app/
122122
COPY --from=build /opt/app/model-mesh-triton-adapter/scripts/tf_pb.py /opt/scripts/
123123
COPY --from=build /opt/app/ovms-adapter /opt/app/
124+
COPY --from=build /opt/app/torchserve-adapter /opt/app/
125+
124126

125127
# Don't define an entrypoint. This is a multi-purpose image so the user should specify which binary they want to run (e.g. /opt/app/puller or /opt/app/triton-adapter)
126128
# ENTRYPOINT ["/opt/app/puller"]

internal/proto/torchserve/inference.pb.go

Lines changed: 330 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
// Copied from https://github.com/pytorch/serve/blob/8c23585d2453f230c411721028ad4b07e58cc7dd/frontend/server/src/main/resources/proto/inference.proto
2+
3+
syntax = "proto3";
4+
5+
package org.pytorch.serve.grpc.inference;
6+
7+
import "google/protobuf/empty.proto";
8+
9+
option java_multiple_files = true;
10+
11+
message PredictionsRequest {
12+
// Name of model.
13+
string model_name = 1; //required
14+
15+
// Version of model to run prediction on.
16+
string model_version = 2; //optional
17+
18+
// input data for model prediction
19+
map<string, bytes> input = 3; //required
20+
}
21+
22+
message PredictionResponse {
23+
// TorchServe health
24+
bytes prediction = 1;
25+
}
26+
27+
message TorchServeHealthResponse {
28+
// TorchServe health
29+
string health = 1;
30+
}
31+
32+
service InferenceAPIsService {
33+
rpc Ping(google.protobuf.Empty) returns (TorchServeHealthResponse) {}
34+
35+
// Predictions entry point to get inference using default model version.
36+
rpc Predictions(PredictionsRequest) returns (PredictionResponse) {}
37+
}

0 commit comments

Comments
 (0)