forked from NVIDIA/TensorRT-LLM
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathbuild_and_run_ad.py
More file actions
152 lines (129 loc) · 5.54 KB
/
build_and_run_ad.py
File metadata and controls
152 lines (129 loc) · 5.54 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
"""Main entrypoint to build, test, and prompt AutoDeploy inference models."""
import argparse
import json
from typing import List, Optional, Union
import torch
from simple_config import SimpleConfig
from tensorrt_llm._torch.auto_deploy.models import ModelFactoryRegistry
from tensorrt_llm._torch.auto_deploy.shim import AutoDeployConfig, DemoLLM
from tensorrt_llm._torch.auto_deploy.utils.benchmark import benchmark, store_benchmark_results
from tensorrt_llm._torch.auto_deploy.utils.logger import ad_logger
from tensorrt_llm.builder import BuildConfig
from tensorrt_llm.llmapi.llm import LLM, RequestOutput
from tensorrt_llm.sampling_params import SamplingParams
# Global torch config, set the torch compile cache to fix up to llama 405B
torch._dynamo.config.cache_size_limit = 20
def get_config_and_check_args() -> SimpleConfig:
parser = argparse.ArgumentParser()
parser.add_argument("-c", "--config", type=json.loads)
parser.add_argument("-m", "--model-kwargs", type=json.loads)
args = parser.parse_args()
configs_from_args = args.config or {}
configs_from_args["model_kwargs"] = getattr(args, "model_kwargs") or {}
config = SimpleConfig(**configs_from_args)
ad_logger.info(f"Simple Config: {config}")
return config
def build_llm_from_config(config: SimpleConfig) -> LLM:
"""Builds a LLM object from our config."""
# set up builder config
build_config = BuildConfig(max_seq_len=config.max_seq_len, max_batch_size=config.max_batch_size)
build_config.plugin_config.tokens_per_block = config.page_size
# setup AD config
ad_config = AutoDeployConfig(
# Both torch-opt and torch-cudagraph invoke cudagraphs
use_cuda_graph=config.compile_backend in ["torch-opt", "torch-cudagraph"],
# Both torch-opt and torch-compile invoke torch.compile
torch_compile_enabled=config.compile_backend in ["torch-opt", "torch-compile"],
model_factory=config.model_factory,
model_kwargs=config.model_kwargs,
attn_backend=config.attn_backend,
mla_backend=config.mla_backend,
skip_loading_weights=config.skip_loading_weights,
cuda_graph_max_batch_size=config.max_batch_size,
free_mem_ratio=config.free_mem_ratio,
simple_shard_only=config.simple_shard_only,
)
ad_logger.info(f"AutoDeploy Config: {ad_config}")
# TODO: let's see if prefetching can't be done through the LLM api?
# I believe the "classic workflow" invoked via the LLM api can do that.
# put everything into the HF model Factory and try pre-fetching the checkpoint
factory = ModelFactoryRegistry.get(config.model_factory)(
model=config.model,
model_kwargs=config.model_kwargs,
tokenizer=config.tokenizer,
tokenizer_kwargs=config.tokenizer_kwargs,
skip_loading_weights=config.skip_loading_weights,
)
ad_logger.info(f"Prefetched model : {factory.model}")
# construct llm high-level interface object
llm_lookup = {
"demollm": DemoLLM,
"trtllm": LLM,
}
llm = llm_lookup[config.runtime](
model=factory.model,
backend="autodeploy",
build_config=build_config,
auto_deploy_config=ad_config,
tensor_parallel_size=config.world_size,
tokenizer=factory.init_tokenizer() if config.customize_tokenizer else None,
)
return llm
def print_outputs(outs: Union[RequestOutput, List[RequestOutput]]) -> List[List[str]]:
prompts_and_outputs: List[List[str]] = []
if isinstance(outs, RequestOutput):
outs = [outs]
for i, out in enumerate(outs):
prompt, output = out.prompt, out.outputs[0].text
ad_logger.info(f"[PROMPT {i}] {prompt}: {output}")
prompts_and_outputs.append([prompt, output])
return prompts_and_outputs
@torch.inference_mode()
def main(config: Optional[SimpleConfig] = None):
if config is None:
config = get_config_and_check_args()
llm = build_llm_from_config(config)
# prompt the model and print its output
ad_logger.info("Running example prompts...")
outs = llm.generate(
config.prompt,
sampling_params=SamplingParams(
max_tokens=config.max_tokens,
top_k=config.top_k,
temperature=config.temperature,
),
)
results = {"prompts_and_outputs": print_outputs(outs)}
# run a benchmark for the model with batch_size == config.benchmark_bs
if config.benchmark and config.runtime != "trtllm":
ad_logger.info("Running benchmark...")
keys = [
"compile_backend",
"attn_backend",
"mla_backend",
"benchmark_bs",
"benchmark_isl",
"benchmark_osl",
"benchmark_num",
]
results["benchmark_results"] = benchmark(
func=lambda: llm.generate(
torch.randint(0, 100, (config.benchmark_bs, config.benchmark_isl)).tolist(),
sampling_params=SamplingParams(
max_tokens=config.benchmark_osl,
top_k=None,
ignore_eos=True,
),
use_tqdm=False,
),
num_runs=config.benchmark_num,
log_prefix="Benchmark with " + ", ".join(f"{k}={getattr(config, k)}" for k in keys),
results_path=config.benchmark_results_path,
)
elif config.benchmark:
ad_logger.info("Skipping simple benchmarking for trtllm...")
if config.benchmark_store_results:
store_benchmark_results(results, config.benchmark_results_path)
llm.shutdown()
if __name__ == "__main__":
main()