Skip to content

Commit 71c7607

Browse files
ksuma2109Suma Kasa
authored andcommitted
[LMIv16] Add integration tests for CustomHandlers and fix (#2904)
Co-authored-by: Suma Kasa <sumakasa@amazon.com>
1 parent e437d44 commit 71c7607

File tree

12 files changed

+310
-2
lines changed

12 files changed

+310
-2
lines changed

.github/workflows/integration.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,9 @@ jobs:
160160
- test: TestVllmCustomHandlers
161161
instance: g6
162162
failure-prefix: lmi
163+
- test: TestVllmCustomFormatters
164+
instance: g6
165+
failure-prefix: lmi
163166
- test: TestVllmLora
164167
instance: g6
165168
failure-prefix: lmi
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
#!/usr/bin/env python
2+
#
3+
# Copyright 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file
6+
# except in compliance with the License. A copy of the License is located at
7+
#
8+
# http://aws.amazon.com/apache2.0/
9+
#
10+
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS"
11+
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
12+
# the specific language governing permissions and limitations under the License.
13+
14+
import logging
15+
from djl_python.inputs import Input
16+
from djl_python.service_loader import load_model_service, has_function_in_module
17+
from djl_python.async_utils import create_non_stream_output
18+
19+
logger = logging.getLogger(__name__)
20+
21+
22+
class CustomHandlerError(Exception):
23+
"""Exception raised when custom handler code fails"""
24+
25+
def __init__(self, message: str, original_exception: Exception):
26+
super().__init__(message)
27+
self.original_exception = original_exception
28+
self.__cause__ = original_exception
29+
30+
31+
class CustomHandlerService:
32+
33+
def __init__(self, properties: dict):
34+
self.custom_handler = None
35+
self.initialized = False
36+
self._initialize(properties)
37+
38+
def _initialize(self, properties: dict):
39+
model_dir = properties.get("model_dir", ".")
40+
try:
41+
service = load_model_service(model_dir, "model.py", -1)
42+
if has_function_in_module(service.module, "handle"):
43+
self.custom_handler = service
44+
logger.info("Loaded custom handler from model.py")
45+
self.initialized = True
46+
except Exception as e:
47+
logger.debug(f"No custom handler found in model.py: {e}")
48+
49+
async def handle(self, inputs: Input):
50+
if self.custom_handler:
51+
try:
52+
return await self.custom_handler.invoke_handler_async(
53+
"handle", inputs)
54+
except Exception as e:
55+
logger.exception("Custom handler failed")
56+
output = create_non_stream_output(
57+
"", error=f"Custom handler failed: {str(e)}", code=424)
58+
return output
59+
return None

engines/python/setup/djl_python/lmi_vllm/vllm_async_service.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from djl_python.async_utils import handle_streaming_response, create_non_stream_output, _extract_lora_adapter
3434
from djl_python.async_utils import register_adapter as _register_adapter, update_adapter as _update_adapter, unregister_adapter as _unregister_adapter
3535
from djl_python.custom_formatter_handling import CustomFormatterHandler, CustomFormatterError
36+
from djl_python.custom_handler_service import CustomHandlerService
3637
from djl_python.rolling_batch.rolling_batch_vllm_utils import create_lora_request, get_lora_request
3738

3839
from djl_python.lmi_vllm.request_response_utils import (
@@ -303,12 +304,26 @@ async def pin_lora(self, lora_name: str, lora_alias: str):
303304
lora_request.lora_int_id)
304305

305306

307+
custom_service = None
306308
service = VLLMHandler()
307309

308310

309311
async def handle(
310312
inputs: Input
311313
) -> Optional[Union[Output, AsyncGenerator[Output, None]]]:
314+
global custom_service
315+
# Initialize custom service once
316+
if custom_service is None:
317+
custom_service = CustomHandlerService(inputs.get_properties())
318+
319+
# Try custom handler first
320+
if custom_service.initialized:
321+
logger.info("Using custom handler for request")
322+
result = await custom_service.handle(inputs)
323+
if result is not None:
324+
logger.info("Custom handler completed successfully")
325+
return result
326+
312327
if not service.initialized:
313328
await service.initialize(inputs.get_properties())
314329
logger.info("vllm service initialized")

serving/docs/lmi/user_guides/vllm_user_guide.md

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,3 +190,63 @@ For example, if you want to enable the `speculative_config`, you can do:
190190

191191
* `option.speculative_config={"model": "meta-llama/Llama3.2-1B-Instruct", "num_speculative_tokens": 5}`
192192
* `OPTION_SPECULATIVE_CONFIG={"model": "meta-llama/Llama3.2-1B-Instruct", "num_speculative_tokens": 5}`
193+
194+
195+
## Custom Handlers
196+
197+
**Note: Custom handler support is only available in async mode.**
198+
199+
vLLM async mode supports custom handlers that allow you to implement custom inference logic while still leveraging the vLLM engine.
200+
To use a custom handler, create a `model.py` file in your model directory with an async `handle` function.
201+
202+
### Custom Handler Example
203+
204+
```python
205+
from djl_python import Input, Output
206+
from djl_python.encode_decode import decode
207+
from djl_python.async_utils import create_non_stream_output
208+
from vllm import LLM, SamplingParams
209+
210+
llm = None
211+
212+
async def handle(inputs: Input) -> Output:
213+
"""Custom async handler with vLLM generate"""
214+
global llm
215+
216+
# Initialize vLLM LLM if not already done
217+
if llm is None:
218+
properties = inputs.get_properties()
219+
model_id = properties.get("model_id", "gpt2")
220+
llm = LLM(model=model_id, tensor_parallel_size=1)
221+
222+
# Parse input
223+
batch = inputs.get_batches()
224+
raw_request = batch[0]
225+
content_type = raw_request.get_property("Content-Type")
226+
decoded_payload = decode(raw_request, content_type)
227+
228+
prompt = decoded_payload.get("inputs", "Hello")
229+
230+
# Create sampling parameters
231+
sampling_params = SamplingParams(max_tokens=50, temperature=0.8)
232+
233+
# Generate using vLLM
234+
outputs = llm.generate([prompt], sampling_params)
235+
generated_text = outputs[0].outputs[0].text if outputs else "No output"
236+
237+
# Create response
238+
response = {
239+
"generated_text": generated_text
240+
}
241+
242+
# Return properly formatted output
243+
return create_non_stream_output(response)
244+
```
245+
246+
### Key Points for Custom Handlers
247+
248+
- The `handle` function must be async and accept an `Input` parameter
249+
- Use `create_non_stream_output()` or `handle_streaming_response()` from `djl_python.async_utils` to format the response
250+
- Access model properties via `inputs.get_properties()`
251+
- Parse request data using `decode()` from `djl_python.encode_decode`
252+
- If the custom handler fails or is not found, the system will automatically fall back to the default vLLM handler
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from djl_python import Input, Output
2+
from non_existent_module import NonExistentClass
3+
4+
async def handle(inputs: Input):
5+
"""Custom handler with import error"""
6+
obj = NonExistentClass()
7+
return obj.process(inputs)
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from djl_python import Input, Output
2+
3+
# Missing handle function - should fall back to vLLM
4+
def some_other_function():
5+
return "This is not a handle function"
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from djl_python import Input, Output
2+
from djl_python.encode_decode import decode
3+
4+
async def handle(inputs: Input):
5+
"""Custom handler that raises runtime error"""
6+
raise RuntimeError("Custom handler intentional failure")
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
from djl_python import Input, Output
2+
from djl_python.async_utils import create_non_stream_output
3+
from djl_python.encode_decode import decode
4+
from vllm import LLM, SamplingParams
5+
import json
6+
7+
llm = None
8+
9+
async def handle(inputs: Input):
10+
"""Custom async handler with simple vLLM generate"""
11+
global llm
12+
13+
print("[CUSTOM_HANDLER] Starting handle function")
14+
15+
# Initialize vLLM LLM if not already done
16+
if llm is None:
17+
print("[CUSTOM_HANDLER] Initializing vLLM engine")
18+
properties = inputs.get_properties()
19+
model_id = properties.get("model_id", "gpt2")
20+
llm = LLM(model=model_id, tensor_parallel_size=1)
21+
print(f"[CUSTOM_HANDLER] vLLM engine initialized with model: {model_id}")
22+
23+
# Parse input
24+
batch = inputs.get_batches()
25+
raw_request = batch[0]
26+
content_type = raw_request.get_property("Content-Type")
27+
decoded_payload = decode(raw_request, content_type)
28+
29+
prompt = decoded_payload.get("inputs", "Hello")
30+
if not prompt or prompt.strip() == "":
31+
prompt = "Hello"
32+
33+
print(f"[CUSTOM_HANDLER] Using prompt: {prompt}")
34+
35+
# Create sampling parameters
36+
sampling_params = SamplingParams(max_tokens=50, temperature=0.8)
37+
38+
# Generate using simple vLLM generate
39+
print("[CUSTOM_HANDLER] Starting generation")
40+
outputs = llm.generate([prompt], sampling_params)
41+
generated_text = outputs[0].outputs[0].text if outputs else "No output"
42+
43+
print(f"[CUSTOM_HANDLER] Generated text: {generated_text}")
44+
45+
# Create response with custom marker
46+
response = {
47+
"custom_handler_used": True,
48+
"generated_text": generated_text
49+
}
50+
51+
print(f"[CUSTOM_HANDLER] Response created: {response}")
52+
53+
output = create_non_stream_output(response)
54+
55+
print("[CUSTOM_HANDLER] Output object created, returning")
56+
return output
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from djl_python import Input, Output
2+
from djl_python.encode_decode import decode
3+
from vllm import LLM, SamplingParams
4+
5+
# Syntax error - missing colon
6+
async def handle(inputs: Input)
7+
"""Custom handler with syntax error"""
8+
return None

tests/integration/llm/client.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1179,7 +1179,7 @@ def awscurl_run(data,
11791179
f"-N {num_run} -X POST {endpoint} --connect-timeout 300")
11801180
for header in headers:
11811181
command += f" -H {header}"
1182-
command = f" {command_data} {delay} {json_output} {jq} -P -t"
1182+
command += f" {command_data} {delay} {json_output} {jq} -P -t"
11831183
if tokenizer:
11841184
command = f"TOKENIZER={tokenizer} {command}"
11851185
if output:
@@ -1730,6 +1730,32 @@ def test_handler_rolling_batch(model, model_spec):
17301730

17311731

17321732
def test_custom_handler_async(model, model_spec):
1733+
modelspec_checker(model, model_spec)
1734+
spec = model_spec[args.model]
1735+
if "worker" in spec:
1736+
check_worker_number(spec["worker"])
1737+
stream_values = spec.get("stream", [False, True])
1738+
# dryrun phase
1739+
req = {"inputs": batch_generation(1)[0]}
1740+
seq_length = spec["seq_length"][0]
1741+
params = {"do_sample": True, "max_new_tokens": seq_length, "details": True}
1742+
req["parameters"] = params
1743+
if "parameters" in spec:
1744+
req["parameters"].update(spec["parameters"])
1745+
if "adapters" in spec:
1746+
req["adapters"] = spec.get("adapters")[0]
1747+
1748+
for stream in stream_values:
1749+
req["stream"] = stream
1750+
LOGGER.info(f"req {req}")
1751+
res = send_json(req)
1752+
message = res.content.decode("utf-8")
1753+
LOGGER.info(f"res: {message}")
1754+
response_checker(res, message)
1755+
assert "custom_handler_used" in message, "Output does not contain custom_handler_used tag"
1756+
1757+
1758+
def test_custom_formatter_async(model, model_spec):
17331759
modelspec_checker(model, model_spec)
17341760
spec = model_spec[args.model]
17351761
if "worker" in spec:
@@ -2329,6 +2355,8 @@ def run(raw_args):
23292355
elif args.handler == "vllm":
23302356
test_handler_rolling_batch(args.model, vllm_model_spec)
23312357
elif args.handler == "custom":
2358+
test_custom_formatter_async(args.model, custom_formatter_spec)
2359+
elif args.handler == "custom_handler":
23322360
test_custom_handler_async(args.model, custom_formatter_spec)
23332361
elif args.handler == "vllm_adapters":
23342362
test_handler_adapters(args.model, vllm_model_spec)

0 commit comments

Comments
 (0)