Skip to content

Commit e94d2a9

Browse files
Raise TritonModelException if the Triton model has an error (#333)
* Raise TritonModelException if the PyTorch Triton model has an error * Raise from exc in executor_model * Raise RuntimeError after every InferenceRequest if there is an error * Add test of PredictPyTorch with Triton * Apply pre-commit formatting after merge --------- Co-authored-by: Karl Higley <karlb@nvidia.com> Co-authored-by: Karl Higley <kmhigley@gmail.com>
1 parent 20d1242 commit e94d2a9

File tree

6 files changed

+114
-7
lines changed

6 files changed

+114
-7
lines changed

merlin/systems/dag/runtimes/triton/ops/fil.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,10 @@ def transform(
152152
self.fil_model_name, inputs, input_schema, output_schema
153153
)
154154
inference_response = inference_request.exec()
155+
156+
if inference_response.has_error():
157+
raise RuntimeError(str(inference_response.error().message()))
158+
155159
return triton_response_to_tensor_table(inference_response, type(inputs), output_schema)
156160

157161

merlin/systems/dag/runtimes/triton/ops/pytorch.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,9 @@ def transform(self, col_selector: ColumnSelector, transformable: Transformable):
8383

8484
inference_response = inference_request.exec()
8585

86+
if inference_response.has_error():
87+
raise RuntimeError(str(inference_response.error().message()))
88+
8689
return triton_response_to_tensor_table(
8790
inference_response, type(transformable), self.output_schema
8891
)

merlin/systems/dag/runtimes/triton/ops/tensorflow.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,9 @@ def transform(self, col_selector: ColumnSelector, transformable: Transformable):
7676
)
7777
inference_response = inference_request.exec()
7878

79+
if inference_response.has_error():
80+
raise RuntimeError(inference_response.error().message())
81+
7982
# TODO: Validate that the outputs match the schema
8083
return triton_response_to_tensor_table(
8184
inference_response, type(transformable), self.output_schema

merlin/systems/dag/runtimes/triton/ops/workflow.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
from shutil import copyfile
2121

2222
import tritonclient.grpc.model_config_pb2 as model_config
23-
import tritonclient.utils
2423
from google.protobuf import text_format
2524

2625
from merlin.core.protocols import Transformable
@@ -89,12 +88,8 @@ def transform(self, col_selector: ColumnSelector, transformable: Transformable):
8988

9089
inference_response = inference_request.exec()
9190

92-
# check inference response for errors:
9391
if inference_response.has_error():
94-
# Cannot raise inference response error because it is not derived from BaseException
95-
raise tritonclient.utils.InferenceServerException(
96-
str(inference_response.error().message())
97-
)
92+
raise RuntimeError(inference_response.error().message())
9893

9994
response_table = triton_response_to_tensor_table(
10095
inference_response, type(transformable), self.output_schema

merlin/systems/triton/models/executor_model.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
import pathlib
2727
from pathlib import Path
2828

29+
import triton_python_backend_utils as pb_utils
30+
2931
from merlin.dag import postorder_iter_nodes
3032
from merlin.systems.dag import Ensemble
3133
from merlin.systems.dag.runtimes.triton import TritonExecutorRuntime
@@ -93,7 +95,10 @@ def execute(self, request):
9395
be the same as `requests`
9496
"""
9597
inputs = triton_request_to_tensor_table(request, self.ensemble.input_schema)
96-
outputs = self.ensemble.transform(inputs, runtime=TritonExecutorRuntime())
98+
try:
99+
outputs = self.ensemble.transform(inputs, runtime=TritonExecutorRuntime())
100+
except Exception as exc:
101+
raise pb_utils.TritonModelException(str(exc)) from exc
97102
return tensor_table_to_triton_response(outputs, self.ensemble.output_schema)
98103

99104

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
#
2+
# Copyright (c) 2023, NVIDIA CORPORATION.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
#
16+
17+
import shutil
18+
19+
import numpy as np
20+
import pandas as pd
21+
import pytest
22+
import tritonclient.utils
23+
24+
from merlin.schema import ColumnSchema, Schema
25+
from merlin.systems.dag.ensemble import Ensemble
26+
from merlin.systems.dag.ops.pytorch import PredictPyTorch
27+
from merlin.systems.triton.utils import run_ensemble_on_tritonserver
28+
29+
torch = pytest.importorskip("torch")
30+
31+
TRITON_SERVER_PATH = shutil.which("tritonserver")
32+
33+
34+
@pytest.mark.skipif(not TRITON_SERVER_PATH, reason="triton server not found")
35+
def test_model_in_ensemble(tmpdir):
36+
class MyModel(torch.nn.Module):
37+
def forward(self, x):
38+
v = torch.stack(list(x.values())).sum(axis=0)
39+
return v
40+
41+
model = MyModel()
42+
43+
traced_model = torch.jit.trace(model, {"a": torch.tensor(1), "b": torch.tensor(2)}, strict=True)
44+
45+
model_input_schema = Schema(
46+
[ColumnSchema("a", dtype="int64"), ColumnSchema("b", dtype="int64")]
47+
)
48+
model_output_schema = Schema([ColumnSchema("output", dtype="int64")])
49+
50+
model_node = model_input_schema.column_names >> PredictPyTorch(
51+
traced_model, model_input_schema, model_output_schema
52+
)
53+
54+
ensemble = Ensemble(model_node, model_input_schema)
55+
56+
ensemble_config, _ = ensemble.export(str(tmpdir))
57+
58+
df = pd.DataFrame({"a": [1], "b": [2]})
59+
60+
response = run_ensemble_on_tritonserver(
61+
str(tmpdir), model_input_schema, df, ["output"], ensemble_config.name
62+
)
63+
np.testing.assert_array_equal(response["output"], np.array([3]))
64+
65+
66+
@pytest.mark.skipif(not TRITON_SERVER_PATH, reason="triton server not found")
67+
def test_model_error(tmpdir):
68+
class MyModel(torch.nn.Module):
69+
def forward(self, x):
70+
v = torch.stack(list(x.values())).sum()
71+
return v
72+
73+
model = MyModel()
74+
75+
traced_model = torch.jit.trace(model, {"a": torch.tensor(1), "b": torch.tensor(2)}, strict=True)
76+
77+
model_input_schema = Schema([ColumnSchema("a", dtype="int64")])
78+
model_output_schema = Schema([ColumnSchema("output", dtype="int64")])
79+
80+
model_node = model_input_schema.column_names >> PredictPyTorch(
81+
traced_model, model_input_schema, model_output_schema
82+
)
83+
84+
ensemble = Ensemble(model_node, model_input_schema)
85+
86+
ensemble_config, _ = ensemble.export(str(tmpdir))
87+
88+
# run inference with missing input (that was present when model was compiled)
89+
# we're expecting a KeyError at runtime.
90+
df = pd.DataFrame({"a": [1]})
91+
92+
with pytest.raises(tritonclient.utils.InferenceServerException) as exc_info:
93+
run_ensemble_on_tritonserver(
94+
str(tmpdir), model_input_schema, df, ["output"], ensemble_config.name
95+
)
96+
assert "The following operation failed in the TorchScript interpreter" in str(exc_info.value)
97+
assert "RuntimeError: KeyError: b" in str(exc_info.value)

0 commit comments

Comments
 (0)