diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml index afa3cf73c..198a6a786 100644 --- a/.github/workflows/integration.yml +++ b/.github/workflows/integration.yml @@ -160,6 +160,9 @@ jobs: - test: TestVllmCustomHandlers instance: g6 failure-prefix: lmi + - test: TestVllmCustomFormatters + instance: g6 + failure-prefix: lmi # - test: TestVllmLora # instance: g6 # failure-prefix: lmi diff --git a/engines/python/setup/djl_python/custom_handler_service.py b/engines/python/setup/djl_python/custom_handler_service.py new file mode 100644 index 000000000..74ce47aca --- /dev/null +++ b/engines/python/setup/djl_python/custom_handler_service.py @@ -0,0 +1,59 @@ +#!/usr/bin/env python +# +# Copyright 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file +# except in compliance with the License. A copy of the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS" +# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for +# the specific language governing permissions and limitations under the License. + +import logging +from djl_python.inputs import Input +from djl_python.service_loader import load_model_service, has_function_in_module +from djl_python.async_utils import create_non_stream_output + +logger = logging.getLogger(__name__) + + +class CustomHandlerError(Exception): + """Exception raised when custom handler code fails""" + + def __init__(self, message: str, original_exception: Exception): + super().__init__(message) + self.original_exception = original_exception + self.__cause__ = original_exception + + +class CustomHandlerService: + + def __init__(self, properties: dict): + self.custom_handler = None + self.initialized = False + self._initialize(properties) + + def _initialize(self, properties: dict): + model_dir = properties.get("model_dir", ".") + try: + service = load_model_service(model_dir, "model.py", -1) + if has_function_in_module(service.module, "handle"): + self.custom_handler = service + logger.info("Loaded custom handler from model.py") + self.initialized = True + except Exception as e: + logger.debug(f"No custom handler found in model.py: {e}") + + async def handle(self, inputs: Input): + if self.custom_handler: + try: + return await self.custom_handler.invoke_handler_async( + "handle", inputs) + except Exception as e: + logger.exception("Custom handler failed") + output = create_non_stream_output( + "", error=f"Custom handler failed: {str(e)}", code=424) + return output + return None diff --git a/engines/python/setup/djl_python/lmi_vllm/vllm_async_service.py b/engines/python/setup/djl_python/lmi_vllm/vllm_async_service.py index aa500e0e2..5fe783d78 100644 --- a/engines/python/setup/djl_python/lmi_vllm/vllm_async_service.py +++ b/engines/python/setup/djl_python/lmi_vllm/vllm_async_service.py @@ -32,7 +32,6 @@ from djl_python.encode_decode import decode from djl_python.async_utils import handle_streaming_response, create_non_stream_output from djl_python.custom_formatter_handling import CustomFormatterHandler, CustomFormatterError - from .request_response_utils import ( ProcessedRequest, vllm_stream_output_formatter, @@ -43,6 +42,7 @@ lmi_with_details_non_stream_output_formatter, lmi_non_stream_output_formatter, ) +from djl_python.custom_handler_service import CustomHandlerService logger = logging.getLogger(__name__) @@ -223,17 +223,34 @@ async def inference( ) -service = VLLMHandler() +# Try to use custom handler first, fall back to VLLMHandler +custom_service = None +vllm_service = VLLMHandler() async def handle( inputs: Input ) -> Optional[Union[Output, AsyncGenerator[Output, None]]]: - if not service.initialized: - await service.initialize(inputs.get_properties()) - logger.info("vllm service initialized") + global custom_service + # Initialize custom service once + if custom_service is None: + custom_service = CustomHandlerService(inputs.get_properties()) + + # Try custom handler first + if custom_service.initialized: + logger.info("Using custom handler for request") + result = await custom_service.handle(inputs) + if result is not None: + logger.info("Custom handler completed successfully") + return result + + # Fall back to vLLM handler + if not vllm_service.initialized: + await vllm_service.initialize(inputs.get_properties()) + logger.info("Default vLLM service initialized") + if inputs.is_empty(): return None - outputs = await service.inference(inputs) + outputs = await vllm_service.inference(inputs) return outputs diff --git a/serving/docker/lmi-container-requirements.txt b/serving/docker/lmi-container-requirements.txt index 57e265993..1c054940c 100644 --- a/serving/docker/lmi-container-requirements.txt +++ b/serving/docker/lmi-container-requirements.txt @@ -33,4 +33,4 @@ peft llmcompressor https://publish.djl.ai/sm-vllm/vllm-0.10.2-cp38-abi3-linux_x86_64.whl xgrammar==0.1.23 -flashinfer-python==0.2.5 \ No newline at end of file +flashinfer-python==0.2.5 diff --git a/serving/docs/lmi/user_guides/vllm_user_guide.md b/serving/docs/lmi/user_guides/vllm_user_guide.md index 791846b69..9a3859fdb 100644 --- a/serving/docs/lmi/user_guides/vllm_user_guide.md +++ b/serving/docs/lmi/user_guides/vllm_user_guide.md @@ -190,3 +190,62 @@ For example, if you want to enable the `speculative_config`, you can do: * `option.speculative_config={"model": "meta-llama/Llama3.2-1B-Instruct", "num_speculative_tokens": 5}` * `OPTION_SPECULATIVE_CONFIG={"model": "meta-llama/Llama3.2-1B-Instruct", "num_speculative_tokens": 5}` + +## Custom Handlers + +**Note: Custom handler support is only available in async mode.** + +vLLM async mode supports custom handlers that allow you to implement custom inference logic while still leveraging the vLLM engine. +To use a custom handler, create a `model.py` file in your model directory with an async `handle` function. + +### Custom Handler Example + +```python +from djl_python import Input, Output +from djl_python.encode_decode import decode +from djl_python.async_utils import create_non_stream_output +from vllm import LLM, SamplingParams + +llm = None + +async def handle(inputs: Input) -> Output: + """Custom async handler with vLLM generate""" + global llm + + # Initialize vLLM LLM if not already done + if llm is None: + properties = inputs.get_properties() + model_id = properties.get("model_id", "gpt2") + llm = LLM(model=model_id, tensor_parallel_size=1) + + # Parse input + batch = inputs.get_batches() + raw_request = batch[0] + content_type = raw_request.get_property("Content-Type") + decoded_payload = decode(raw_request, content_type) + + prompt = decoded_payload.get("inputs", "Hello") + + # Create sampling parameters + sampling_params = SamplingParams(max_tokens=50, temperature=0.8) + + # Generate using vLLM + outputs = llm.generate([prompt], sampling_params) + generated_text = outputs[0].outputs[0].text if outputs else "No output" + + # Create response + response = { + "generated_text": generated_text + } + + # Return properly formatted output + return create_non_stream_output(response) +``` + +### Key Points for Custom Handlers + +- The `handle` function must be async and accept an `Input` parameter +- Use `create_non_stream_output()` or `handle_streaming_response` from `djl_python.async_utils` to format the response +- Access model properties via `inputs.get_properties()` +- Parse request data using `decode()` from `djl_python.encode_decode` +- If the custom handler fails or is not found, the system will automatically fall back to the default vLLM handler diff --git a/tests/integration/examples/custom_handlers/import_error.py b/tests/integration/examples/custom_handlers/import_error.py new file mode 100644 index 000000000..b1c6b83e5 --- /dev/null +++ b/tests/integration/examples/custom_handlers/import_error.py @@ -0,0 +1,7 @@ +from djl_python import Input, Output +from non_existent_module import NonExistentClass + +async def handle(inputs: Input): + """Custom handler with import error""" + obj = NonExistentClass() + return obj.process(inputs) \ No newline at end of file diff --git a/tests/integration/examples/custom_handlers/missing_handle.py b/tests/integration/examples/custom_handlers/missing_handle.py new file mode 100644 index 000000000..4d82b8207 --- /dev/null +++ b/tests/integration/examples/custom_handlers/missing_handle.py @@ -0,0 +1,5 @@ +from djl_python import Input, Output + +# Missing handle function - should fall back to vLLM +def some_other_function(): + return "This is not a handle function" \ No newline at end of file diff --git a/tests/integration/examples/custom_handlers/runtime_error.py b/tests/integration/examples/custom_handlers/runtime_error.py new file mode 100644 index 000000000..bca931237 --- /dev/null +++ b/tests/integration/examples/custom_handlers/runtime_error.py @@ -0,0 +1,6 @@ +from djl_python import Input, Output +from djl_python.encode_decode import decode + +async def handle(inputs: Input): + """Custom handler that raises runtime error""" + raise RuntimeError("Custom handler intentional failure") \ No newline at end of file diff --git a/tests/integration/examples/custom_handlers/success.py b/tests/integration/examples/custom_handlers/success.py new file mode 100644 index 000000000..7d1879efa --- /dev/null +++ b/tests/integration/examples/custom_handlers/success.py @@ -0,0 +1,56 @@ +from djl_python import Input, Output +from djl_python.async_utils import create_non_stream_output +from djl_python.encode_decode import decode +from vllm import LLM, SamplingParams +import json + +llm = None + +async def handle(inputs: Input): + """Custom async handler with simple vLLM generate""" + global llm + + print("[CUSTOM_HANDLER] Starting handle function") + + # Initialize vLLM LLM if not already done + if llm is None: + print("[CUSTOM_HANDLER] Initializing vLLM engine") + properties = inputs.get_properties() + model_id = properties.get("model_id", "gpt2") + llm = LLM(model=model_id, tensor_parallel_size=1) + print(f"[CUSTOM_HANDLER] vLLM engine initialized with model: {model_id}") + + # Parse input + batch = inputs.get_batches() + raw_request = batch[0] + content_type = raw_request.get_property("Content-Type") + decoded_payload = decode(raw_request, content_type) + + prompt = decoded_payload.get("inputs", "Hello") + if not prompt or prompt.strip() == "": + prompt = "Hello" + + print(f"[CUSTOM_HANDLER] Using prompt: {prompt}") + + # Create sampling parameters + sampling_params = SamplingParams(max_tokens=50, temperature=0.8) + + # Generate using simple vLLM generate + print("[CUSTOM_HANDLER] Starting generation") + outputs = llm.generate([prompt], sampling_params) + generated_text = outputs[0].outputs[0].text if outputs else "No output" + + print(f"[CUSTOM_HANDLER] Generated text: {generated_text}") + + # Create response with custom marker + response = { + "custom_handler_used": True, + "generated_text": generated_text + } + + print(f"[CUSTOM_HANDLER] Response created: {response}") + + output = create_non_stream_output(response) + + print("[CUSTOM_HANDLER] Output object created, returning") + return output \ No newline at end of file diff --git a/tests/integration/examples/custom_handlers/syntax_error.py b/tests/integration/examples/custom_handlers/syntax_error.py new file mode 100644 index 000000000..7b37fe75f --- /dev/null +++ b/tests/integration/examples/custom_handlers/syntax_error.py @@ -0,0 +1,8 @@ +from djl_python import Input, Output +from djl_python.encode_decode import decode +from vllm import LLM, SamplingParams + +# Syntax error - missing colon +async def handle(inputs: Input) + """Custom handler with syntax error""" + return None \ No newline at end of file diff --git a/tests/integration/llm/client.py b/tests/integration/llm/client.py index fffbac06b..7be87cf26 100644 --- a/tests/integration/llm/client.py +++ b/tests/integration/llm/client.py @@ -1646,6 +1646,32 @@ def test_handler_rolling_batch(model, model_spec): def test_custom_handler_async(model, model_spec): + modelspec_checker(model, model_spec) + spec = model_spec[args.model] + if "worker" in spec: + check_worker_number(spec["worker"]) + stream_values = spec.get("stream", [False, True]) + # dryrun phase + req = {"inputs": batch_generation(1)[0]} + seq_length = spec["seq_length"][0] + params = {"do_sample": True, "max_new_tokens": seq_length, "details": True} + req["parameters"] = params + if "parameters" in spec: + req["parameters"].update(spec["parameters"]) + if "adapters" in spec: + req["adapters"] = spec.get("adapters")[0] + + for stream in stream_values: + req["stream"] = stream + LOGGER.info(f"req {req}") + res = send_json(req) + message = res.content.decode("utf-8") + LOGGER.info(f"res: {message}") + response_checker(res, message) + assert "custom_handler_used" in message, "Output does not contain custom_handler_used tag" + + +def test_custom_formatter_async(model, model_spec): modelspec_checker(model, model_spec) spec = model_spec[args.model] if "worker" in spec: @@ -2167,6 +2193,8 @@ def run(raw_args): elif args.handler == "vllm": test_handler_rolling_batch(args.model, vllm_model_spec) elif args.handler == "custom": + test_custom_formatter_async(args.model, custom_formatter_spec) + elif args.handler == "custom_handler": test_custom_handler_async(args.model, custom_formatter_spec) elif args.handler == "vllm_adapters": test_handler_adapters(args.model, vllm_model_spec) diff --git a/tests/integration/llm/prepare.py b/tests/integration/llm/prepare.py index e22dadaef..c6f8f1004 100644 --- a/tests/integration/llm/prepare.py +++ b/tests/integration/llm/prepare.py @@ -1691,6 +1691,25 @@ def build_vllm_async_model(model): write_model_artifacts(options) +def build_vllm_async_model_with_custom_handler(model, handler_type="success"): + if model not in vllm_model_list.keys(): + raise ValueError( + f"{model} is not one of the supporting handler {list(vllm_model_list.keys())}" + ) + options = vllm_model_list[model] + options["engine"] = "Python" + options["option.rolling_batch"] = "disable" + options["option.async_mode"] = "true" + options["option.entryPoint"] = "djl_python.lmi_vllm.vllm_async_service" + write_model_artifacts(options) + + # Copy custom handler from examples + source_file = f"examples/custom_handlers/{handler_type}.py" + target_file = "models/test/model.py" + if os.path.exists(source_file): + shutil.copy2(source_file, target_file) + + def build_vllm_async_model_custom_formatters(model, error_type=None): if model not in vllm_model_list.keys(): raise ValueError( @@ -1853,7 +1872,8 @@ def build_text_embedding_model(model): 'correctness': build_correctness_model, 'text_embedding': build_text_embedding_model, 'vllm_async': build_vllm_async_model, - 'vllm_async_custom_formatters': build_vllm_async_model_custom_formatters + 'vllm_async_custom_formatters': build_vllm_async_model_custom_formatters, + 'vllm_async_custom_handler': build_vllm_async_model_with_custom_handler } if __name__ == '__main__': diff --git a/tests/integration/tests.py b/tests/integration/tests.py index 364db5a65..a6f1808d9 100644 --- a/tests/integration/tests.py +++ b/tests/integration/tests.py @@ -629,6 +629,47 @@ def test_vllm_performance(self): client.run("handler_performance vllm".split()) +@pytest.mark.vllm +@pytest.mark.gpu_4 +class TestVllmCustomHandlers: + + def test_custom_handler_success(self): + with Runner('lmi', 'gpt-neox-20b-custom-handler') as r: + prepare.build_vllm_async_model_with_custom_handler( + "gpt-neox-20b-custom") + r.launch() + client.run("custom_handler gpt-neox-20b".split()) + + def test_custom_handler_syntax_error(self): + with Runner('lmi', 'gpt-neox-20b-custom-handler-syntax-error') as r: + prepare.build_vllm_async_model_with_custom_handler( + "gpt-neox-20b-custom", "syntax_error") + with pytest.raises(Exception): + r.launch() + + def test_custom_handler_runtime_error(self): + with Runner('lmi', 'gpt-neox-20b-custom-handler-runtime-error') as r: + prepare.build_vllm_async_model_with_custom_handler( + "gpt-neox-20b-custom", "runtime_error") + r.launch() + with pytest.raises(ValueError, match=r".*424.*"): + client.run("custom_handler gpt-neox-20b".split()) + + def test_custom_handler_missing_handle(self): + with Runner('lmi', 'gpt-neox-20b-custom-handler-missing') as r: + prepare.build_vllm_async_model_with_custom_handler( + "gpt-neox-20b-custom", "missing_handle") + r.launch() + client.run("vllm gpt-neox-20b".split()) # Should fall back to vLLM + + def test_custom_handler_import_error(self): + with Runner('lmi', 'gpt-neox-20b-custom-handler-import-error') as r: + prepare.build_vllm_async_model_with_custom_handler( + "gpt-neox-20b-custom", "import_error") + with pytest.raises(Exception): + r.launch() + + @pytest.mark.vllm @pytest.mark.lora @pytest.mark.gpu_4 @@ -1083,7 +1124,7 @@ def test_llama31_8b_tp2_pp2_specdec(self): @pytest.mark.vllm @pytest.mark.gpu_4 -class TestVllmCustomHandlers: +class TestVllmCustomFormatters: def test_gpt_neox_20b_custom(self): with Runner('lmi', 'gpt-neox-20b') as r: