Skip to content

Commit ed12b4a

Browse files
bolasimaspctu
andauthored
Abuqader/fp8 tp8 (#244)
Co-authored-by: Abu Qader <[email protected]>
1 parent cd732f5 commit ed12b4a

File tree

14 files changed

+1636
-0
lines changed

14 files changed

+1636
-0
lines changed

mistral/mistral-tp8/config.yaml

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
apply_library_patches: true
2+
base_image:
3+
image: baseten/trtllm-server:r23.12_baseten_v0.9.0.dev2024022000
4+
python_executable_path: /usr/bin/python3
5+
build:
6+
arguments:
7+
engine_repository: baseten/mistral_fp8_tp8_i2048_o2048_bs128-tllm_0.9.0.dev2024022000
8+
pipeline_parallel_count: 1
9+
tensor_parallel_count: 8
10+
tokenizer_repository: mistralai/Mistral-7B-Instruct-v0.2
11+
bundled_packages_dir: packages
12+
data_dir: data
13+
description: Generate text from a prompt with this seven billion parameter language
14+
model.
15+
environment_variables: {}
16+
examples_filename: examples.yaml
17+
external_data: null
18+
external_package_dirs: []
19+
input_type: Any
20+
live_reload: false
21+
model_cache: []
22+
model_class_filename: model.py
23+
model_class_name: Model
24+
model_framework: custom
25+
model_metadata:
26+
tags:
27+
- text-generation
28+
- openai-compatible
29+
model_module_dir: model
30+
model_name: Mistral 8
31+
model_type: Model
32+
python_version: py311
33+
requirements:
34+
- tritonclient[all]
35+
- transformers
36+
- jinja2
37+
resources:
38+
accelerator: H100:8
39+
use_gpu: true
40+
runtime:
41+
num_workers: 1
42+
predict_concurrency: 512
43+
secrets: {}

mistral/mistral-tp8/model/__init__.py

Whitespace-only changes.

mistral/mistral-tp8/model/model.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
import os
2+
from itertools import count
3+
4+
import build_engine_utils
5+
from constants import (
6+
GRPC_SERVICE_PORT,
7+
HF_AUTH_KEY_CONSTANT,
8+
HTTP_SERVICE_PORT,
9+
TOKENIZER_KEY_CONSTANT,
10+
)
11+
from schema import ModelInput, TrussBuildConfig
12+
from transformers import AutoTokenizer
13+
from triton_client import TritonClient, TritonServer
14+
15+
16+
class Model:
17+
def __init__(self, data_dir, config, secrets):
18+
self._data_dir = data_dir
19+
self._config = config
20+
self._secrets = secrets
21+
self._request_id_counter = count(start=1)
22+
self.triton_client = None
23+
self.triton_server = None
24+
self.tokenizer = None
25+
self.uses_openai_api = None
26+
27+
def load(self):
28+
build_config = TrussBuildConfig(**self._config["build"]["arguments"])
29+
self.uses_openai_api = "openai-compatible" in self._config.get(
30+
"model_metadata", {}
31+
).get("tags", [])
32+
hf_access_token = None
33+
if "hf_access_token" in self._secrets._base_secrets.keys():
34+
hf_access_token = self._secrets["hf_access_token"]
35+
36+
# TODO(Abu): Move to pre-runtime
37+
if build_config.requires_build:
38+
build_engine_utils.build_engine_from_config_args(
39+
engine_build_args=build_config.engine_build_args,
40+
dst=self._data_dir,
41+
)
42+
43+
self.triton_server = TritonServer(
44+
grpc_port=GRPC_SERVICE_PORT,
45+
http_port=HTTP_SERVICE_PORT,
46+
)
47+
48+
self.triton_server.create_model_repository(
49+
truss_data_dir=self._data_dir,
50+
engine_repository_path=build_config.engine_repository
51+
if not build_config.requires_build
52+
else None,
53+
huggingface_auth_token=hf_access_token,
54+
)
55+
56+
env = {}
57+
if hf_access_token:
58+
env[HF_AUTH_KEY_CONSTANT] = hf_access_token
59+
env[TOKENIZER_KEY_CONSTANT] = build_config.tokenizer_repository
60+
61+
self.triton_server.start(
62+
world_size=build_config.tensor_parallel_count,
63+
env=env,
64+
)
65+
66+
self.triton_client = TritonClient(
67+
grpc_service_port=GRPC_SERVICE_PORT,
68+
)
69+
70+
self.tokenizer = AutoTokenizer.from_pretrained(
71+
build_config.tokenizer_repository, token=hf_access_token
72+
)
73+
self.eos_token_id = self.tokenizer.eos_token_id
74+
75+
async def predict(self, model_input):
76+
model_input["request_id"] = str(os.getpid()) + str(
77+
next(self._request_id_counter)
78+
)
79+
model_input["eos_token_id"] = self.eos_token_id
80+
81+
self.triton_client.start_grpc_stream()
82+
83+
model_input = ModelInput(**model_input)
84+
85+
result_iterator = self.triton_client.infer(model_input)
86+
87+
async def generate():
88+
async for result in result_iterator:
89+
yield result
90+
91+
if model_input.stream:
92+
return generate()
93+
else:
94+
if self.uses_openai_api:
95+
return "".join(generate())
96+
else:
97+
return {"text": "".join(generate())}
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
from pathlib import Path
2+
3+
from schema import EngineBuildArgs
4+
5+
6+
def build_engine_from_config_args(
7+
engine_build_args: EngineBuildArgs,
8+
dst: Path,
9+
):
10+
import os
11+
import shutil
12+
import sys
13+
14+
# NOTE: These are provided by the underlying base image
15+
# TODO(Abu): Remove this when we have a better way of handling this
16+
sys.path.append("/app/baseten")
17+
from build_engine import Engine, build_engine
18+
from trtllm_utils import docker_tag_aware_file_cache
19+
20+
engine = Engine(**engine_build_args.model_dump())
21+
22+
with docker_tag_aware_file_cache("/root/.cache/trtllm"):
23+
built_engine = build_engine(engine, download_remote=True)
24+
25+
if not os.path.exists(dst):
26+
os.makedirs(dst)
27+
28+
for filename in os.listdir(str(built_engine)):
29+
source_file = os.path.join(str(built_engine), filename)
30+
destination_file = os.path.join(dst, filename)
31+
if not os.path.exists(destination_file):
32+
shutil.copy(source_file, destination_file)
33+
34+
return dst
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from pathlib import Path
2+
3+
# If changing model repo path, please updated inside tensorrt_llm config.pbtxt as well
4+
TENSORRT_LLM_MODEL_REPOSITORY_PATH = Path("/packages/tensorrt_llm_model_repository/")
5+
GRPC_SERVICE_PORT = 8001
6+
HTTP_SERVICE_PORT = 8003
7+
HF_AUTH_KEY_CONSTANT = "HUGGING_FACE_HUB_TOKEN"
8+
TOKENIZER_KEY_CONSTANT = "TRITON_TOKENIZER_REPOSITORY"
9+
ENTRYPOINT_MODEL_NAME = "ensemble"

0 commit comments

Comments
 (0)