Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 22 additions & 2 deletions lib/llm/src/grpc/service/kserve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,17 @@ impl GrpcInferenceService for KserveService {
let stream = tensor_response_stream(state.clone(), tensor_request, true).await?;

pin_mut!(stream);
while let Some(response) = stream.next().await {
while let Some(delta) = stream.next().await {
let response = match delta.ok() {
Err(e) => {
yield ModelStreamInferResponse {
error_message: e.to_string(),
infer_response: None
};
continue;
}
Ok(response) => response,
};
match response.data {
Some(data) => {
let data = ExtendedNvCreateTensorResponse {response: data,
Expand Down Expand Up @@ -412,7 +422,17 @@ impl GrpcInferenceService for KserveService {

if streaming {
pin_mut!(stream);
while let Some(response) = stream.next().await {
while let Some(delta) = stream.next().await {
let response = match delta.ok() {
Err(e) => {
yield ModelStreamInferResponse {
error_message: e.to_string(),
infer_response: None
};
continue;
}
Ok(response) => response,
};
match response.data {
Some(data) => {
let mut reply = ModelStreamInferResponse::try_from(data).map_err(|e| {
Expand Down
10 changes: 10 additions & 0 deletions tests/frontend/grpc/echo_tensor_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,16 @@ async def generate(request, context):
params = {}
if "parameters" in request:
params.update(request["parameters"])
if "malformed_response" in request["parameters"]:
request["tensors"][0]["data"] = {"values": [0, 1, 2]}
yield {
"model": request["model"],
"tensors": request["tensors"],
"parameters": params,
}
return
elif "raise_exception" in request["parameters"]:
raise ValueError("Intentional exception raised by echo_tensor_worker.")

params["processed"] = {"bool": True}

Expand Down
110 changes: 110 additions & 0 deletions tests/frontend/grpc/test_tensor_mocker_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,14 @@

import logging
import os
import queue
import shutil
from functools import partial

import numpy as np
import pytest
import triton_echo_client
import tritonclient.grpc as grpcclient

from tests.utils.constants import QWEN
from tests.utils.managed_process import ManagedProcess
Expand Down Expand Up @@ -105,4 +109,110 @@ def test_echo(start_services_with_echo_worker) -> None:
client = triton_echo_client.TritonEchoClient(grpc_port=frontend_port)
client.check_health()
client.run_infer()
client.run_stream_infer()
client.get_config()


@pytest.mark.e2e
@pytest.mark.pre_merge
@pytest.mark.gpu_0 # Echo tensor worker is CPU-only (no GPU required)
@pytest.mark.parallel
@pytest.mark.parametrize(
"request_params",
[
{"malformed_response": True},
{"raise_exception": True},
],
ids=["malformed_response", "raise_exception"],
)
def test_model_infer_failure(start_services_with_echo_worker, request_params):
"""Test gRPC request-level parameters are echoed through tensor models.

The worker acts as an identity function: echoes input tensors unchanged and
returns all request parameters plus a "processed" flag to verify the complete
parameter flow through the gRPC frontend.
"""
frontend_port = start_services_with_echo_worker
client = grpcclient.InferenceServerClient(f"localhost:{frontend_port}")

input_data = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32)
inputs = [grpcclient.InferInput("INPUT", input_data.shape, "FP32")]
inputs[0].set_data_from_numpy(input_data)

# expect exception during inference
with pytest.raises(Exception) as excinfo:
client.infer("echo", inputs=inputs, parameters=request_params)
if "malformed_response" in request_params:
assert "missing field `data_type`" in str(excinfo.value).lower()
elif "raise_exception" in request_params:
assert "intentional exception" in str(excinfo.value).lower()


@pytest.mark.e2e
@pytest.mark.pre_merge
@pytest.mark.gpu_0 # Echo tensor worker is CPU-only (no GPU required)
@pytest.mark.parallel
@pytest.mark.parametrize(
"request_params",
[
{"malformed_response": True},
{"raise_exception": True},
],
ids=["malformed_response", "raise_exception"],
)
def test_model_stream_infer_failure(start_services_with_echo_worker, request_params):
"""Test gRPC request-level parameters are echoed through tensor models.

The worker acts as an identity function: echoes input tensors unchanged and
returns all request parameters plus a "processed" flag to verify the complete
parameter flow through the gRPC frontend.
"""
frontend_port = start_services_with_echo_worker
client = grpcclient.InferenceServerClient(f"localhost:{frontend_port}")

input_data = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32)
inputs = [grpcclient.InferInput("INPUT", input_data.shape, "FP32")]
inputs[0].set_data_from_numpy(input_data)

class UserData:
def __init__(self):
self._completed_requests: queue.Queue[
grpcclient.InferResult | Exception
] = queue.Queue()

# Define the callback function. Note the last two parameters should be
# result and error. InferenceServerClient would povide the results of an
# inference as grpcclient.InferResult in result. For successful
# inference, error will be None, otherwise it will be an object of
# tritonclientutils.InferenceServerException holding the error details
def callback(user_data, result, error):
print("Received callback")
if error:
user_data._completed_requests.put(error)
else:
user_data._completed_requests.put(result)

user_data = UserData()
client.start_stream(
callback=partial(callback, user_data),
)

client.async_stream_infer(
model_name="echo",
inputs=inputs,
parameters=request_params,
)

# For stream infer, the exception and error will pass to the callback but not
# raised
with pytest.raises(Exception) as excinfo:
data_item = user_data._completed_requests.get(timeout=5)
if isinstance(data_item, Exception):
print("Raising exception received from stream infer callback")
raise data_item
if "malformed_response" in request_params:
assert "missing field `data_type`" in str(excinfo.value).lower()
elif "raise_exception" in request_params:
assert "intentional exception" in str(excinfo.value).lower()
else:
assert False, "Expected exception was not raised"
72 changes: 72 additions & 0 deletions tests/frontend/grpc/triton_echo_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
# SPDX-License-Identifier: Apache-2.0


import queue
from functools import partial

import numpy as np
import tritonclient.grpc as grpcclient

Expand Down Expand Up @@ -63,9 +66,78 @@ def run_infer(self) -> None:
assert np.array_equal(input0_data, output0_data)
assert np.array_equal(input1_data, output1_data)

def run_stream_infer(self) -> None:
triton_client = self._client()
model_name = "echo"

inputs = [
grpcclient.InferInput("INPUT0", [16], "INT32"),
grpcclient.InferInput("INPUT1", [16], "BYTES"),
]

input0_data = np.arange(start=0, stop=16, dtype=np.int32).reshape([16])
input1_data = np.array(
[str(x).encode("utf-8") for x in input0_data.reshape(input0_data.size)],
dtype=np.object_,
).reshape([16])

inputs[0].set_data_from_numpy(input0_data)
inputs[1].set_data_from_numpy(input1_data)

class UserData:
def __init__(self):
self._completed_requests = queue.Queue()

# Define the callback function. Note the last two parameters should be
# result and error. InferenceServerClient would povide the results of an
# inference as grpcclient.InferResult in result. For successful
# inference, error will be None, otherwise it will be an object of
# tritonclientutils.InferenceServerException holding the error details
def callback(user_data, result, error):
print("Received callback")
if error:
user_data._completed_requests.put(error)
else:
user_data._completed_requests.put(result)

user_data = UserData()
triton_client.start_stream(
callback=partial(callback, user_data),
)

triton_client.async_stream_infer(
model_name=model_name,
inputs=inputs,
)

data_item = user_data._completed_requests.get()
assert (
isinstance(data_item, Exception) is False
), f"Stream inference failed: {data_item}"

output0_data = data_item.as_numpy("INPUT0")
output1_data = data_item.as_numpy("INPUT1")

assert (
output0_data is not None
), "Expected response to include output tensor 'INPUT0'"
assert (
output1_data is not None
), "Expected response to include output tensor 'INPUT1'"
assert np.array_equal(input0_data, output0_data)
assert np.array_equal(input1_data, output1_data)

def get_config(self) -> None:
triton_client = self._client()
model_name = "echo"
response = triton_client.get_model_config(model_name=model_name)
# Check one of the field that can only be set by providing Triton model config
assert response.config.model_transaction_policy.decoupled


if __name__ == "__main__":
client = TritonEchoClient(grpc_port=8000)
client.check_health()
client.run_infer()
client.get_config()
print("Triton echo client ran successfully.")
Loading