Skip to content

Commit 304413d

Browse files
authored
Solar TRT-LLM Truss (#214)
1 parent 6957617 commit 304413d

File tree

13 files changed

+1579
-0
lines changed

13 files changed

+1579
-0
lines changed
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
2+
3+
# TRTLLM
4+
5+
### Overview
6+
This Truss adds support for TRT-LLM engines via Triton Inference Server. TRT-LLM is a highly-performant language model runtime. We leverage the C++ runtime to take advantage of in-flight batching (aka continous batching).
7+
8+
### Prerequisites
9+
10+
To use this Truss, your engine must be built with in-flight batching support. Refer to your architecture-specific `build.py` re: how to build with in-flight-batching support.
11+
12+
### Config
13+
14+
This Truss is primarily config driven. This means that most settings you'll need to edit are located in the `config.yaml`. These settings are all located underneath the `model_metadata` key.
15+
16+
- `tensor_parallelism` (int): If you built your model with tensor parallelism support, you'll need to set this value with the same value used during the build engine step. This value should be the same as the number of GPUs in the `resources` section.
17+
18+
*Pipeline parallelism is not supported in this version but will be added later. As noted from Nvidia, pipeline parallelism reduces the need for high-bandwidth communication but may incur load-balancing issues and may be less efficient in terms of GPU utilization.*
19+
20+
- `engine_repository` (str): We expect engines to be uploaded to Huggingface with a flat directory structure (i.e the engine and associated files are not underneath a folder structure). This value is the full `{org_name}/{repo_name}` string. Engines can be private or public.
21+
22+
- `tokenizer_repository` (str): Engines do not come bundled with their own tokenizer. This is the Huggingface repository where we can find a tokenizer. Tokenizers can be private or public.
23+
24+
If the engine and repository tokenizers are private, you'll need to update the `secrets` section of the `config.yaml` as follows:
25+
26+
```
27+
secrets:
28+
hf_access_token: "my_hf_api_key"
29+
```
30+
31+
### Performance
32+
33+
TRT-LLM engines are designed to be highly performant. Once your Truss has been deployed, you may find that you're not fully utilizing the GPU. The following are levers to improve performance but require trial-and-error to identify appropriates. All of these values live inside the `config.pbtxt` for a given ensemble model.
34+
35+
#### Preprocessing / Postprocessing
36+
37+
```
38+
instance_group [
39+
{
40+
count: 1
41+
kind: KIND_CPU
42+
}
43+
]
44+
```
45+
By default, we load 1 instance of the pre/post models. If you find that the tokenizer is a bottleneck, increasing the `count` variable here will load more replicas of these models and Triton will automatically load balance across model instances.
46+
47+
### Tensorrt LLM
48+
```
49+
parameters: {
50+
key: "max_tokens_in_paged_kv_cache"
51+
value: {
52+
string_value: "10000"
53+
}
54+
}
55+
```
56+
By default, we set the `max_tokens_in_paged_kv_cache` to 10000. For a 13B model on 1 A100 with a batch size of 8, we have over 60GB of GPU memory left over. We can increase this value to 100k comfortably and allow for more tokens in the KV cache. Your mileage will vary based on the size of your model and the hardware you're running on.
57+
58+
```
59+
parameters: {
60+
key: "kv_cache_free_gpu_mem_fraction"
61+
value: {
62+
string_value: "0.1"
63+
}
64+
}
65+
```
66+
By default, if `max_tokens_in_paged_kv_cache` is unset, Triton Inference Server will attempt to preallocate `kv_cache_free_gpu_mem_fraction` fraction of free gpu memory for the KV cache.
67+
68+
```
69+
parameters: {
70+
key: "max_num_sequences"
71+
value: {
72+
string_value: "64"
73+
}
74+
}
75+
```
76+
The `max_num_sequences` param is the maximum numbers of requests that the inference server can maintain state for at a given time (state = KV cache + decoder state).
77+
See this [comment](https://github.com/NVIDIA/TensorRT-LLM/issues/65#issuecomment-1774332446) for more details. Setting this value higher allows for more parallel processing but uses more GPU memory.
78+
79+
### API
80+
81+
We expect requests will the following information:
82+
83+
84+
- ```prompt``` (str): The prompt you'd like to complete
85+
- ```max_tokens``` (int, default: 50): The max token count. This includes the number of tokens in your prompt so if this value is less than your prompt, you'll just recieve a truncated version of the prompt.
86+
- ```beam_width``` (int, default:50): The number of beams to compute. This must be 1 for this version of TRT-LLM. Inflight-batching does not support beams > 1.
87+
- ```bad_words_list``` (list, default:[]): A list of words to not include in generated output.
88+
- ```stop_words_list``` (list, default:[]): A list of words to stop generation upon encountering.
89+
- ```repetition_penalty``` (float, defualt: 1.0): A repetition penalty to incentivize not repeating tokens.
90+
91+
This Truss will stream responses back. Responses will be buffered chunks of text.
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
base_image:
2+
image: docker.io/baseten/trtllm-server:r23.12_baseten_v0.9.0.dev2024022000
3+
python_executable_path: /usr/bin/python3
4+
description: Generate text from a prompt with this seven billion parameter language
5+
model.
6+
environment_variables:
7+
HF_HUB_ENABLE_HF_TRANSFER: true
8+
external_package_dirs: []
9+
model_metadata:
10+
avatar_url: https://cdn.baseten.co/production/static/explore/meta.png
11+
cover_image_url: https://cdn.baseten.co/production/static/explore/llama.png
12+
engine_repository: baseten/solar10.7
13+
tags:
14+
- text-generation
15+
tensor_parallelism: 1
16+
tokenizer_repository: upstage/SOLAR-10.7B-Instruct-v1.0
17+
model_name: Solar 10.7B
18+
python_version: py311
19+
requirements:
20+
- tritonclient[all]
21+
- hf_transfer
22+
resources:
23+
accelerator: H100
24+
use_gpu: true
25+
runtime:
26+
predict_concurrency: 256
27+
secrets: {}
28+
system_packages: []
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
*.7z filter=lfs diff=lfs merge=lfs -text
2+
*.arrow filter=lfs diff=lfs merge=lfs -text
3+
*.bin filter=lfs diff=lfs merge=lfs -text
4+
*.bz2 filter=lfs diff=lfs merge=lfs -text
5+
*.ckpt filter=lfs diff=lfs merge=lfs -text
6+
*.ftz filter=lfs diff=lfs merge=lfs -text
7+
*.gz filter=lfs diff=lfs merge=lfs -text
8+
*.h5 filter=lfs diff=lfs merge=lfs -text
9+
*.joblib filter=lfs diff=lfs merge=lfs -text
10+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
11+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
12+
*.model filter=lfs diff=lfs merge=lfs -text
13+
*.msgpack filter=lfs diff=lfs merge=lfs -text
14+
*.npy filter=lfs diff=lfs merge=lfs -text
15+
*.npz filter=lfs diff=lfs merge=lfs -text
16+
*.onnx filter=lfs diff=lfs merge=lfs -text
17+
*.ot filter=lfs diff=lfs merge=lfs -text
18+
*.parquet filter=lfs diff=lfs merge=lfs -text
19+
*.pb filter=lfs diff=lfs merge=lfs -text
20+
*.pickle filter=lfs diff=lfs merge=lfs -text
21+
*.pkl filter=lfs diff=lfs merge=lfs -text
22+
*.pt filter=lfs diff=lfs merge=lfs -text
23+
*.pth filter=lfs diff=lfs merge=lfs -text
24+
*.rar filter=lfs diff=lfs merge=lfs -text
25+
*.safetensors filter=lfs diff=lfs merge=lfs -text
26+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27+
*.tar.* filter=lfs diff=lfs merge=lfs -text
28+
*.tar filter=lfs diff=lfs merge=lfs -text
29+
*.tflite filter=lfs diff=lfs merge=lfs -text
30+
*.tgz filter=lfs diff=lfs merge=lfs -text
31+
*.wasm filter=lfs diff=lfs merge=lfs -text
32+
*.xz filter=lfs diff=lfs merge=lfs -text
33+
*.zip filter=lfs diff=lfs merge=lfs -text
34+
*.zst filter=lfs diff=lfs merge=lfs -text
35+
*tfevents* filter=lfs diff=lfs merge=lfs -text
36+
gpt_float16_tp2_rank0.engine filter=lfs diff=lfs merge=lfs -text
37+
gpt_float16_tp2_rank1.engine filter=lfs diff=lfs merge=lfs -text

llama/solar-10b-trt-llm/model/__init__.py

Whitespace-only changes.
Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
import os
2+
from itertools import count
3+
from pathlib import Path
4+
from threading import Thread
5+
6+
import numpy as np
7+
from client import TritonClient, UserData
8+
from transformers import AutoTokenizer
9+
from utils import download_engine, prepare_grpc_tensor, server_loaded
10+
11+
TRITON_MODEL_REPOSITORY_PATH = Path("/packages/inflight_batcher_llm/")
12+
13+
14+
class Model:
15+
def __init__(self, **kwargs):
16+
self._data_dir = kwargs["data_dir"]
17+
self._config = kwargs["config"]
18+
self._secrets = kwargs["secrets"]
19+
self._request_id_counter = count(start=1)
20+
self.triton_client = None
21+
self.tokenizer = None
22+
self.uses_openai_api = (
23+
"openai-compatible" in self._config["model_metadata"]["tags"]
24+
)
25+
26+
def load(self):
27+
tensor_parallel_count = self._config["model_metadata"].get(
28+
"tensor_parallelism", 1
29+
)
30+
pipeline_parallel_count = self._config["model_metadata"].get(
31+
"pipeline_parallelism", 1
32+
)
33+
if "hf_access_token" in self._secrets._base_secrets.keys():
34+
hf_access_token = self._secrets["hf_access_token"]
35+
else:
36+
hf_access_token = None
37+
is_external_engine_repo = "engine_repository" in self._config["model_metadata"]
38+
39+
# Instantiate TritonClient
40+
self.triton_client = TritonClient(
41+
data_dir=self._data_dir,
42+
model_repository_dir=TRITON_MODEL_REPOSITORY_PATH,
43+
parallel_count=tensor_parallel_count * pipeline_parallel_count,
44+
)
45+
46+
# Download model from Hugging Face Hub if specified
47+
if is_external_engine_repo:
48+
if not server_loaded():
49+
download_engine(
50+
engine_repository=self._config["model_metadata"][
51+
"engine_repository"
52+
],
53+
fp=self._data_dir,
54+
auth_token=hf_access_token,
55+
)
56+
57+
# Load Triton Server and model
58+
tokenizer_repository = self._config["model_metadata"]["tokenizer_repository"]
59+
env = {"triton_tokenizer_repository": tokenizer_repository}
60+
if hf_access_token is not None:
61+
env["HUGGING_FACE_HUB_TOKEN"] = hf_access_token
62+
63+
self.triton_client.load_server_and_model(env=env)
64+
65+
# setup eos token
66+
self.tokenizer = AutoTokenizer.from_pretrained(
67+
tokenizer_repository, token=hf_access_token
68+
)
69+
self.eos_token_id = self.tokenizer.eos_token_id
70+
71+
def predict(self, model_input):
72+
user_data = UserData()
73+
model_name = "ensemble"
74+
stream_uuid = str(os.getpid()) + str(next(self._request_id_counter))
75+
76+
if self.uses_openai_api:
77+
prompt = self.tokenizer.apply_chat_template(
78+
model_input.get("messages"),
79+
tokenize=False,
80+
)
81+
else:
82+
prompt = model_input.get("prompt")
83+
84+
max_tokens = model_input.get("max_tokens", 50)
85+
beam_width = model_input.get("beam_width", 1)
86+
bad_words_list = model_input.get("bad_words_list", [""])
87+
stop_words_list = model_input.get("stop_words_list", [""])
88+
repetition_penalty = model_input.get("repetition_penalty", 1.0)
89+
ignore_eos = model_input.get("ignore_eos", False)
90+
stream = model_input.get("stream", True)
91+
92+
input0 = [[prompt]]
93+
input0_data = np.array(input0).astype(object)
94+
output0_len = np.ones_like(input0).astype(np.uint32) * max_tokens
95+
bad_words_list = np.array([bad_words_list], dtype=object)
96+
stop_words_list = np.array([stop_words_list], dtype=object)
97+
stream_data = np.array([[stream]], dtype=bool)
98+
beam_width_data = np.array([[beam_width]], dtype=np.uint32)
99+
repetition_penalty_data = np.array([[repetition_penalty]], dtype=np.float32)
100+
101+
inputs = [
102+
prepare_grpc_tensor("text_input", input0_data),
103+
prepare_grpc_tensor("max_tokens", output0_len),
104+
prepare_grpc_tensor("bad_words", bad_words_list),
105+
prepare_grpc_tensor("stop_words", stop_words_list),
106+
prepare_grpc_tensor("stream", stream_data),
107+
prepare_grpc_tensor("beam_width", beam_width_data),
108+
prepare_grpc_tensor("repetition_penalty", repetition_penalty_data),
109+
]
110+
111+
if not ignore_eos:
112+
end_id_data = np.array([[self.eos_token_id]], dtype=np.uint32)
113+
inputs.append(prepare_grpc_tensor("end_id", end_id_data))
114+
else:
115+
# do nothing, trt-llm by default doesn't stop on `eos`
116+
pass
117+
118+
# Start GRPC stream in a separate thread
119+
stream_thread = Thread(
120+
target=self.triton_client.start_grpc_stream,
121+
args=(user_data, model_name, inputs, stream_uuid),
122+
)
123+
stream_thread.start()
124+
125+
def generate():
126+
# Yield results from the queue
127+
for i in TritonClient.stream_predict(user_data):
128+
yield i
129+
130+
# Clean up GRPC stream and thread
131+
self.triton_client.stop_grpc_stream(stream_uuid, stream_thread)
132+
133+
if stream:
134+
return generate()
135+
else:
136+
if self.uses_openai_api:
137+
return "".join(generate())
138+
else:
139+
return {"text": "".join(generate())}

0 commit comments

Comments
 (0)