Skip to content

Commit c7d2151

Browse files
pankajroarkaspctu
andauthored
Move mythomax model from sq to fp8 (#230)
Co-authored-by: Abu Qader <[email protected]>
1 parent 5e01087 commit c7d2151

File tree

14 files changed

+45
-22
lines changed

14 files changed

+45
-22
lines changed

llama/mythomax-13b-trt-sq/README.md renamed to llama/mythomax-13b-trt-fp8/README.md

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,9 @@
22

33
Based on https://huggingface.co/Gryphe/MythoMax-L2-13b
44

5-
int8 quantized using smoothquant using `https://huggingface.co/datasets/royallab/PIPPA-cleaned` dataset
6-
smoothquant alpha value used: 0.5
5+
fp8 quantized using `https://huggingface.co/datasets/royallab/PIPPA-cleaned` dataset
76

8-
TensorRT-LLM engine is here: https://huggingface.co/baseten/Gryphe_MythoMax-L2-13b_v0.7.1_H100-80GB-HBM3_2ff724
7+
TensorRT-LLM engine is here: https://huggingface.co/baseten/Gryphe_MythoMax-L2-13b_v0.7.1_H100-80GB-HBM3_fp8
98

109
Max input tokens: 3000
1110
Max output tokens: 2000
@@ -25,15 +24,15 @@ First, clone this repository:
2524

2625
```sh
2726
git clone https://github.com/basetenlabs/truss-examples/
28-
cd llama/mythomax-13b-trt-sq
27+
cd llama/mythomax-13b-trt-fp8
2928
```
3029

3130
Before deployment:
3231

3332
1. Make sure you have a [Baseten account](https://app.baseten.co/signup) and [API key](https://app.baseten.co/settings/account/api_keys).
3433
2. Install the latest version of Truss: `pip install --upgrade truss`
3534

36-
With `mythomax-13b-trt-sq` as your working directory, you can deploy the model with:
35+
With `mythomax-13b-trt-fp8` as your working directory, you can deploy the model with:
3736

3837
```sh
3938
truss push --publish
File renamed without changes.

llama/mythomax-13b-trt-sq/config.yaml renamed to llama/mythomax-13b-trt-fp8/config.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
base_image:
2-
image: docker.io/baseten/trtllm-server:r23.12_baseten_v0.7.1
2+
image: docker.io/baseten/trtllm-server:r23.12_baseten_v0.9.0.dev2024022000
33
python_executable_path: /usr/bin/python3
44
description: Generate text from a prompt with this seven billion parameter language
55
model.
@@ -9,7 +9,7 @@ external_package_dirs: []
99
model_metadata:
1010
avatar_url: https://cdn.baseten.co/production/static/explore/meta.png
1111
cover_image_url: https://cdn.baseten.co/production/static/explore/llama.png
12-
engine_repository: baseten/Gryphe_MythoMax-L2-13b_v0.7.1_H100-80GB-HBM3_2ff724
12+
engine_repository: baseten/Gryphe_MythoMax-L2-13b_v0.7.1_H100-80GB-HBM3_fp8
1313
example_model_input:
1414
max_tokens: 1024
1515
prompt: What's the meaning of life?
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.

llama/mythomax-13b-trt-sq/packages/inflight_batcher_llm/postprocessing/1/model.py renamed to llama/mythomax-13b-trt-fp8/packages/inflight_batcher_llm/postprocessing/1/model.py

Lines changed: 39 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@
3232
import triton_python_backend_utils as pb_utils
3333
from transformers import AutoTokenizer, LlamaTokenizer, T5Tokenizer
3434

35+
INVALID_UNICODE_CHAR = "�"
36+
3537

3638
class TritonPythonModel:
3739
"""Your Python model must use the same class name. Every Python model
@@ -55,7 +57,8 @@ def initialize(self, args):
5557
"""
5658
# Parse model configs
5759
model_config = json.loads(args["model_config"])
58-
tokenizer_dir = os.environ["triton_tokenizer_repository"]
60+
# NOTE: Keep this in sync with the truss model.py variable
61+
tokenizer_dir = os.environ["TRITON_TOKENIZER_REPOSITORY"]
5962
tokenizer_type = model_config["parameters"]["tokenizer_type"]["string_value"]
6063

6164
if tokenizer_type == "t5":
@@ -115,24 +118,48 @@ def execute(self, requests):
115118
.as_numpy()
116119
.flatten()
117120
)
121+
118122
if len(tokens_batch) == 0:
119123
continue
120124

121125
# Postprocess output data
122-
prev_token = self._get_prev_token(request_id)
123-
self._store_prev_token(request_id, tokens_batch[-1])
126+
prev_token = self._get_var(request_id, "prev_token")
127+
token_buffer = self._get_var(request_id, "token_buffer")
128+
token_buffer = token_buffer if token_buffer is not None else []
129+
current_tokens = np.concatenate(
130+
(np.array(token_buffer, dtype=int), tokens_batch), dtype=int
131+
)
132+
current_tokens_decoded = self.tokenizer.decode(current_tokens)
133+
134+
if len(current_tokens_decoded) == 0:
135+
responses.append(pb_utils.InferenceResponse())
136+
continue
137+
138+
if current_tokens_decoded[-1] == INVALID_UNICODE_CHAR:
139+
# If the last token is invalid, we need to keep it in the buffer
140+
# for the next request to see if this is a multi-token unicode
141+
# character.
142+
self._store_var(request_id, "token_buffer", current_tokens)
143+
responses.append(pb_utils.InferenceResponse())
144+
continue
145+
124146
if prev_token is None:
125-
delta = self.tokenizer.decode(tokens_batch)
147+
delta = current_tokens_decoded
126148
else:
127149
# TODO(pankaj) Figure out how to make tokenizer.decode not
128150
# ignore initial whitespace so we can avoid this hack.
129151
# Get string with and without previous token and diff. This hack
130152
# is needed because tokenizer.decode strips initial whitespace.
131-
old_string = self.tokenizer.decode([prev_token])
132-
with_prev_token = np.concatenate(([prev_token], tokens_batch))
153+
old_string = self.tokenizer.decode(prev_token)
154+
with_prev_token = np.concatenate((prev_token, current_tokens))
133155
new_string = self.tokenizer.decode(with_prev_token)
134156
delta = self._compute_delta(old_string, new_string)
135157

158+
# The previous token is the last character of the decoded sequence
159+
# which includes the multi-token unicode character.
160+
self._store_var(request_id, "prev_token", current_tokens)
161+
self._store_var(request_id, "token_buffer", None)
162+
136163
# Create output tensor
137164
output_tensor = pb_utils.Tensor(
138165
"OUTPUT", np.array([delta]).astype(self.output_dtype)
@@ -147,22 +174,19 @@ def execute(self, requests):
147174
def finalize(self):
148175
print("Cleaning up...")
149176

150-
def _store_prev_token(self, request_id, token):
177+
def _store_var(self, request_id, var_name, var):
151178
if request_id in self.state_dict:
152-
self.state_dict[request_id]["prev_token"] = token
153-
154-
# Move request ID to end of queue to prevent it from being evicted
179+
self.state_dict[request_id][var_name] = var
155180
self.state_dict.move_to_end(request_id)
156181
else:
157-
# Evict least recently used item if cache is full
158182
if len(self.state_dict) > self.cache_size:
159183
self.state_dict.popitem(last=False)
184+
self.state_dict[request_id] = {"prev_token": None, "token_buffer": None}
185+
self.state_dict[request_id][var_name] = var
160186

161-
self.state_dict[request_id] = {"prev_token": token}
162-
163-
def _get_prev_token(self, request_id):
187+
def _get_var(self, request_id, var_name):
164188
if request_id in self.state_dict:
165-
return self.state_dict[request_id]["prev_token"]
189+
return self.state_dict[request_id][var_name]
166190
return None
167191

168192
def _compute_delta(self, prev_str, new_str):
File renamed without changes.

0 commit comments

Comments
 (0)