Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
FROM nvcr.io/nvidia/tritonserver:25.10-trtllm-python-py3

# Environment variables
ENV PYTHONPATH=/usr/local/lib/python3.12/dist-packages:$PYTHONPATH
ENV PYTHONDONTWRITEBYTECODE=1
ENV DEBIAN_FRONTEND=noninteractive
ENV HF_HOME=/persistent-storage/models
ENV TORCH_CUDA_ARCH_LIST=8.6

# Install system dependencies
RUN apt-get update && apt-get install -y \
git \
git-lfs \
&& rm -rf /var/lib/apt/lists/*

WORKDIR /app

# Install Python dependencies
RUN pip install --break-system-packages \
huggingface_hub \
transformers \
|| true

# Create required directories
RUN mkdir -p \
/app/model_repository/llama3_2/1 \
/persistent-storage/models \
/persistent-storage/engines

# Copy application files
COPY --chmod=755 download_model.py start_triton.sh /app/
COPY model.py /app/model_repository/llama3_2/1/
COPY config.pbtxt /app/model_repository/llama3_2/

# Expose Triton ports
EXPOSE 8000 8001 8002

# Start Triton server
CMD ["/app/start_triton.sh"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
[cerebrium.deployment]
name = "tensorrt-triton-demo"
python_version = "3.12"
disable_auth = true
include = ['./*', 'cerebrium.toml']
exclude = ['.*']
deployment_initialization_timeout = 830

[cerebrium.hardware]
cpu = 4.0
memory = 40.0
compute = "AMPERE_A10"
gpu_count = 1
provider = "aws"
region = "us-east-1"

[cerebrium.scaling]
min_replicas = 0
max_replicas = 2
cooldown = 60
replica_concurrency = 5
scaling_metric = "concurrency_utilization"

[cerebrium.runtime.custom]
port = 8000
healthcheck_endpoint = "/v2/health/live"
readycheck_endpoint = "/v2/health/ready"
dockerfile_path = "./Dockerfile"
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
name: "llama3_2"
backend: "python"
max_batch_size: 0

instance_group [
{
count: 1
kind: KIND_GPU
}
]

input [
{
name: "text_input"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

prompt?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is only for configuring the shapes and other params. System prompt would have to come directly in the model.py

data_type: TYPE_STRING
dims: [ 1 ]
},
{
name: "max_tokens"
data_type: TYPE_INT32
dims: [ 1 ]
optional: true
},
{
name: "temperature"
data_type: TYPE_FP32
dims: [ 1 ]
optional: true
},
{
name: "top_p"
data_type: TYPE_FP32
dims: [ 1 ]
optional: true
}
]

output [
{
name: "text_output"
data_type: TYPE_STRING
dims: [ 1 ]
}
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#!/usr/bin/env python3
"""
Download HuggingFace model to persistent storage.
Only downloads if model doesn't already exist.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you run this with cerebrium run or does it run on deploy?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought about this but I'm also not feeling great on reinstalling packages (once through the docker file and once through toml)

"""

import os
from pathlib import Path
from huggingface_hub import snapshot_download, login

MODEL_ID = "meta-llama/Llama-3.2-3B-Instruct"
MODEL_DIR = Path("/persistent-storage/models") / MODEL_ID


def download_model():
"""Download model from HuggingFace if not already present."""
hf_token = os.environ.get("HF_AUTH_TOKEN")

if not hf_token:
print("WARNING: HF_AUTH_TOKEN not set, model download may fail")
return

if MODEL_DIR.exists() and any(MODEL_DIR.iterdir()):
print("✓ Model already exists")
return

print("Downloading model from HuggingFace...")
login(token=hf_token)
snapshot_download(
MODEL_ID,
local_dir=str(MODEL_DIR),
token=hf_token
)
print("✓ Model downloaded successfully")


if __name__ == "__main__":
download_model()
151 changes: 151 additions & 0 deletions 2-advanced-concepts/6-faster-inference-with-triton-tensorrt/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
"""
Triton Python Backend for TensorRT-LLM.

This module implements a Triton Inference Server Python backend that uses
TensorRT-LLM's PyTorch backend for optimized LLM inference.
"""

import numpy as np
import triton_python_backend_utils as pb_utils
import torch
from tensorrt_llm import LLM, SamplingParams, BuildConfig
from tensorrt_llm.plugin.plugin import PluginConfig
from transformers import AutoTokenizer

# Model configuration
MODEL_ID = "meta-llama/Llama-3.2-3B-Instruct"
MODEL_DIR = f"/persistent-storage/models/{MODEL_ID}"


class TritonPythonModel:
"""
Triton Python Backend model for TensorRT-LLM inference.

This class handles model initialization, inference requests, and cleanup.
"""

def initialize(self, args):
"""
Initialize the model using TensorRT-LLM's PyTorch backend.

This method is called once when the model is loaded. It:
1. Loads the tokenizer from HuggingFace
2. Initializes TensorRT-LLM with PyTorch backend (loads model directly)

Args:
args: Dictionary containing model configuration from Triton
"""
print("Loading tokenizer...")
self.tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)

print("Initializing TensorRT-LLM with PyTorch backend...")


plugin_config = PluginConfig.from_dict({
"paged_kv_cache": True, # Efficient memory usage for KV cache
})

# Configure build parameters
build_config = BuildConfig(
plugin_config=plugin_config,
max_input_len=4096, # Maximum input sequence length
max_batch_size=1, # Batch size per request
)

self.llm = LLM(
model=MODEL_DIR, # HuggingFace model path
build_config=build_config,
tensor_parallel_size=torch.cuda.device_count(),
)
print("✓ Model ready")

def execute(self, requests):
"""
Execute inference requests.

Processes one or more inference requests, generating text responses
using the TensorRT-LLM model.

Args:
requests: List of InferenceRequest objects from Triton

Returns:
List of InferenceResponse objects with generated text
"""
responses = []

for request in requests:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't look that efficient

try:
# Extract input text
input_tensor = pb_utils.get_input_tensor_by_name(request, "text_input")
text = input_tensor.as_numpy()[0].decode('utf-8')

# Extract optional parameters (with defaults)
max_tokens = 1024
temperature = 0.8
top_p = 0.95

max_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "max_tokens")
if max_tokens_tensor is not None:
max_tokens = int(max_tokens_tensor.as_numpy()[0])

temp_tensor = pb_utils.get_input_tensor_by_name(request, "temperature")
if temp_tensor is not None:
temperature = float(temp_tensor.as_numpy()[0])

top_p_tensor = pb_utils.get_input_tensor_by_name(request, "top_p")
if top_p_tensor is not None:
top_p = float(top_p_tensor.as_numpy()[0])

# Format prompt using Llama chat template
messages = [{"role": "user", "content": text}]
prompt = self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)

# Configure sampling parameters
sampling_params = SamplingParams(
temperature=temperature,
top_p=top_p,
max_tokens=max_tokens,
)

# Generate text
output = self.llm.generate(prompt, sampling_params)
generated_text = output.outputs[0].text

# Create response tensor
output_tensor = pb_utils.Tensor(
"text_output",
np.array([generated_text.encode('utf-8')], dtype=object)
)

# Create inference response
inference_response = pb_utils.InferenceResponse(
output_tensors=[output_tensor]
)
responses.append(inference_response)

except Exception as e:
# Handle errors gracefully
print(f"Error processing request: {e}")
error_response = pb_utils.InferenceResponse(
output_tensors=[],
error=pb_utils.TritonError(f"Error: {str(e)}")
)
responses.append(error_response)

return responses

def finalize(self):
"""
Cleanup when model is being unloaded.

Shuts down the TensorRT-LLM engine and clears GPU memory.
"""
if hasattr(self, 'llm'):
self.llm.shutdown()
torch.cuda.empty_cache()

Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#!/bin/bash
set -e

# Download model if not already present
echo "Checking for model..."
python3 /app/download_model.py

# Start Triton Inference Server
echo "Starting Triton Inference Server..."
exec tritonserver \
--model-repository=/app/model_repository \
--http-port=8000 \
--grpc-port=8001 \
--metrics-port=8002