-
Notifications
You must be signed in to change notification settings - Fork 75
Add Triton + TensorRT-LLM inference example #86
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,39 @@ | ||
| FROM nvcr.io/nvidia/tritonserver:25.10-trtllm-python-py3 | ||
|
|
||
| # Environment variables | ||
| ENV PYTHONPATH=/usr/local/lib/python3.12/dist-packages:$PYTHONPATH | ||
| ENV PYTHONDONTWRITEBYTECODE=1 | ||
| ENV DEBIAN_FRONTEND=noninteractive | ||
| ENV HF_HOME=/persistent-storage/models | ||
| ENV TORCH_CUDA_ARCH_LIST=8.6 | ||
|
|
||
| # Install system dependencies | ||
| RUN apt-get update && apt-get install -y \ | ||
| git \ | ||
| git-lfs \ | ||
| && rm -rf /var/lib/apt/lists/* | ||
|
|
||
| WORKDIR /app | ||
|
|
||
| # Install Python dependencies | ||
| RUN pip install --break-system-packages \ | ||
| huggingface_hub \ | ||
| transformers \ | ||
| || true | ||
|
|
||
| # Create required directories | ||
| RUN mkdir -p \ | ||
| /app/model_repository/llama3_2/1 \ | ||
| /persistent-storage/models \ | ||
| /persistent-storage/engines | ||
|
|
||
| # Copy application files | ||
| COPY --chmod=755 download_model.py start_triton.sh /app/ | ||
| COPY model.py /app/model_repository/llama3_2/1/ | ||
| COPY config.pbtxt /app/model_repository/llama3_2/ | ||
|
|
||
| # Expose Triton ports | ||
| EXPOSE 8000 8001 8002 | ||
|
|
||
| # Start Triton server | ||
| CMD ["/app/start_triton.sh"] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,28 @@ | ||
| [cerebrium.deployment] | ||
| name = "tensorrt-triton-demo" | ||
| python_version = "3.12" | ||
| disable_auth = true | ||
| include = ['./*', 'cerebrium.toml'] | ||
| exclude = ['.*'] | ||
| deployment_initialization_timeout = 830 | ||
|
|
||
| [cerebrium.hardware] | ||
| cpu = 4.0 | ||
| memory = 40.0 | ||
| compute = "AMPERE_A10" | ||
| gpu_count = 1 | ||
| provider = "aws" | ||
| region = "us-east-1" | ||
|
|
||
| [cerebrium.scaling] | ||
| min_replicas = 0 | ||
| max_replicas = 2 | ||
| cooldown = 60 | ||
| replica_concurrency = 5 | ||
| scaling_metric = "concurrency_utilization" | ||
|
|
||
| [cerebrium.runtime.custom] | ||
| port = 8000 | ||
| healthcheck_endpoint = "/v2/health/live" | ||
| readycheck_endpoint = "/v2/health/ready" | ||
| dockerfile_path = "./Dockerfile" |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,44 @@ | ||
| name: "llama3_2" | ||
| backend: "python" | ||
| max_batch_size: 0 | ||
|
|
||
| instance_group [ | ||
| { | ||
| count: 1 | ||
| kind: KIND_GPU | ||
| } | ||
| ] | ||
|
|
||
| input [ | ||
| { | ||
| name: "text_input" | ||
| data_type: TYPE_STRING | ||
| dims: [ 1 ] | ||
| }, | ||
| { | ||
| name: "max_tokens" | ||
| data_type: TYPE_INT32 | ||
| dims: [ 1 ] | ||
| optional: true | ||
| }, | ||
| { | ||
| name: "temperature" | ||
| data_type: TYPE_FP32 | ||
| dims: [ 1 ] | ||
| optional: true | ||
| }, | ||
| { | ||
| name: "top_p" | ||
| data_type: TYPE_FP32 | ||
| dims: [ 1 ] | ||
| optional: true | ||
| } | ||
| ] | ||
|
|
||
| output [ | ||
| { | ||
| name: "text_output" | ||
| data_type: TYPE_STRING | ||
| dims: [ 1 ] | ||
| } | ||
| ] | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,38 @@ | ||
| #!/usr/bin/env python3 | ||
| """ | ||
| Download HuggingFace model to persistent storage. | ||
| Only downloads if model doesn't already exist. | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you run this with cerebrium run or does it run on deploy?
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I thought about this but I'm also not feeling great on reinstalling packages (once through the docker file and once through toml) |
||
| """ | ||
|
|
||
| import os | ||
| from pathlib import Path | ||
| from huggingface_hub import snapshot_download, login | ||
|
|
||
| MODEL_ID = "meta-llama/Llama-3.2-3B-Instruct" | ||
| MODEL_DIR = Path("/persistent-storage/models") / MODEL_ID | ||
|
|
||
|
|
||
| def download_model(): | ||
| """Download model from HuggingFace if not already present.""" | ||
| hf_token = os.environ.get("HF_AUTH_TOKEN") | ||
|
|
||
| if not hf_token: | ||
| print("WARNING: HF_AUTH_TOKEN not set, model download may fail") | ||
| return | ||
|
|
||
| if MODEL_DIR.exists() and any(MODEL_DIR.iterdir()): | ||
| print("✓ Model already exists") | ||
| return | ||
|
|
||
| print("Downloading model from HuggingFace...") | ||
| login(token=hf_token) | ||
| snapshot_download( | ||
| MODEL_ID, | ||
| local_dir=str(MODEL_DIR), | ||
| token=hf_token | ||
| ) | ||
| print("✓ Model downloaded successfully") | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| download_model() | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,151 @@ | ||
| """ | ||
| Triton Python Backend for TensorRT-LLM. | ||
|
|
||
| This module implements a Triton Inference Server Python backend that uses | ||
| TensorRT-LLM's PyTorch backend for optimized LLM inference. | ||
| """ | ||
|
|
||
| import numpy as np | ||
| import triton_python_backend_utils as pb_utils | ||
| import torch | ||
| from tensorrt_llm import LLM, SamplingParams, BuildConfig | ||
| from tensorrt_llm.plugin.plugin import PluginConfig | ||
| from transformers import AutoTokenizer | ||
|
|
||
| # Model configuration | ||
| MODEL_ID = "meta-llama/Llama-3.2-3B-Instruct" | ||
| MODEL_DIR = f"/persistent-storage/models/{MODEL_ID}" | ||
|
|
||
|
|
||
| class TritonPythonModel: | ||
| """ | ||
| Triton Python Backend model for TensorRT-LLM inference. | ||
|
|
||
| This class handles model initialization, inference requests, and cleanup. | ||
| """ | ||
|
|
||
| def initialize(self, args): | ||
| """ | ||
| Initialize the model using TensorRT-LLM's PyTorch backend. | ||
|
|
||
| This method is called once when the model is loaded. It: | ||
| 1. Loads the tokenizer from HuggingFace | ||
| 2. Initializes TensorRT-LLM with PyTorch backend (loads model directly) | ||
|
|
||
| Args: | ||
| args: Dictionary containing model configuration from Triton | ||
| """ | ||
| print("Loading tokenizer...") | ||
| self.tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR) | ||
|
|
||
| print("Initializing TensorRT-LLM with PyTorch backend...") | ||
|
|
||
|
|
||
| plugin_config = PluginConfig.from_dict({ | ||
| "paged_kv_cache": True, # Efficient memory usage for KV cache | ||
| }) | ||
|
|
||
| # Configure build parameters | ||
| build_config = BuildConfig( | ||
| plugin_config=plugin_config, | ||
| max_input_len=4096, # Maximum input sequence length | ||
| max_batch_size=1, # Batch size per request | ||
| ) | ||
|
|
||
| self.llm = LLM( | ||
| model=MODEL_DIR, # HuggingFace model path | ||
| build_config=build_config, | ||
| tensor_parallel_size=torch.cuda.device_count(), | ||
| ) | ||
| print("✓ Model ready") | ||
|
|
||
| def execute(self, requests): | ||
| """ | ||
| Execute inference requests. | ||
|
|
||
| Processes one or more inference requests, generating text responses | ||
| using the TensorRT-LLM model. | ||
|
|
||
| Args: | ||
| requests: List of InferenceRequest objects from Triton | ||
|
|
||
| Returns: | ||
| List of InferenceResponse objects with generated text | ||
| """ | ||
| responses = [] | ||
|
|
||
| for request in requests: | ||
|
||
| try: | ||
| # Extract input text | ||
| input_tensor = pb_utils.get_input_tensor_by_name(request, "text_input") | ||
| text = input_tensor.as_numpy()[0].decode('utf-8') | ||
|
|
||
| # Extract optional parameters (with defaults) | ||
| max_tokens = 1024 | ||
| temperature = 0.8 | ||
| top_p = 0.95 | ||
|
|
||
| max_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "max_tokens") | ||
| if max_tokens_tensor is not None: | ||
| max_tokens = int(max_tokens_tensor.as_numpy()[0]) | ||
|
|
||
| temp_tensor = pb_utils.get_input_tensor_by_name(request, "temperature") | ||
| if temp_tensor is not None: | ||
| temperature = float(temp_tensor.as_numpy()[0]) | ||
|
|
||
| top_p_tensor = pb_utils.get_input_tensor_by_name(request, "top_p") | ||
| if top_p_tensor is not None: | ||
| top_p = float(top_p_tensor.as_numpy()[0]) | ||
|
|
||
| # Format prompt using Llama chat template | ||
| messages = [{"role": "user", "content": text}] | ||
| prompt = self.tokenizer.apply_chat_template( | ||
| messages, | ||
| tokenize=False, | ||
| add_generation_prompt=True | ||
| ) | ||
|
|
||
| # Configure sampling parameters | ||
| sampling_params = SamplingParams( | ||
| temperature=temperature, | ||
| top_p=top_p, | ||
| max_tokens=max_tokens, | ||
| ) | ||
|
|
||
| # Generate text | ||
| output = self.llm.generate(prompt, sampling_params) | ||
| generated_text = output.outputs[0].text | ||
|
|
||
| # Create response tensor | ||
| output_tensor = pb_utils.Tensor( | ||
| "text_output", | ||
| np.array([generated_text.encode('utf-8')], dtype=object) | ||
| ) | ||
|
|
||
| # Create inference response | ||
| inference_response = pb_utils.InferenceResponse( | ||
| output_tensors=[output_tensor] | ||
| ) | ||
| responses.append(inference_response) | ||
|
|
||
| except Exception as e: | ||
| # Handle errors gracefully | ||
| print(f"Error processing request: {e}") | ||
| error_response = pb_utils.InferenceResponse( | ||
| output_tensors=[], | ||
| error=pb_utils.TritonError(f"Error: {str(e)}") | ||
| ) | ||
| responses.append(error_response) | ||
|
|
||
| return responses | ||
|
|
||
| def finalize(self): | ||
| """ | ||
| Cleanup when model is being unloaded. | ||
|
|
||
| Shuts down the TensorRT-LLM engine and clears GPU memory. | ||
| """ | ||
| if hasattr(self, 'llm'): | ||
| self.llm.shutdown() | ||
| torch.cuda.empty_cache() | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,14 @@ | ||
| #!/bin/bash | ||
| set -e | ||
|
|
||
| # Download model if not already present | ||
| echo "Checking for model..." | ||
| python3 /app/download_model.py | ||
|
|
||
| # Start Triton Inference Server | ||
| echo "Starting Triton Inference Server..." | ||
| exec tritonserver \ | ||
| --model-repository=/app/model_repository \ | ||
| --http-port=8000 \ | ||
| --grpc-port=8001 \ | ||
| --metrics-port=8002 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
prompt?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is only for configuring the shapes and other params. System prompt would have to come directly in the model.py