diff --git a/.github/workflows/docker_publish.yml b/.github/workflows/docker_publish.yml index 50daa0835..feebf5f5f 100644 --- a/.github/workflows/docker_publish.yml +++ b/.github/workflows/docker_publish.yml @@ -78,15 +78,25 @@ jobs: --fail \ | jq '.token' | tr -d '"' ) ./start_instance.sh action_cpu $token djl-serving + - name: Create new Graviton instance + id: create_aarch64 + run: | + cd /home/ubuntu/djl_benchmark_script/scripts + token=$( curl -X POST -H "Authorization: token ${{ secrets.ACTION_RUNNER_PERSONAL_TOKEN }}" \ + https://api.github.com/repos/deepjavalibrary/djl-serving/actions/runners/registration-token \ + --fail \ + | jq '.token' | tr -d '"' ) + ./start_instance.sh action_graviton $token djl-serving outputs: cpu_instance_id1: ${{ steps.create_cpu_1.outputs.action_cpu_instance_id }} cpu_instance_id2: ${{ steps.create_cpu_2.outputs.action_cpu_instance_id }} cpu_instance_id3: ${{ steps.create_cpu_3.outputs.action_cpu_instance_id }} + aarch64_instance_id: ${{ steps.create_aarch64.outputs.action_graviton_instance_id }} docker-sync: runs-on: - self-hosted - - cpu + - ${{ matrix.arch != 'aarch64' && 'cpu' || 'aarch64' }} - RUN_ID-${{ github.run_id }} - RUN_NUMBER-${{ github.run_number }} - SHA-${{ github.sha }} @@ -154,3 +164,5 @@ jobs: ./stop_instance.sh $instance_id instance_id=${{ needs.create-runners.outputs.cpu_instance_id3 }} ./stop_instance.sh $instance_id + instance_id=${{ needs.create-runners.outputs.aarch64_instance_id }} + ./stop_instance.sh $instance_id diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml index fd597c84f..ae0fb7711 100644 --- a/.github/workflows/integration.yml +++ b/.github/workflows/integration.yml @@ -161,9 +161,9 @@ jobs: - test: TestGpu_g6 instance: g6 failure-prefix: gpu - - test: TestAarch64 - instance: aarch64 - failure-prefix: aarch64 + # - test: TestAarch64 + # instance: aarch64 + # failure-prefix: aarch64 # - test: TestHfHandler_g6 # instance: g6 # failure-prefix: lmi diff --git a/engines/python/setup/djl_python/lmi_trtllm/request_response_utils.py b/engines/python/setup/djl_python/lmi_trtllm/request_response_utils.py new file mode 100644 index 000000000..2f0a688b9 --- /dev/null +++ b/engines/python/setup/djl_python/lmi_trtllm/request_response_utils.py @@ -0,0 +1,221 @@ +import json +from typing import Callable, Union, Tuple, List +from tensorrt_llm.serve.openai_protocol import ( + ErrorResponse, + ChatCompletionRequest, + ChatCompletionResponse, + CompletionResponse, + CompletionRequest, + CompletionLogProbs, +) +from tensorrt_llm.llmapi.tokenizer import TokenizerBase +from djl_python.async_utils import create_non_stream_output +from djl_python.outputs import Output + + +class ProcessedRequest: + + def __init__( + self, + trtllm_request: Union[CompletionRequest, ChatCompletionRequest], + inference_invoker: Callable, + non_stream_output_formatter: Callable, + stream_output_formatter: Callable, + accumulate_chunks: bool, + include_prompt: bool, + ): + self.trtllm_request = trtllm_request + self.inference_invoker = inference_invoker + # We need access to both the stream and non-stream output formatters here + # because even with streaming requests, there may be some errors before inference that + # result in a return of ErrorResponse object instead of AsyncGenerator + self.non_stream_output_formatter = non_stream_output_formatter + self.stream_output_formatter = stream_output_formatter + self.accumulate_chunks = accumulate_chunks + self.include_prompt = include_prompt + self.lora_request = None + + +def convert_lmi_schema_to_completion_request( + payload: dict, ) -> Tuple[CompletionRequest, bool, bool]: + parameters = payload.get("parameters", {}) + + completion_dict = { + "prompt": payload.pop("inputs"), + "model": payload.pop("model"), + "max_tokens": parameters.pop("max_new_tokens", 30), + "echo": parameters.pop("return_full_text", False), + "truncate_prompt_tokens": parameters.pop("truncate", None), + "n": parameters.pop("top_n_tokens", 1), + "ignore_eos": parameters.pop("ignore_eos_token", False), + "stream": payload.pop("stream", False), + } + # TRTLLM does not support logprobs in completions API. If provided, rely on TRTLLM validation error + include_details_in_response = False + include_prompt = False + if completion_dict["stream"]: + completion_dict["stream_options"] = { + "include_usage": True, + "continuous_usage_stats": True + } + include_prompt = completion_dict.pop("echo", False) + if parameters.pop("details", False): + include_details_in_response = True + if parameters.pop("decoder_input_details", False): + completion_dict["return_context_logits"] = 1 + do_sample = parameters.pop("do_sample", None) + # when do_sample is None, just passthrough sampling params as sampling is dictated by the value of other params + # when do_sample is False, set sampling params such that we disable sampling + if do_sample is not None and not do_sample: + parameters["temperature"] = 0.0 + + completion_dict.update(parameters) + + return CompletionRequest( + **completion_dict), include_details_in_response, include_prompt + + +def convert_completion_response_to_lmi_schema( + response: CompletionResponse, + request: CompletionRequest = None, + include_details: bool = False, + tokenizer: TokenizerBase = None) -> Output: + primary_choice = response.choices[0] + lmi_response = {"generated_text": primary_choice.text} + if not include_details: + return create_non_stream_output(lmi_response) + details = { + "finish_reason": primary_choice.stop_reason, + "generated_tokens": response.usage.completion_tokens, + "seed": request.seed, + } + lmi_response["details"] = details + output = create_non_stream_output(lmi_response) + return output + + +def convert_completion_chunk_response_to_lmi_schema( + chunk: str, + include_details: bool = False, + history: List[str] = None, + request: CompletionRequest = None, + include_prompt: bool = False, + tokenizer: TokenizerBase = None, + **_, +) -> Tuple[str, bool, List[str]]: + # TRTLLM returns chunks in string format, and the conversion process to TGI + # currently converts the string to an object, and then the object back to a string. + # It's much easier to work with the object instead of manipulating the string, but inefficient + trimmed_chunk = chunk[6:].strip() + if trimmed_chunk == '[DONE]': + data = "" + return data, True, history + + trt_completion_chunk = json.loads(trimmed_chunk) + if "error" in trt_completion_chunk: + return json.dumps(trt_completion_chunk, + ensure_ascii=False), True, history + + if len(trt_completion_chunk["choices"]) == 0: + # penultimate chunk + return "", False, history + choice = trt_completion_chunk["choices"][0] + index = choice["index"] + token_text = choice["text"] + history.append(token_text) + finish_reason = choice["finish_reason"] + stop_reason = choice["stop_reason"] + usage = trt_completion_chunk["usage"] + + # TODO: TokenId and LogProb here + token = { + "id": None, + "text": token_text, + "logprob": None, + } + tgi_chunk = { + "index": index, + "token": token, + "generated_text": None, + "details": None, + } + generation_finished = finish_reason is not None or stop_reason is not None + if generation_finished: + generated_text = ''.join(history) + if include_prompt: + generated_text = request.prompt + generated_text + tgi_chunk["generated_text"] = generated_text + if include_details: + details = { + "finish_reason": finish_reason or stop_reason, + "seed": request.seed, + "generated_tokens": usage["completion_tokens"] + 1, + "input_length": usage["prompt_tokens"], + } + tgi_chunk["details"] = details + json_str = json.dumps(tgi_chunk, ensure_ascii=False) + return json_str, False, history + + +def lmi_with_details_non_stream_output_formatter( + response: CompletionResponse, + request: CompletionRequest = None, + tokenizer: TokenizerBase = None, +) -> Output: + return convert_completion_response_to_lmi_schema(response, + include_details=True, + request=request, + tokenizer=tokenizer) + + +def lmi_non_stream_output_formatter( + response: CompletionResponse, + request: CompletionRequest = None, + tokenizer: TokenizerBase = None, +) -> Output: + return convert_completion_response_to_lmi_schema(response, + include_details=False, + request=request, + tokenizer=tokenizer) + + +def lmi_with_details_stream_output_formatter( + chunk: str, + **kwargs, +) -> Tuple[str, bool, List[str]]: + return convert_completion_chunk_response_to_lmi_schema( + chunk, include_details=True, **kwargs) + + +def lmi_stream_output_formatter( + chunk: str, + **kwargs, +) -> Tuple[str, bool, List[str]]: + return convert_completion_chunk_response_to_lmi_schema(chunk, **kwargs) + + +def trtllm_non_stream_output_formatter( + response: Union[ErrorResponse, ChatCompletionResponse, CompletionResponse], + **_, +) -> Output: + if isinstance(response, ErrorResponse): + return create_non_stream_output("", + error=response.message, + code=response.code) + response_data = response.model_dump_json() + return create_non_stream_output(response_data) + + +def trtllm_stream_output_formatter( + chunk: str, + **_, +) -> Tuple[str, bool]: + # trtllm returns responses in sse format, 'data: {...}' + trimmed_chunk = chunk[6:].strip() + if trimmed_chunk == '[DONE]': + data = "" + last = True + else: + data = trimmed_chunk + last = False + return data, last diff --git a/engines/python/setup/djl_python/lmi_trtllm/trtllm_async_service.py b/engines/python/setup/djl_python/lmi_trtllm/trtllm_async_service.py new file mode 100644 index 000000000..2fd06584d --- /dev/null +++ b/engines/python/setup/djl_python/lmi_trtllm/trtllm_async_service.py @@ -0,0 +1,425 @@ +#!/usr/bin/env python +# +# Copyright 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file +# except in compliance with the License. A copy of the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS" +# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for +# the specific language governing permissions and limitations under the License. + +# Heavily inspired by https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/serve/openai_server.py +import asyncio +import logging +import signal +import types +from http import HTTPStatus +from typing import AsyncGenerator, AsyncIterator, List, Tuple, TypedDict, Union, Optional + +from openai.types.chat import ChatCompletionMessageParam + +# yapf: disable +from tensorrt_llm._tensorrt_engine import LLM +from tensorrt_llm.executor import CppExecutorError +from tensorrt_llm.executor.postproc_worker import PostprocParams +from tensorrt_llm.llmapi.llm import RequestOutput +from tensorrt_llm.serve.openai_protocol import (ChatCompletionRequest, + ChatCompletionResponse, + CompletionRequest, + CompletionResponse, + CompletionResponseChoice, + ErrorResponse, ModelCard, + ModelList, UsageInfo, + to_llm_disaggregated_params) +from tensorrt_llm.serve.postprocess_handlers import ( + ChatPostprocArgs, CompletionPostprocArgs, chat_response_post_processor, + chat_stream_post_processor, completion_response_post_processor, + completion_stream_post_processor) + +from djl_python.async_utils import create_non_stream_output, handle_streaming_response +from djl_python.inputs import Input +from djl_python.outputs import Output +from djl_python.properties_manager.trt_properties import TensorRtLlmProperties +from djl_python.encode_decode import decode + +from .request_response_utils import ( + trtllm_non_stream_output_formatter, + trtllm_stream_output_formatter, + convert_lmi_schema_to_completion_request, + lmi_with_details_stream_output_formatter, + lmi_stream_output_formatter, + lmi_with_details_non_stream_output_formatter, + lmi_non_stream_output_formatter, + ProcessedRequest) + +logger = logging.getLogger(__name__) + + +class ConversationMessage(TypedDict): + role: str + content: str + + +def parse_chat_message_content( + message: ChatCompletionMessageParam, ) -> List[ConversationMessage]: + role = message["role"] + content = message.get("content") + + if content is None: + return [] + if isinstance(content, str): + return [ConversationMessage(role=role, content=content)] + + # for Iterable[ChatCompletionContentPartTextParam] + texts: List[str] = [] + for part in content: + part_type = part["type"] + if part_type == "text": + text = part["text"] + texts.append(text) + else: + raise NotImplementedError(f"{part_type} is not supported") + + text_prompt = "\n".join(texts) + return [ConversationMessage(role=role, content=text_prompt)] + + +class TensorRTLlmAsyncService: + + def __init__(self): + self.trt_configs = None + self.llm = None + self.tokenizer = None + self.model_name = None + self.postproc_worker_enabled = False + self.initialized = False + + def initialize(self, properties: dict): + import os + logger.info(f"Raw properties received: {properties}") + self.trt_configs = TensorRtLlmProperties(**properties) + logger.info(f"Extra properties in trt_configs: {self.trt_configs.__pydantic_extra__}") + llm_kwargs = self.trt_configs.get_llm_kwargs() + logger.info(f"LLM kwargs being passed: {llm_kwargs}") + logger.info(f"Backend setting: {self.trt_configs.backend}") + + # Check what files exist in model directory + model_path = self.trt_configs.model_id_or_path + if os.path.exists(model_path): + files = os.listdir(model_path) + logger.info(f"Files in model directory: {files}") + # Check for specific config files + for config_file in ['config.json', 'model_config.json', 'engine_config.json']: + config_path = os.path.join(model_path, config_file) + if os.path.exists(config_path): + logger.info(f"Found {config_file} in model directory") + # In this handler, we expect the front-end to pre-compile the model into a TRTLLM engine. + # That means we do not expose build config customization here, and only set runtime configs + try: + self.llm = LLM( + model=self.trt_configs.model_id_or_path, + tensor_parallel_size=llm_kwargs.get("tensor_parallel_degree", 1), + trust_remote_code=llm_kwargs.get("trust_remote_code", False), + dtype=llm_kwargs.get("dtype", "auto"), + revision=llm_kwargs.get("revision", None), + **llm_kwargs, + ) + except Exception as e: + logger.error(f"Failed to initialize LLM with error: {e}") + logger.error(f"Model path: {self.trt_configs.model_id_or_path}") + logger.error(f"LLM kwargs keys: {list(llm_kwargs.keys())}") + raise + self.tokenizer = self.llm.tokenizer + self.model_name = self.trt_configs.model_id_or_path + self.postproc_worker_enabled = self.llm.args.num_postprocess_workers > 0 + self.initialized = True + + + def preprocess_requests(self, inputs: Input) -> ProcessedRequest: + batch = inputs.get_batches() + assert len(batch) == 1, "only one request per batch allowed" + raw_request = batch[0] + content_type = raw_request.get_property("Content-Type") + decoded_payload = decode(raw_request, content_type) + accumulate_chunks = False + include_prompt = False + if "model" not in decoded_payload: + decoded_payload["model"] = self.model_name + if "prompt" in decoded_payload: + request = CompletionRequest(**decoded_payload) + invoke_call = self.openai_completion + non_stream_output_formatter = trtllm_non_stream_output_formatter + stream_output_formatter = trtllm_stream_output_formatter + elif "messages" in decoded_payload: + request = ChatCompletionRequest(**decoded_payload) + invoke_call = self.openai_chat + non_stream_output_formatter = trtllm_non_stream_output_formatter + stream_output_formatter = trtllm_stream_output_formatter + elif "inputs" in decoded_payload: + request, include_details, include_prompt = convert_lmi_schema_to_completion_request(decoded_payload) + invoke_call = self.openai_completion + non_stream_output_formatter = lmi_with_details_non_stream_output_formatter if include_details else lmi_non_stream_output_formatter + stream_output_formatter = lmi_with_details_stream_output_formatter if include_details else lmi_stream_output_formatter + accumulate_chunks = True + else: + raise RuntimeError("invalid payload. must contain prompt, inputs, or messages") + processed_request = ProcessedRequest( + request, + invoke_call, + non_stream_output_formatter, + stream_output_formatter, + accumulate_chunks, + include_prompt, + ) + return processed_request + + async def inference(self, inputs: Input) -> Union[Output, AsyncGenerator[Output, None]]: + try: + processed_request = self.preprocess_requests(inputs) + except Exception as e: + logger.exception("Input parsing failed") + output = create_non_stream_output("", error=f"Input parsing failed: {e}", code=424) + return output + + response = await processed_request.inference_invoker(processed_request.trtllm_request) + if isinstance(response, types.AsyncGeneratorType): + return handle_streaming_response( + response, + processed_request.stream_output_formatter, + accumulate_chunks=processed_request.accumulate_chunks, + include_prompt=processed_request.include_prompt, + tokenizer=self.tokenizer, + request=processed_request.trtllm_request, + ) + + return processed_request.non_stream_output_formatter( + response, + request=processed_request.trtllm_request, + tokenizer=self.tokenizer, + ) + + @staticmethod + def create_error_response( + message: str, + err_type: str = "BadRequestError", + status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> ErrorResponse: + error_response = ErrorResponse(message=message, + type=err_type, + code=status_code.value) + return error_response + + async def openai_chat(self, request: ChatCompletionRequest) -> Union[ErrorResponse, ChatCompletionResponse, AsyncGenerator[str, None]]: + + def get_role() -> str: + if request.add_generation_prompt: + role = "assistant" + else: + role = request.messages[-1]["role"] + return role + + async def chat_stream_generator( + promise: RequestOutput, postproc_params: PostprocParams) -> AsyncGenerator[str, None]: + if not self.postproc_worker_enabled: + post_processor, args = postproc_params.post_processor, postproc_params.postproc_args + async for res in promise: + pp_results = res.outputs[0]._postprocess_result if self.postproc_worker_enabled else post_processor(res, args) + for pp_res in pp_results: + yield pp_res + yield f"data: [DONE]\n\n" + + async def create_chat_response( + promise: RequestOutput, postproc_params: PostprocParams) -> ChatCompletionResponse: + await promise.aresult() + if self.postproc_worker_enabled: + return promise.outputs[0]._postprocess_result + else: + post_processor, args = postproc_params.post_processor, postproc_params.postproc_args + return post_processor(promise, args) + + try: + conversation: List[ConversationMessage] = [] + for msg in request.messages: + conversation.extend(parse_chat_message_content(msg)) + tool_dicts = None if request.tools is None else [ + tool.model_dump() for tool in request.tools + ] + prompt: str = self.tokenizer.apply_chat_template( + conversation=conversation, + tokenize=False, + add_generation_prompt=request.add_generation_prompt, + tools=tool_dicts, + documents=request.documents, + chat_template=request.chat_template, + **(request.chat_template_kwargs or {}), + ) + sampling_params = request.to_sampling_params() + disaggregated_params = to_llm_disaggregated_params(request.disaggregated_params) + postproc_args = ChatPostprocArgs.from_request(request) + if conversation and conversation[-1].get( + "content") and conversation[-1].get("role") == get_role(): + postproc_args.last_message_content = conversation[-1]["content"] + postproc_params = PostprocParams( + post_processor=chat_stream_post_processor + if request.stream else chat_response_post_processor, + postproc_args=postproc_args, + ) + + promise = self.llm.generate_async( + inputs=prompt, + sampling_params=sampling_params, + _postproc_params=postproc_params if self.postproc_worker_enabled else None, + streaming=request.stream, + disaggregated_params=disaggregated_params + ) + if not self.postproc_worker_enabled: + postproc_args.tokenizer = self.tokenizer + postproc_args.num_prompt_tokens = len(promise.prompt_token_ids) + + if request.stream: + response_generator = chat_stream_generator(promise, postproc_params) + return response_generator + else: + response = await create_chat_response(promise, postproc_params) + return response + except CppExecutorError: + # If internal executor error is raised, shutdown the server + signal.raise_signal(signal.SIGINT) + except Exception as e: + return self.create_error_response(str(e)) + + async def openai_completion(self, request: CompletionRequest) -> Union[ErrorResponse, CompletionResponse, AsyncGenerator[str, None]]: + + def merge_promises( + promises: List[RequestOutput], + postproc_params_collections: List[Optional[PostprocParams]] + ) -> AsyncIterator[Tuple[RequestOutput, Optional[PostprocParams]]]: + outputs = asyncio.Queue() + finished = [False] * len(promises) + + async def producer(i: int, promise: RequestOutput, postproc_params: Optional[PostprocParams]): + async for output in promise: + await outputs.put((output, postproc_params)) + finished[i] = True + + _tasks = [ + asyncio.create_task(producer(i, promise, postproc_params)) + for i, (promise, postproc_params) in enumerate(zip(promises, postproc_params_collections)) + ] + + async def consumer(): + while not all(finished) or not outputs.empty(): + item = await outputs.get() + yield item + await asyncio.gather(*_tasks) + + return consumer() + + async def create_completion_generator( + generator: AsyncIterator[Tuple[RequestOutput, Optional[PostprocParams]]]): + async for request_output, postproc_params in generator: + if not self.postproc_worker_enabled: + post_processor, args = postproc_params.post_processor, postproc_params.postproc_args + pp_result = post_processor(request_output, args) + else: + pp_result = request_output.outputs[0]._postprocess_result + for pp_res in pp_result: + yield pp_res + yield f"data: [DONE]\n\n" + + async def create_completion_response( + generator: AsyncIterator[Tuple[RequestOutput, Optional[PostprocParams]]]) -> CompletionResponse: + all_choices: List[CompletionResponseChoice] = [] + num_prompt_tokens = num_gen_tokens = 0 + async for request_output, postproc_params in generator: + pp_result: CompletionResponse + if not self.postproc_worker_enabled: + post_processor, args = postproc_params.post_processor, postproc_params.postproc_args + pp_result = post_processor(request_output, args) + else: + pp_result = request_output.outputs[0]._postprocess_result + + choices, usage = pp_result.choices, pp_result.usage + all_choices.extend(choices) + num_prompt_tokens += usage.prompt_tokens + num_gen_tokens += usage.completion_tokens + + usage_info = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=num_gen_tokens, + total_tokens=num_gen_tokens + num_prompt_tokens, + ) + response = CompletionResponse( + model=self.model_name, + choices=all_choices, + usage=usage_info, + ) + return response + + try: + if isinstance(request.prompt, str) or \ + (isinstance(request.prompt, list) and isinstance(request.prompt[0], int)): + prompts = [request.prompt] + else: + prompts = request.prompt + + promises: List[RequestOutput] = [] + postproc_params_collection: List[Optional[PostprocParams]] = [] + sampling_params = request.to_sampling_params() + disaggregated_params = to_llm_disaggregated_params(request.disaggregated_params) + for idx, prompt in enumerate(prompts): + postproc_args = CompletionPostprocArgs.from_request(request) + postproc_args.prompt_idx = idx + if request.echo: + postproc_args.prompt = prompt + postproc_params = PostprocParams( + post_processor=completion_stream_post_processor + if request.stream else completion_response_post_processor, + postproc_args=postproc_args, + ) + promise = self.llm.generate_async( + inputs=prompt, + sampling_params=sampling_params, + _postproc_params=postproc_params, + streaming=request.stream, + disaggregated_params=disaggregated_params + ) + if not self.postproc_worker_enabled: + postproc_args.tokenizer = self.tokenizer + postproc_args.num_prompt_tokens = len(promise.prompt_token_ids) + promises.append(promise) + postproc_params_collection.append(None if self.postproc_worker_enabled else postproc_params) + + generator = merge_promises(promises, postproc_params_collection) + if request.stream: + response_generator = create_completion_generator( + generator) + return response_generator + else: + response = await create_completion_response( + generator) + return response + except CppExecutorError: + # If internal executor error is raised, shutdown the server + signal.raise_signal(signal.SIGINT) + except Exception as e: + logger.exception(f"Encountered an exception: {str(e)}") + return self.create_error_response(str(e)) + + +service = TensorRTLlmAsyncService() + + +async def handle(inputs: Input) -> Optional[Output]: + if not service.initialized: + logger.info("########## Initializing TRTLLM Service ##########") + logger.info(inputs.get_properties()) + service.initialize(inputs.get_properties()) + logger.info("trtllm service initialized") + if inputs.is_empty(): + return None + + outputs = await service.inference(inputs) + return outputs \ No newline at end of file diff --git a/engines/python/setup/djl_python/properties_manager/trt_properties.py b/engines/python/setup/djl_python/properties_manager/trt_properties.py index 4450ad5c0..3a9feec7a 100644 --- a/engines/python/setup/djl_python/properties_manager/trt_properties.py +++ b/engines/python/setup/djl_python/properties_manager/trt_properties.py @@ -10,8 +10,15 @@ # or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS" # BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for # the specific language governing permissions and limitations under the License. +import json +from pydantic import ConfigDict, field_validator +from typing import Optional + +from tensorrt_llm.llmapi import KvCacheConfig +from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig +from tensorrt_llm.llmapi import QuantConfig, CalibConfig, QuantAlgo + from djl_python.properties_manager.properties import Properties, RollingBatchEnum -from pydantic import field_validator TRT_SUPPORTED_ROLLING_BATCH_TYPES = [ RollingBatchEnum.auto.value, RollingBatchEnum.trtllm.value, @@ -21,6 +28,143 @@ class TensorRtLlmProperties(Properties): + # in our implementation, we handle compilation/quantization ahead of time. + # we do not expose build_config, quant_config, calib_config here as those get handled by + # the compilation in trt_llm_partition.py. We do it this way so that users can completely build the complete + # trt engine ahead of time via that script. If provided just a HF model id, then that script gets invoked, + # does compilation/quantization and generates engines that will get loaded here. We are only exposing + # runtime knobs here. + + tokenizer: Optional[str] = None + tokenizer_mode: str = 'auto' + skip_tokenizer_init: bool = False + dtype: str = 'auto' + revision: Optional[str] = None + tokenizer_revision: Optional[str] = None + pipeline_parallel_size: int = 1 + context_parallel_size: int = 1 + moe_tensor_parallel_size: Optional[int] = None + moe_expert_parallel_size: Optional[int] = None + enable_attention_dp: bool = False + auto_parallel: bool = False + auto_parallel_world_size: Optional[int] = None + load_format: str = 'auto' + enable_chunked_prefill: bool = False + guided_decoding_backend: Optional[str] = None + iter_stats_max_iterations: Optional[int] = None + request_stats_max_iterations: Optional[int] = None + embedding_parallel_mode: str = 'SHARDING_ALONG_VOCAB' + fast_build: bool = False + # different default! allows for faster loading on worker restart + enable_build_cache: bool = True + batching_type: Optional[None] = None + normalize_log_probs: bool = False + gather_generation_logits: bool = False + extended_runtime_perf_knob_config: Optional[None] = None + max_batch_size: Optional[int] = None + max_input_len: int = 1024 + max_seq_len: Optional[int] = None + max_beam_width: int = 1 + max_num_tokens: Optional[int] = None + backend: Optional[str] = None + + model_config = ConfigDict(extra='allow', populate_by_name=True) + + def get_kv_cache_config(self) -> Optional[KvCacheConfig]: + kv_cache_config = {} + if "enable_block_reuse" in self.__pydantic_extra__: + kv_cache_config["enable_block_reuse"] = self.__pydantic_extra__[ + "enable_block_reuse"].lower() == "true" + if "max_tokens" in self.__pydantic_extra__: + kv_cache_config["max_tokens"] = int( + self.__pydantic_extra__["max_tokens"]) + if "max_attention_window" in self.__pydantic_extra__: + kv_cache_config["max_attention_window"] = json.loads( + self.__pydantic_extra__["max_attention_window"]) + if "sink_token_length" in self.__pydantic_extra__: + kv_cache_config["sink_token_length"] = int( + self.__pydantic_extra__["sink_token_length"]) + if "free_gpu_memory_fraction" in self.__pydantic_extra__: + kv_cache_config["free_gpu_memory_fraction"] = float( + self.__pydantic_extra__["free_gpu_memory_fraction"]) + if "host_cache_size" in self.__pydantic_extra__: + kv_cache_config["host_cache_size"] = int( + self.__pydantic_extra__["host_cache_size"]) + if "onboard_blocks" in self.__pydantic_extra__: + kv_cache_config["onboard_blocks"] = self.__pydantic_extra__[ + "onboard_blocks"].lower() == "true" + if "cross_kv_cache_fraction" in self.__pydantic_extra__: + kv_cache_config["cross_kv_cache_fraction"] = float( + self.__pydantic_extra__["cross_kv_cache_fraction"]) + if "secondary_offload_min_priority" in self.__pydantic_extra__: + kv_cache_config["secondary_offload_min_priority"] = int( + self.__pydantic_extra__["secondary_offload_min_priority"]) + if "event_buffer_max_size" in self.__pydantic_extra__: + kv_cache_config["event_buffer_max_size"] = int( + self.__pydantic_extra__["event_buffer_max_size"]) + if kv_cache_config: + return KvCacheConfig(**kv_cache_config) + return None + + def get_pytorch_config(self) -> Optional[PyTorchConfig]: + if self.backend != 'pytorch': + return None + # https://github.com/NVIDIA/TensorRT-LLM/blob/v0.20.0rc0/examples/pytorch/quickstart_advanced.py#L107 + pytorch_config = { + "enable_overlap_scheduler": + self.__pydantic_extra__.get('enable_overlap_scheduler', + 'false').lower() == 'true', + "kv_cache_dtype": + self.__pydantic_extra__.get('kv_cache_dtype', 'auto'), + "attn_backend": + self.__pydantic_extra__.get('attn_backend', 'TRTLLM'), + 'use_cuda_graph': + self.__pydantic_extra__.get('use_cuda_graph', + 'false').lower() == 'true', + 'load_format': + self.__pydantic_extra__.get('load_format', 'auto'), + 'moe_backend': + self.__pydantic_extra__.get('moe_backend', 'CUTLASS') + } + return PyTorchConfig(**pytorch_config) + + def get_llm_kwargs(self) -> dict: + return { + "tokenizer": self.tokenizer, + "tokenizer_mode": self.tokenizer_mode, + "skip_tokenizer_init": self.skip_tokenizer_init, + "trust_remote_code": self.trust_remote_code, + "dtype": self.dtype, + "revision": self.revision, + "tokenizer_revision": self.tokenizer_revision, + "tensor_parallel_size": self.tensor_parallel_degree, + "pipeline_parallel_size": self.pipeline_parallel_size, + "context_parallel_size": self.context_parallel_size, + "moe_tensor_parallel_size": self.moe_tensor_parallel_size, + "moe_expert_parallel_size": self.moe_expert_parallel_size, + "enable_attention_dp": self.enable_attention_dp, + "auto_parallel": self.auto_parallel, + "auto_parallel_world_size": self.auto_parallel_world_size, + "load_format": self.load_format, + "enable_chunked_prefill": self.enable_chunked_prefill, + "guided_decoding_backend": self.guided_decoding_backend, + "iter_stats_max_iterations": self.iter_stats_max_iterations, + "request_stats_max_iterations": self.request_stats_max_iterations, + "embedding_parallel_mode": self.embedding_parallel_mode, + "enable_build_cache": self.enable_build_cache, + "batching_type": self.batching_type, + "normalize_log_probs": self.normalize_log_probs, + "gather_generation_logits": self.gather_generation_logits, + "max_batch_size": self.max_rolling_batch_size, + "max_input_len": self.max_input_len, + "max_seq_len": self.max_seq_len, + "max_beam_width": self.max_beam_width, + "max_num_tokens": self.max_num_tokens, + "backend": self.backend, + "kv_cache_config": self.get_kv_cache_config(), + "pytorch_config": self.get_pytorch_config(), + } + @field_validator('rolling_batch', mode='before') def validate_rolling_batch(cls, rolling_batch: str) -> str: rolling_batch = rolling_batch.lower() diff --git a/engines/python/setup/djl_python/properties_manager/vllm_rb_properties.py b/engines/python/setup/djl_python/properties_manager/vllm_rb_properties.py index 7a20d17a1..c742573a5 100644 --- a/engines/python/setup/djl_python/properties_manager/vllm_rb_properties.py +++ b/engines/python/setup/djl_python/properties_manager/vllm_rb_properties.py @@ -122,7 +122,7 @@ def validate_pipeline_parallel(self): def validate_tool_call_parser(self): if self.enable_auto_tool_choice: from vllm.entrypoints.openai.tool_parsers import ToolParserManager - valid_tool_parses = ToolParserManager.tool_parsers.keys() + valid_tool_parses = ToolParserManager.list_registered() if self.tool_call_parser not in valid_tool_parses: raise ValueError( f"Invalid tool call parser: {self.tool_call_parser} " @@ -133,8 +133,7 @@ def validate_tool_call_parser(self): def validate_reasoning_parser(self): if self.enable_reasoning: from vllm.reasoning.abs_reasoning_parsers import ReasoningParserManager - valid_reasoning_parses = ReasoningParserManager.reasoning_parsers.keys( - ) + valid_reasoning_parses = ReasoningParserManager.list_registered() if self.reasoning_parser not in valid_reasoning_parses: raise ValueError( f"Invalid reasoning parser: {self.reasoning_parser} " diff --git a/engines/python/setup/djl_python/tests/test_properties_manager.py b/engines/python/setup/djl_python/tests/test_properties_manager.py index a008a3309..0457c0255 100644 --- a/engines/python/setup/djl_python/tests/test_properties_manager.py +++ b/engines/python/setup/djl_python/tests/test_properties_manager.py @@ -6,8 +6,6 @@ from vllm import EngineArgs from djl_python.properties_manager.properties import Properties - -from djl_python.properties_manager.trt_properties import TensorRtLlmProperties from djl_python.properties_manager.hf_properties import HuggingFaceProperties from djl_python.properties_manager.vllm_rb_properties import VllmRbProperties @@ -94,16 +92,6 @@ def test_common_configs_error_case(self): with self.assertRaises(ValueError): Properties(**other_properties) - @parameters([{ - "rolling_batch": "auto", - }]) - def test_trt_llm_configs(self, params): - properties = {**model_min_properties, **params} - trt_configs = TensorRtLlmProperties(**properties) - self.assertEqual(trt_configs.model_id_or_path, properties['model_id']) - self.assertEqual(trt_configs.rolling_batch.value, - properties['rolling_batch']) - def test_hf_configs(self): properties = { "model_id": "model_id", diff --git a/engines/python/src/main/java/ai/djl/python/engine/PyModel.java b/engines/python/src/main/java/ai/djl/python/engine/PyModel.java index 97807c679..857eae5cf 100644 --- a/engines/python/src/main/java/ai/djl/python/engine/PyModel.java +++ b/engines/python/src/main/java/ai/djl/python/engine/PyModel.java @@ -79,7 +79,8 @@ public void load(Path modelPath, String prefix, Map options) throws I String entryPoint = null; String recommendedEntryPoint = null; if (options != null) { - // If tp_degree set to "max", we defer and set it at the end to ensure we take pp degree + // If tp_degree set to "max", we defer and set it at the end to ensure we take + // pp degree // into account. boolean setTensorParallelDegreeToMax = false; logger.debug("options in serving.properties for model: {}", modelName); @@ -183,7 +184,7 @@ public void load(Path modelPath, String prefix, Map options) throws I modelDir, prefix, ".skops", ".joblib", ".pkl", ".pickle", ".cloudpkl")) { recommendedEntryPoint = "djl_python.sklearn_handler"; } else if ("trtllm".equals(features)) { - recommendedEntryPoint = "djl_python.tensorrt_llm"; + recommendedEntryPoint = "djl_python.lmi_trtllm.trtllm_async_service"; } else if ("vllm".equals(features)) { if (pyEnv.isAsyncMode()) { recommendedEntryPoint = "djl_python.lmi_vllm.vllm_async_service"; diff --git a/serving/docker/lmi-container-requirements.txt b/serving/docker/lmi-container-requirements.txt index 8f3affe2c..09b41ec30 100644 --- a/serving/docker/lmi-container-requirements.txt +++ b/serving/docker/lmi-container-requirements.txt @@ -32,7 +32,7 @@ uvloop ninja peft llmcompressor -https://vllm-wheels.s3.us-west-2.amazonaws.com/d3ab240f39219df0175ec662416f630d7bf273d8/vllm-1.0.0.dev-cp38-abi3-manylinux1_x86_64.whl +https://vllm-wheels.s3.us-west-2.amazonaws.com/93103575ce0480f36fc1a3603eb51d9a89f38a00/vllm-1.0.0.dev-cp38-abi3-manylinux1_x86_64.whl xgrammar -flashinfer-python==0.4.1 +flashinfer-python==0.5.2 lmcache \ No newline at end of file diff --git a/serving/docker/partition/trt_llm_partition.py b/serving/docker/partition/trt_llm_partition.py index 151c36e5b..bf0e4d2ea 100644 --- a/serving/docker/partition/trt_llm_partition.py +++ b/serving/docker/partition/trt_llm_partition.py @@ -1,39 +1,290 @@ -#!/usr/bin/env python -# # Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 # -# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file -# except in compliance with the License. A copy of the License is located at +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # # http://aws.amazon.com/apache2.0/ # -# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS" -# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for -# the specific language governing permissions and limitations under the License. - +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import argparse -import logging import os -import sys -from tensorrt_llm_toolkit import create_model_repo +import time +import json +from typing import Optional + +from tensorrt_llm.auto_parallel import infer_cluster_config +from tensorrt_llm.commands.build import parse_arguments, parallel_build +from tensorrt_llm.logger import logger, severity_map +from tensorrt_llm.models.modeling_utils import SpeculativeDecodingMode +from tensorrt_llm.plugin import PluginConfig, add_plugin_argument +from tensorrt_llm._tensorrt_engine import LLM +from tensorrt_llm.llmapi import BuildConfig, QuantConfig, CalibConfig, QuantAlgo from utils import update_kwargs_with_env_vars, load_properties, remove_option_from_properties -def create_trt_llm_repo(properties, args): - kwargs = remove_option_from_properties(properties) - kwargs['trt_llm_model_repo'] = args.trt_llm_model_repo - kwargs["tensor_parallel_degree"] = args.tensor_parallel_degree - kwargs["pipeline_parallel_degree"] = args.pipeline_parallel_degree - model_id_or_path = args.model_path or kwargs['model_id'] - create_model_repo(model_id_or_path, **kwargs) +def build_engine( + trtllm_engine_configs: dict, + model_id: str, + output_dir: str, + tensor_parallel_degree: int, +): + tik = time.time() + llm_model = LLM(model=model_id, + tensor_parallel_size=tensor_parallel_degree, + trust_remote_code=trtllm_engine_configs.get( + "trust_remote_code", False), + dtype=trtllm_engine_configs.get("dtype", "auto"), + revision=trtllm_engine_configs.get("revision", None), + **trtllm_engine_configs.llm_kwargs) + + logger.info(f"[LMI] Model Compiled successfully, saving to {output_dir}") + llm_model.save(output_dir) + + tok = time.time() + t = time.strftime('%H:%M:%S', time.gmtime(tok - tik)) + logger.info(f'Total time of building all engines: {t}') + + +def parse_build_config(properties: dict) -> BuildConfig: + if "max_rolling_batch_size" in properties: + properties["max_batch_size"] = properties["max_rolling_batch_size"] + trtllm_args = [] + for k, v in properties.items(): + trtllm_args.append(f"--{k}") + trtllm_args.append(f"{v}") + parser = parse_arguments() + args, unknown = parser.parse_known_args(args=trtllm_args) + logger.info( + f"[LMI] The following args will be passed to the build_config for TRTLLM: {args}" + ) + logger.info( + f"[LMI] The following args are not used by TRTLLM build and will be saved for the runtime configuration: {unknown}" + ) + + if hasattr(args, 'gather_generation_logits'): + logger.warning( + 'Option --gather_generation_logits is deprecated, a build flag is not required anymore. Use --output_generation_logits at runtime instead.' + ) + + if args.gather_all_token_logits: + args.gather_context_logits = True + args.gather_generation_logits = True + if args.gather_context_logits and args.max_draft_len > 0: + raise RuntimeError( + "Gather context logits is not support with draft len > 0. " + "If want to get the accepted tokens' logits from target model, please just enable gather_generation_logits" + ) + + if hasattr(args, 'paged_kv_cache'): + logger.warning( + 'Option --paged_kv_cache is deprecated, use --kv_cache_type=paged/disabled instead.' + ) + + plugin_config = PluginConfig.from_arguments(args) + plugin_config.validate() + if args.fast_build: + plugin_config.manage_weights = True + + speculative_decoding_mode = SpeculativeDecodingMode.from_arguments(args) + + if args.build_config is None: + if args.multiple_profiles == "enable" and args.opt_num_tokens is not None: + raise RuntimeError( + "multiple_profiles is enabled, while opt_num_tokens is set. " + "They are not supposed to be working in the same time for now." + ) + if args.cluster_key is not None: + cluster_config = dict(cluster_key=args.cluster_key) + else: + cluster_config = infer_cluster_config() + + # This should only be used for debugging. + # The env var BUILDER_FORCE_NUM_PROFILES should override the number of + # optimization profiles during TRT build. + # BUILDER_FORCE_NUM_PROFILES must be less than or equal to the number of + # optimization profiles set by model's prepare_inputs(). + force_num_profiles_from_env = os.environ.get( + "BUILDER_FORCE_NUM_PROFILES", None) + if force_num_profiles_from_env is not None: + logger.warning( + f"Overriding # of builder profiles <= {force_num_profiles_from_env}." + ) + + build_config = BuildConfig.from_dict( + { + 'max_input_len': args.max_input_len, + 'max_seq_len': args.max_seq_len, + 'max_batch_size': args.max_batch_size, + 'max_beam_width': args.max_beam_width, + 'max_num_tokens': args.max_num_tokens, + 'opt_num_tokens': args.opt_num_tokens, + 'max_prompt_embedding_table_size': + args.max_prompt_embedding_table_size, + 'gather_context_logits': args.gather_context_logits, + 'gather_generation_logits': args.gather_generation_logits, + 'strongly_typed': True, + 'force_num_profiles': force_num_profiles_from_env, + 'weight_sparsity': args.weight_sparsity, + 'profiling_verbosity': args.profiling_verbosity, + 'enable_debug_output': args.enable_debug_output, + 'max_draft_len': args.max_draft_len, + 'speculative_decoding_mode': speculative_decoding_mode, + 'input_timing_cache': args.input_timing_cache, + 'output_timing_cache': '/tmp/model.cache', + 'auto_parallel_config': { + 'world_size': + args.auto_parallel, + 'gpus_per_node': + args.gpus_per_node, + 'sharded_io_allowlist': [ + 'past_key_value_\\d+', + 'present_key_value_\\d*', + ], + 'same_buffer_io': { + 'past_key_value_(\\d+)': 'present_key_value_\\1', + }, + **cluster_config, + }, + 'dry_run': args.dry_run, + 'visualize_network': args.visualize_network, + 'max_encoder_input_len': args.max_encoder_input_len, + 'weight_streaming': args.weight_streaming, + 'monitor_memory': args.monitor_memory, + }, + plugin_config=plugin_config) + + if hasattr(args, 'kv_cache_type'): + build_config.update_from_dict( + {'kv_cache_type': args.kv_cache_type}) + else: + build_config = BuildConfig.from_json_file(args.build_config, + plugin_config=plugin_config) + return build_config + + +def parse_quant_config(properties: dict) -> Optional[QuantConfig]: + quant_config = {} + if "quant_algo" in properties: + quant_config["quant_algo"] = QuantAlgo( + properties.pop("quant_algo").upper()) + if "kv_cache_quant_algo" in properties: + quant_config["kv_cache_quant_algo"] = QuantAlgo( + properties.pop("kv_cache_quant_algo").upper()) + if "group_size" in properties: + quant_config["group_size"] = int(properties.pop("group_size")) + if "smoothquant_val" in properties: + quant_config["smoothquant_val"] = float( + properties.pop("smoothquant_val")) + if "clamp_val" in properties: + quant_config["clamp_val"] = json.loads(properties.pop("clamp_val")) + if "use_meta_recipe" in properties: + quant_config["use_meta_recipe"] = properties.pop( + "use_meta_recipe").lower() == "true" + if "has_zero_point" in properties: + quant_config["has_zero_point"] = properties.pop( + "has_zero_point").lower() == "true" + if "pre_quant_scales" in properties: + quant_config["pre_quant_scales"] = properties.pop( + "pre_quant_scales").lower() == "true" + if "exclude_modules" in properties: + quant_config["exclude_modules"] = json.loads( + properties.pop("exclude_modules")) + if quant_config: + return QuantConfig(**quant_config) + return None + + +def parse_calib_config(properties: dict) -> Optional[CalibConfig]: + calib_config = {} + if "device" in properties: + calib_config["device"] = properties.pop("device") + if "calib_dataset" in properties: + calib_config["calib_dataset"] = properties.pop("calib_dataset") + if "calib_batches" in properties: + calib_config["calib_batches"] = int(properties.pop("calib_batches")) + if "calib_batch_size" in properties: + calib_config["calib_batch_size"] = int( + properties.pop("calib_batch_size")) + if "calib_max_seq_length" in properties: + calib_config["calib_max_seq_length"] = int( + properties.pop("calib_max_seq_length")) + if "random_seed" in properties: + calib_config["random_seed"] = int(properties.pop("random_seed")) + if "tokenizer_max_seq_length" in properties: + calib_config["tokenizer_max_seq_length"] = int( + properties.pop("tokenizer_max_seq_length")) + if calib_config: + return CalibConfig(**calib_config) + return None + + +def parse_llm_kwargs(properties: dict) -> dict: + llm_kwargs = {} + if "dtype" in properties: + llm_kwargs["dtype"] = properties.pop("dtype") + if "revision" in properties: + llm_kwargs["revision"] = properties.pop("revision") + if "trust_remote_code" in properties: + llm_kwargs["trust_remote_code"] = properties.pop("trust_remote_code") + return llm_kwargs -def main(): - logging.basicConfig(stream=sys.stdout, - format="%(message)s", - level=logging.INFO) +def generate_trtllm_build_configs(properties: dict) -> dict: + quant_config = parse_quant_config(properties) + calib_config = parse_calib_config(properties) + build_config = parse_build_config(properties) + + llm_kwargs = {} + if "dtype" in properties: + llm_kwargs["dtype"] = properties.pop("dtype") + if "revision" in properties: + llm_kwargs["revision"] = properties.pop("revision") + if "trust_remote_code" in properties: + llm_kwargs["trust_remote_code"] = properties.pop("trust_remote_code") + + if quant_config: + llm_kwargs["quant_config"] = quant_config + if calib_config: + llm_kwargs["calib_config"] = calib_config + if build_config: + llm_kwargs["build_config"] = build_config + + return { + "llm_kwargs": llm_kwargs, + "dtype": llm_kwargs.get("dtype", "auto"), + "revision": llm_kwargs.get("revision", None), + "trust_remote_code": llm_kwargs.get("trust_remote_code", False) + } + + +def sanitize_serving_properties(model_dir: str) -> dict: + properties = update_kwargs_with_env_vars({}) + properties.update(load_properties(model_dir)) + properties = remove_option_from_properties(properties) + return properties + + +def copy_properties_to_compiled_model_dir(source_path: str, dest_path: str): + with open(os.path.join(source_path, 'serving.properties'), + 'r') as source, open( + os.path.join(dest_path, 'serving.properties'), 'w+') as dest: + for line in source: + if "option.model_id" in line: + continue + dest.write(line) + + +def main(): + logger.set_level('info') parser = argparse.ArgumentParser() parser.add_argument( '--properties_dir', @@ -45,25 +296,32 @@ def main(): type=str, required=True, help='local path where trt llm model repo will be created') + parser.add_argument('--model_path', + type=str, + required=False, + default=None, + help='local path to downloaded model') parser.add_argument('--tensor_parallel_degree', type=int, required=True, - help='Tensor parallel degree') + help="tensor parallel degree for compilation") parser.add_argument('--pipeline_parallel_degree', type=int, required=True, - help='Pipeline parallel degree') - parser.add_argument('--model_path', - type=str, - required=False, - default=None, - help='local path to downloaded model') + help="pipeline parallel degree for compilation") args = parser.parse_args() - properties = update_kwargs_with_env_vars({}) - properties.update(load_properties(args.properties_dir)) - create_trt_llm_repo(properties, args) + sanitized_properties = sanitize_serving_properties(args.properties_dir) + trt_build_configs = generate_trtllm_build_configs(sanitized_properties) + build_engine( + trt_build_configs, + args.model_path, + args.trt_llm_model_repo, + args.tensor_parallel_degree, + ) + copy_properties_to_compiled_model_dir(args.properties_dir, + args.trt_llm_model_repo) -if __name__ == "__main__": +if __name__ == '__main__': main() diff --git a/serving/docker/scripts/install_s5cmd.sh b/serving/docker/scripts/install_s5cmd.sh index 56e62a038..d89b939c4 100755 --- a/serving/docker/scripts/install_s5cmd.sh +++ b/serving/docker/scripts/install_s5cmd.sh @@ -4,18 +4,42 @@ set -ex ARCH=$1 -# Download custom s5cmd binary built with Go 1.25.4 +# Install jq if not available (for lmi container) +if ! command -v jq &> /dev/null; then + apt-get update && apt-get install -y jq +fi + +# Retrieve latest patched version +GO_MAJOR_MINOR="1.25" +GO_VERSION=$(curl -s https://go.dev/dl/?mode=json | jq -r ".[].version" | grep "^go${GO_MAJOR_MINOR}" | head -1 | sed 's/go//') +echo "Using Go version: ${GO_VERSION} (latest in ${GO_MAJOR_MINOR}.x series)" + if [[ $ARCH == "aarch64" ]]; then - curl -f https://publish.djl.ai/s5cmd/go1.25.4/s5cmd-linux-arm64 -L -o s5cmd + GO_ARCH="arm64" else - curl -f https://publish.djl.ai/s5cmd/go1.25.4/s5cmd-linux-amd64 -L -o s5cmd + GO_ARCH="amd64" fi -INSTALL_DIR="/opt/djl/bin" +curl -fsSL "https://go.dev/dl/go${GO_VERSION}.linux-${GO_ARCH}.tar.gz" | tar -xz -C /tmp +export PATH="/tmp/go/bin:${PATH}" +export GOPATH="/tmp/gopath" +export GOCACHE="/tmp/gocache" +# Download s5cmd release source +S5CMD_VERSION="v2.3.0" +echo "Building s5cmd ${S5CMD_VERSION}" +curl -fsSL "https://github.com/peak/s5cmd/archive/refs/tags/${S5CMD_VERSION}.tar.gz" | tar -xz -C /tmp +mv /tmp/s5cmd-${S5CMD_VERSION#v} /tmp/s5cmd +cd /tmp/s5cmd +go build -ldflags "-X github.com/peak/s5cmd/v2/version.Version=${S5CMD_VERSION}" -o s5cmd . + +# Install s5cmd +INSTALL_DIR="/opt/djl/bin" mkdir -p "${INSTALL_DIR}" mv s5cmd "${INSTALL_DIR}/" chmod +x "${INSTALL_DIR}/s5cmd" +rm -rf /tmp/go /tmp/gopath /tmp/gocache /tmp/s5cmd + export PATH="${INSTALL_DIR}:${PATH}" echo "export PATH=${INSTALL_DIR}:\$PATH" >>~/.bashrc diff --git a/serving/docker/tensorrt-llm.Dockerfile b/serving/docker/tensorrt-llm.Dockerfile index 514fc2a13..5c79f7b5a 100644 --- a/serving/docker/tensorrt-llm.Dockerfile +++ b/serving/docker/tensorrt-llm.Dockerfile @@ -9,31 +9,13 @@ # or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS" # BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for # the specific language governing permissions and limitations under the License. -ARG version=12.5.1-devel-ubuntu24.04 +ARG version=12.8.1-devel-ubuntu24.04 FROM nvidia/cuda:$version -ARG cuda_version=cu125 -ARG python_version=3.10 -ARG TORCH_VERSION=2.4.0 +ARG cuda_version=cu128 +ARG python_version=3.12 +ARG trtllm_version=1.0.0 ARG djl_version ARG djl_serving_version -ARG transformers_version=4.44.2 -ARG accelerate_version=0.32.1 -ARG tensorrtlibs_version=10.1.0 -# %2B is the url escape for the '+' character -ARG trtllm_toolkit_version=0.12.0%2Bnightly -ARG trtllm_version=v0.12.0 -ARG cuda_python_version=12.5 -ARG peft_version=0.10.0 -ARG triton_version=r24.04 -ARG trtllm_toolkit_wheel="https://publish.djl.ai/tensorrt-llm/toolkit/tensorrt_llm_toolkit-${trtllm_toolkit_version}-py3-none-any.whl" -ARG trtllm_wheel="https://publish.djl.ai/tensorrt-llm/${trtllm_version}/tensorrt_llm-0.12.0-cp310-cp310-linux_x86_64.whl" -ARG triton_toolkit_wheel="https://publish.djl.ai/tritonserver/${triton_version}/tritontoolkit-24.4-py310-none-any.whl" -ARG pydantic_version=2.6.1 -ARG modelopt_version=0.15.0 -ARG janus_version=1.0.0 -ARG pynvml_verison=11.5.0 -ARG numpy_version=1.26.4 -ARG datasets_version=2.19.1 EXPOSE 8080 @@ -67,39 +49,19 @@ RUN mkdir -p /opt/djl/conf && \ COPY config.properties /opt/djl/conf/config.properties COPY partition /opt/djl/partition - -COPY distribution[s]/ ./ -RUN mv *.deb djl-serving_all.deb || true - # Install OpenMPI and other deps ARG DEBIAN_FRONTEND=noninteractive -RUN apt-get update && apt-get install -y g++ wget unzip openmpi-bin libopenmpi-dev libffi-dev git-lfs rapidjson-dev graphviz && \ +RUN apt-get update && apt-get install -y g++ wget unzip openmpi-bin libopenmpi-dev libffi-dev git-lfs rapidjson-dev graphviz cuda-compat-12-8 && \ scripts/install_python.sh ${python_version} && \ pip3 cache purge && \ apt-get clean -y && rm -rf /var/lib/apt/lists/* -# Install PyTorch -# Qwen needs transformers_stream_generator, tiktoken and einops -RUN pip install torch==${TORCH_VERSION} transformers==${transformers_version} accelerate==${accelerate_version} peft==${peft_version} sentencepiece \ - mpi4py cuda-python==${cuda_python_version} onnx polygraphy pynvml==${pynvml_verison} datasets==${datasets_version} pydantic==${pydantic_version} scipy torchprofile bitsandbytes ninja \ - transformers_stream_generator einops tiktoken jinja2 graphviz blobfile colored h5py strenum pulp flax easydict && \ - pip3 cache purge - -# Install TensorRT and TRT-LLM Deps -RUN pip install --no-cache-dir --extra-index-url https://pypi.nvidia.com tensorrt==${tensorrtlibs_version} janus==${janus_version} nvidia-modelopt==${modelopt_version} && \ - pip install --no-deps ${trtllm_wheel} && \ - pyver=$(echo $python_version | awk -F. '{print $1$2}') && \ - pip3 cache purge +# We install TRTLLM separately because it is hosted in a different pypi index +RUN pip install tensorrt_llm==${trtllm_version} --extra-index-url https://pypi.nvidia.com && \ + pip install tensorrt_llm==${trtllm_version} uvloop ninja -# download dependencies -RUN pip install ${triton_toolkit_wheel} ${trtllm_toolkit_wheel} && \ - mkdir -p /opt/tritonserver/lib && mkdir -p /opt/tritonserver/backends/tensorrtllm && \ - curl -o /opt/tritonserver/lib/libtritonserver.so https://publish.djl.ai/tritonserver/${triton_version}/libtritonserver.so && \ - curl -o /opt/tritonserver/backends/tensorrtllm/libtriton_tensorrtllm.so https://publish.djl.ai/tensorrt-llm/${trtllm_version}/libtriton_tensorrtllm.so && \ - curl -o /opt/tritonserver/backends/tensorrtllm/libtriton_tensorrtllm_common.so https://publish.djl.ai/tensorrt-llm/${trtllm_version}/libtriton_tensorrtllm_common.so && \ - curl -o /opt/tritonserver/lib/libnvinfer_plugin_tensorrt_llm.so.10 https://publish.djl.ai/tensorrt-llm/${trtllm_version}/libnvinfer_plugin_tensorrt_llm.so.10 && \ - pip3 cache purge && \ - apt-get clean -y && rm -rf /var/lib/apt/lists/* +COPY distribution[s]/ ./ +RUN mv *.deb djl-serving_all.deb || true # Final steps RUN scripts/install_djl_serving.sh $djl_version $djl_serving_version && \ @@ -111,17 +73,13 @@ RUN scripts/install_djl_serving.sh $djl_version $djl_serving_version && \ useradd -m -d /home/djl djl && \ chown -R djl:djl /opt/djl && \ rm -rf scripts && \ - pip3 install numpy==${numpy_version} && \ pip3 cache purge && \ apt-get clean -y && rm -rf /var/lib/apt/lists/* -# Add CUDA-Compat -RUN apt-get update && apt-get install -y cuda-compat-12-4 && apt-get clean -y && rm -rf /var/lib/apt/lists/* - LABEL maintainer="djl-dev@amazon.com" LABEL dlc_major_version="1" LABEL com.amazonaws.ml.engines.sagemaker.dlc.framework.djl.tensorrtllm="true" -LABEL com.amazonaws.ml.engines.sagemaker.dlc.framework.djl.v0-34-0.tensorrtllm="true" +LABEL com.amazonaws.ml.engines.sagemaker.dlc.framework.djl.v0-33-0.tensorrtllm="true" LABEL com.amazonaws.sagemaker.capabilities.multi-models="true" LABEL com.amazonaws.sagemaker.capabilities.accept-bind-to-port="true" LABEL djl-version=$djl_version @@ -129,4 +87,4 @@ LABEL djl-serving-version=$djl_serving_version LABEL trtllm-version=$trtllm_version LABEL cuda-version=$cuda_version # To use the 535 CUDA driver -LABEL com.amazonaws.sagemaker.inference.cuda.verified_versions=12.2 +LABEL com.amazonaws.sagemaker.inference.cuda.verified_versions=12.2 \ No newline at end of file diff --git a/tests/integration/llm/prepare.py b/tests/integration/llm/prepare.py index 5dc0c7c9d..8a81a7000 100644 --- a/tests/integration/llm/prepare.py +++ b/tests/integration/llm/prepare.py @@ -542,7 +542,6 @@ "llama2-13b": { "option.model_id": "s3://djl-llm/llama-2-13b-hf/", "option.tensor_parallel_degree": 4, - "option.rolling_batch": "trtllm", }, "llama2-7b-smoothquant": { "option.model_id": "s3://djl-llm/meta-llama-Llama-2-7b-chat-hf/", @@ -550,7 +549,6 @@ "option.quantize": "smoothquant", "option.smoothquant_per_token": "True", "option.smoothquant_per_channel": "True", - "option.rolling_batch": "trtllm", }, "internlm-7b": { "option.model_id": "internlm/internlm-7b", @@ -572,7 +570,6 @@ "mistral-7b": { "option.model_id": "s3://djl-llm/mistral-7b/", "option.tensor_parallel_degree": 4, - "option.rolling_batch": "trtllm", }, "gpt-j-6b": { "option.model_id": "s3://djl-llm/gpt-j-6b/", @@ -1079,6 +1076,9 @@ def build_trtllm_handler_model(model): f"{model} is not one of the supporting handler {list(trtllm_handler_list.keys())}" ) options = trtllm_handler_list[model] + options["option.rolling_batch"] = "disable" + options["option.async_mode"] = True + options["option.entryPoint"] = "djl_python.lmi_trtllm.trtllm_async_service" # 30 minute waiting for conversion timeout options["model_loading_timeout"] = "1800" write_model_artifacts(options) diff --git a/tests/integration/tests.py b/tests/integration/tests.py index f2e0266dc..ca9360cc4 100644 --- a/tests/integration/tests.py +++ b/tests/integration/tests.py @@ -790,8 +790,7 @@ class TestVllm_p4d: def test_qwen3_vl_32b_instruct(self): with Runner('lmi', 'qwen3-vl-32b-instruct') as r: prepare.build_vllm_async_model("qwen3-vl-32b-instruct") - env = ["VLLM_ATTENTION_BACKEND=TORCH_SDPA"] - r.launch(env_vars=env) + r.launch() client.run("multimodal qwen3-vl-32b-instruct".split()) def test_llama_4_scout_17b_16e_instruct(self): diff --git a/wlm/src/main/java/ai/djl/serving/wlm/LmiConfigRecommender.java b/wlm/src/main/java/ai/djl/serving/wlm/LmiConfigRecommender.java index bf1009ec8..5280438fa 100644 --- a/wlm/src/main/java/ai/djl/serving/wlm/LmiConfigRecommender.java +++ b/wlm/src/main/java/ai/djl/serving/wlm/LmiConfigRecommender.java @@ -66,12 +66,12 @@ private static void setRollingBatch( } else if (!isTextGenerationModel(modelConfig)) { // Non text-generation use-cases are not compatible with rolling batch rollingBatch = "disable"; - } else if (isVllmEnabled(features)) { rollingBatch = "disable"; lmiProperties.setProperty("option.async_mode", "true"); } else if (isTrtLlmEnabled(features)) { - rollingBatch = "trtllm"; + rollingBatch = "disable"; + lmiProperties.setProperty("option.async_mode", "true"); } else { rollingBatch = "disable"; } @@ -143,7 +143,8 @@ private static void setIsPeftModel( } private static void setPropertiesForLora(Properties lmiProperties) { - // If option.enable_lora=true, set load_on_devices=0 and maxWorkers=1 because we only + // If option.enable_lora=true, set load_on_devices=0 and maxWorkers=1 because we + // only // support one worker thread // for LoRA. // TODO: Support multiple worker threads for LoRA. @@ -158,7 +159,8 @@ private static void setPropertiesForLora(Properties lmiProperties) { } private static void setPropertiesForStatefulSessions(Properties lmiProperties) { - // If option.enable_stateful_sessions=true, set load_on_devices=0 and maxWorkers=1 because + // If option.enable_stateful_sessions=true, set load_on_devices=0 and + // maxWorkers=1 because // we // only support one worker thread for stateful sessions. boolean enableStatefulSessions = diff --git a/wlm/src/main/java/ai/djl/serving/wlm/LmiUtils.java b/wlm/src/main/java/ai/djl/serving/wlm/LmiUtils.java index 1a2942a92..da6af1321 100644 --- a/wlm/src/main/java/ai/djl/serving/wlm/LmiUtils.java +++ b/wlm/src/main/java/ai/djl/serving/wlm/LmiUtils.java @@ -21,6 +21,7 @@ import ai.djl.util.Utils; import ai.djl.util.cuda.CudaUtils; +import com.google.gson.JsonObject; import com.google.gson.JsonSyntaxException; import com.google.gson.annotations.SerializedName; @@ -42,8 +43,6 @@ import java.util.Properties; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.stream.Stream; /** A utility class to detect optimal engine for LMI model. */ public final class LmiUtils { @@ -111,8 +110,10 @@ static boolean isTrtLlmRollingBatch(Properties properties) { } static boolean needConvertTrtLLM(ModelInfo info) { - Properties properties = info.getProperties(); - return isTrtLlmRollingBatch(properties); + String features = Utils.getEnvOrSystemProperty("SERVING_FEATURES"); + // Pytorch backend cannot be saved as engine currently in TRTLLM... + String backend = info.prop.getProperty("option.backend"); + return features != null && features.contains("trtllm") && !"pytorch".equals(backend); } static void convertTrtLLM(ModelInfo info) throws IOException { @@ -393,7 +394,7 @@ private static HuggingFaceModelConfig getHuggingFaceModelConfig(ModelInfo private static Path buildTrtLlmArtifacts( Properties prop, String modelId, String tpDegree, String ppDegree) throws IOException { logger.info("Converting model to TensorRT-LLM artifacts"); - String hash = Utils.hash(modelId + tpDegree); + String hash = Utils.hash(modelId + tpDegree + ppDegree); String download = Utils.getenv("SERVING_DOWNLOAD_DIR", null); Path parent = download == null ? Utils.getCacheDir() : Paths.get(download); Path trtLlmRepoDir = parent.resolve("trtllm").resolve(hash); @@ -467,23 +468,13 @@ static String getAWSGpuMachineType() { } static boolean isValidTrtLlmModelRepo(Path modelPath) throws IOException { - // TODO: match model name - AtomicBoolean isValid = new AtomicBoolean(); - try (Stream walk = Files.list(modelPath)) { - walk.filter(Files::isDirectory) - .forEach( - p -> { - Path confFile = p.resolve("config.pbtxt"); - // TODO: add stricter check for tokenizer - Path tokenizer = p.resolve("tokenizer_config.json"); - if (Files.isRegularFile(confFile) - && Files.isRegularFile(tokenizer)) { - logger.info("Found triton model: {}", p); - isValid.set(true); - } - }); - } - return isValid.get(); + Path configFile = modelPath.resolve("config.json"); + if (Files.exists(configFile) && Files.isRegularFile(configFile)) { + String config = Files.readString(configFile); + JsonObject json = JsonUtils.GSON.fromJson(config, JsonObject.class); + return json.has("build_config"); + } + return false; } /** @@ -663,9 +654,11 @@ public long getApproxMemoryForSingleSequence(int sequenceLength, int weightBytes */ public boolean isPeftModel() { // TODO: refactor and make this better - // Peft Configs are very different than regular configs and ideally shouldn't be clubbed + // Peft Configs are very different than regular configs and ideally shouldn't be + // clubbed // into this class. - // This method works now, as the only info we need for the peft model is whether it is + // This method works now, as the only info we need for the peft model is whether + // it is // peft return peftType != null; }