Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/integration.yml
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,9 @@ jobs:
- test: TestVllmCustomHandlers
instance: g6
failure-prefix: lmi
- test: TestVllmCustomFormatters
instance: g6
failure-prefix: lmi
- test: TestVllmLora
instance: g6
failure-prefix: lmi
Expand Down
59 changes: 59 additions & 0 deletions engines/python/setup/djl_python/custom_handler_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
#!/usr/bin/env python
#
# Copyright 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file
# except in compliance with the License. A copy of the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS"
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
# the specific language governing permissions and limitations under the License.

import logging
from djl_python.inputs import Input
from djl_python.service_loader import load_model_service, has_function_in_module
from djl_python.async_utils import create_non_stream_output

logger = logging.getLogger(__name__)


class CustomHandlerError(Exception):
"""Exception raised when custom handler code fails"""

def __init__(self, message: str, original_exception: Exception):
super().__init__(message)
self.original_exception = original_exception
self.__cause__ = original_exception


class CustomHandlerService:

def __init__(self, properties: dict):
self.custom_handler = None
self.initialized = False
self._initialize(properties)

def _initialize(self, properties: dict):
model_dir = properties.get("model_dir", ".")
try:
service = load_model_service(model_dir, "model.py", -1)
if has_function_in_module(service.module, "handle"):
self.custom_handler = service
logger.info("Loaded custom handler from model.py")
self.initialized = True
except Exception as e:
logger.debug(f"No custom handler found in model.py: {e}")

async def handle(self, inputs: Input):
if self.custom_handler:
try:
return await self.custom_handler.invoke_handler_async(
"handle", inputs)
except Exception as e:
logger.exception("Custom handler failed")
output = create_non_stream_output(
"", error=f"Custom handler failed: {str(e)}", code=424)
return output
return None
15 changes: 15 additions & 0 deletions engines/python/setup/djl_python/lmi_vllm/vllm_async_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from djl_python.async_utils import handle_streaming_response, create_non_stream_output, _extract_lora_adapter
from djl_python.async_utils import register_adapter as _register_adapter, update_adapter as _update_adapter, unregister_adapter as _unregister_adapter
from djl_python.custom_formatter_handling import CustomFormatterHandler, CustomFormatterError
from djl_python.custom_handler_service import CustomHandlerService
from djl_python.rolling_batch.rolling_batch_vllm_utils import create_lora_request, get_lora_request

from djl_python.lmi_vllm.request_response_utils import (
Expand Down Expand Up @@ -303,12 +304,26 @@ async def pin_lora(self, lora_name: str, lora_alias: str):
lora_request.lora_int_id)


custom_service = None
service = VLLMHandler()


async def handle(
inputs: Input
) -> Optional[Union[Output, AsyncGenerator[Output, None]]]:
global custom_service
# Initialize custom service once
if custom_service is None:
custom_service = CustomHandlerService(inputs.get_properties())

# Try custom handler first
if custom_service.initialized:
logger.info("Using custom handler for request")
result = await custom_service.handle(inputs)
if result is not None:
logger.info("Custom handler completed successfully")
return result

if not service.initialized:
await service.initialize(inputs.get_properties())
logger.info("vllm service initialized")
Expand Down
60 changes: 60 additions & 0 deletions serving/docs/lmi/user_guides/vllm_user_guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -190,3 +190,63 @@ For example, if you want to enable the `speculative_config`, you can do:

* `option.speculative_config={"model": "meta-llama/Llama3.2-1B-Instruct", "num_speculative_tokens": 5}`
* `OPTION_SPECULATIVE_CONFIG={"model": "meta-llama/Llama3.2-1B-Instruct", "num_speculative_tokens": 5}`


## Custom Handlers

**Note: Custom handler support is only available in async mode.**

vLLM async mode supports custom handlers that allow you to implement custom inference logic while still leveraging the vLLM engine.
To use a custom handler, create a `model.py` file in your model directory with an async `handle` function.

### Custom Handler Example

```python
from djl_python import Input, Output
from djl_python.encode_decode import decode
from djl_python.async_utils import create_non_stream_output
from vllm import LLM, SamplingParams

llm = None

async def handle(inputs: Input) -> Output:
"""Custom async handler with vLLM generate"""
global llm

# Initialize vLLM LLM if not already done
if llm is None:
properties = inputs.get_properties()
model_id = properties.get("model_id", "gpt2")
llm = LLM(model=model_id, tensor_parallel_size=1)

# Parse input
batch = inputs.get_batches()
raw_request = batch[0]
content_type = raw_request.get_property("Content-Type")
decoded_payload = decode(raw_request, content_type)

prompt = decoded_payload.get("inputs", "Hello")

# Create sampling parameters
sampling_params = SamplingParams(max_tokens=50, temperature=0.8)

# Generate using vLLM
outputs = llm.generate([prompt], sampling_params)
generated_text = outputs[0].outputs[0].text if outputs else "No output"

# Create response
response = {
"generated_text": generated_text
}

# Return properly formatted output
return create_non_stream_output(response)
```

### Key Points for Custom Handlers

- The `handle` function must be async and accept an `Input` parameter
- Use `create_non_stream_output()` or `handle_streaming_response()` from `djl_python.async_utils` to format the response
- Access model properties via `inputs.get_properties()`
- Parse request data using `decode()` from `djl_python.encode_decode`
- If the custom handler fails or is not found, the system will automatically fall back to the default vLLM handler
7 changes: 7 additions & 0 deletions tests/integration/examples/custom_handlers/import_error.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from djl_python import Input, Output
from non_existent_module import NonExistentClass

async def handle(inputs: Input):
"""Custom handler with import error"""
obj = NonExistentClass()
return obj.process(inputs)
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from djl_python import Input, Output

# Missing handle function - should fall back to vLLM
def some_other_function():
return "This is not a handle function"
6 changes: 6 additions & 0 deletions tests/integration/examples/custom_handlers/runtime_error.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from djl_python import Input, Output
from djl_python.encode_decode import decode

async def handle(inputs: Input):
"""Custom handler that raises runtime error"""
raise RuntimeError("Custom handler intentional failure")
56 changes: 56 additions & 0 deletions tests/integration/examples/custom_handlers/success.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from djl_python import Input, Output
from djl_python.async_utils import create_non_stream_output
from djl_python.encode_decode import decode
from vllm import LLM, SamplingParams
import json

llm = None

async def handle(inputs: Input):
"""Custom async handler with simple vLLM generate"""
global llm

print("[CUSTOM_HANDLER] Starting handle function")

# Initialize vLLM LLM if not already done
if llm is None:
print("[CUSTOM_HANDLER] Initializing vLLM engine")
properties = inputs.get_properties()
model_id = properties.get("model_id", "gpt2")
llm = LLM(model=model_id, tensor_parallel_size=1)
print(f"[CUSTOM_HANDLER] vLLM engine initialized with model: {model_id}")

# Parse input
batch = inputs.get_batches()
raw_request = batch[0]
content_type = raw_request.get_property("Content-Type")
decoded_payload = decode(raw_request, content_type)

prompt = decoded_payload.get("inputs", "Hello")
if not prompt or prompt.strip() == "":
prompt = "Hello"

print(f"[CUSTOM_HANDLER] Using prompt: {prompt}")

# Create sampling parameters
sampling_params = SamplingParams(max_tokens=50, temperature=0.8)

# Generate using simple vLLM generate
print("[CUSTOM_HANDLER] Starting generation")
outputs = llm.generate([prompt], sampling_params)
generated_text = outputs[0].outputs[0].text if outputs else "No output"

print(f"[CUSTOM_HANDLER] Generated text: {generated_text}")

# Create response with custom marker
response = {
"custom_handler_used": True,
"generated_text": generated_text
}

print(f"[CUSTOM_HANDLER] Response created: {response}")

output = create_non_stream_output(response)

print("[CUSTOM_HANDLER] Output object created, returning")
return output
8 changes: 8 additions & 0 deletions tests/integration/examples/custom_handlers/syntax_error.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from djl_python import Input, Output
from djl_python.encode_decode import decode
from vllm import LLM, SamplingParams

# Syntax error - missing colon
async def handle(inputs: Input)
"""Custom handler with syntax error"""
return None
30 changes: 29 additions & 1 deletion tests/integration/llm/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1179,7 +1179,7 @@ def awscurl_run(data,
f"-N {num_run} -X POST {endpoint} --connect-timeout 300")
for header in headers:
command += f" -H {header}"
command = f" {command_data} {delay} {json_output} {jq} -P -t"
command += f" {command_data} {delay} {json_output} {jq} -P -t"
if tokenizer:
command = f"TOKENIZER={tokenizer} {command}"
if output:
Expand Down Expand Up @@ -1730,6 +1730,32 @@ def test_handler_rolling_batch(model, model_spec):


def test_custom_handler_async(model, model_spec):
modelspec_checker(model, model_spec)
spec = model_spec[args.model]
if "worker" in spec:
check_worker_number(spec["worker"])
stream_values = spec.get("stream", [False, True])
# dryrun phase
req = {"inputs": batch_generation(1)[0]}
seq_length = spec["seq_length"][0]
params = {"do_sample": True, "max_new_tokens": seq_length, "details": True}
req["parameters"] = params
if "parameters" in spec:
req["parameters"].update(spec["parameters"])
if "adapters" in spec:
req["adapters"] = spec.get("adapters")[0]

for stream in stream_values:
req["stream"] = stream
LOGGER.info(f"req {req}")
res = send_json(req)
message = res.content.decode("utf-8")
LOGGER.info(f"res: {message}")
response_checker(res, message)
assert "custom_handler_used" in message, "Output does not contain custom_handler_used tag"


def test_custom_formatter_async(model, model_spec):
modelspec_checker(model, model_spec)
spec = model_spec[args.model]
if "worker" in spec:
Expand Down Expand Up @@ -2329,6 +2355,8 @@ def run(raw_args):
elif args.handler == "vllm":
test_handler_rolling_batch(args.model, vllm_model_spec)
elif args.handler == "custom":
test_custom_formatter_async(args.model, custom_formatter_spec)
elif args.handler == "custom_handler":
test_custom_handler_async(args.model, custom_formatter_spec)
elif args.handler == "vllm_adapters":
test_handler_adapters(args.model, vllm_model_spec)
Expand Down
22 changes: 21 additions & 1 deletion tests/integration/llm/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -1713,6 +1713,25 @@ def build_vllm_async_model(model):
adapter_names=adapter_names)


def build_vllm_async_model_with_custom_handler(model, handler_type="success"):
if model not in vllm_model_list.keys():
raise ValueError(
f"{model} is not one of the supporting handler {list(vllm_model_list.keys())}"
)
options = vllm_model_list[model]
options["engine"] = "Python"
options["option.rolling_batch"] = "disable"
options["option.async_mode"] = "true"
options["option.entryPoint"] = "djl_python.lmi_vllm.vllm_async_service"
write_model_artifacts(options)

# Copy custom handler from examples
source_file = f"examples/custom_handlers/{handler_type}.py"
target_file = "models/test/model.py"
if os.path.exists(source_file):
shutil.copy2(source_file, target_file)


def build_vllm_async_model_custom_formatters(model, error_type=None):
if model not in vllm_model_list.keys():
raise ValueError(
Expand Down Expand Up @@ -1890,7 +1909,8 @@ def build_stateful_model(model):
'correctness': build_correctness_model,
'text_embedding': build_text_embedding_model,
'vllm_async': build_vllm_async_model,
'vllm_async_custom_formatters': build_vllm_async_model_custom_formatters
'vllm_async_custom_formatters': build_vllm_async_model_custom_formatters,
'vllm_async_custom_handler': build_vllm_async_model_with_custom_handler
}

if __name__ == '__main__':
Expand Down
41 changes: 41 additions & 0 deletions tests/integration/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1073,6 +1073,47 @@ def test_llama31_8b_tp2_pp2_specdec(self):
@pytest.mark.gpu_4
class TestVllmCustomHandlers:

def test_custom_handler_success(self):
with Runner('lmi', 'gpt-neox-20b-custom-handler') as r:
prepare.build_vllm_async_model_with_custom_handler(
"gpt-neox-20b-custom")
r.launch()
client.run("custom_handler gpt-neox-20b".split())

def test_custom_handler_syntax_error(self):
with Runner('lmi', 'gpt-neox-20b-custom-handler-syntax-error') as r:
prepare.build_vllm_async_model_with_custom_handler(
"gpt-neox-20b-custom", "syntax_error")
with pytest.raises(Exception):
r.launch()

def test_custom_handler_runtime_error(self):
with Runner('lmi', 'gpt-neox-20b-custom-handler-runtime-error') as r:
prepare.build_vllm_async_model_with_custom_handler(
"gpt-neox-20b-custom", "runtime_error")
r.launch()
with pytest.raises(ValueError, match=r".*424.*"):
client.run("custom_handler gpt-neox-20b".split())

def test_custom_handler_missing_handle(self):
with Runner('lmi', 'gpt-neox-20b-custom-handler-missing') as r:
prepare.build_vllm_async_model_with_custom_handler(
"gpt-neox-20b-custom", "missing_handle")
r.launch()
client.run("vllm gpt-neox-20b".split()) # Should fall back to vLLM

def test_custom_handler_import_error(self):
with Runner('lmi', 'gpt-neox-20b-custom-handler-import-error') as r:
prepare.build_vllm_async_model_with_custom_handler(
"gpt-neox-20b-custom", "import_error")
with pytest.raises(Exception):
r.launch()


@pytest.mark.vllm
@pytest.mark.gpu_4
class TestVllmCustomFormatters:

def test_gpt_neox_20b_custom(self):
with Runner('lmi', 'gpt-neox-20b') as r:
prepare.build_vllm_async_model_custom_formatters(
Expand Down
Loading