Skip to content

Commit 6eff0da

Browse files
committed
full scale hf chat template
Signed-off-by: Lucas Liebenwein <[email protected]>
1 parent b64ad52 commit 6eff0da

File tree

7 files changed

+154
-152
lines changed

7 files changed

+154
-152
lines changed

examples/auto_deploy/build_and_run_ad.py

Lines changed: 43 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@
2626
# Global torch config, set the torch compile cache to fix up to llama 405B
2727
torch._dynamo.config.cache_size_limit = 20
2828

29+
# simple string, TRT-LLM style text-only prompt or full-scale HF message template
30+
PromptInput = Union[str, Dict, List[Dict]]
31+
2932

3033
class PromptConfig(BaseModel):
3134
"""Prompt configuration.
@@ -35,17 +38,27 @@ class PromptConfig(BaseModel):
3538
"""
3639

3740
batch_size: int = Field(default=2, description="Number of queries")
38-
queries: Union[str, List[str]] = Field(
41+
queries: Union[PromptInput, List[PromptInput]] = Field(
3942
default_factory=lambda: [
43+
# OPTION 1: simple text prompt
4044
"How big is the universe? ",
41-
"In simple words and in a single sentence, explain the concept of gravity: ",
42-
"How to fix slicing in golf? ",
43-
"Where is the capital of Iceland? ",
44-
"How big is the universe? ",
45-
"In simple words and in a single sentence, explain the concept of gravity: ",
46-
"How to fix slicing in golf? ",
47-
"Where is the capital of Iceland? ",
48-
]
45+
# OPTION 2: wrapped text prompt for TRT-LLM
46+
{"prompt": "In simple words and a single sentence, explain the concept of gravity: "},
47+
# OPTION 3: a full-scale HF message template (this one works for text-only models!)
48+
# Learn more about chat templates: https://huggingface.co/docs/transformers/en/chat_templating
49+
# and multi-modal templates: https://huggingface.co/docs/transformers/en/chat_templating_multimodal
50+
[
51+
{
52+
"role": "user",
53+
"content": "How to fix slicing in golf?",
54+
}
55+
],
56+
# More prompts...
57+
{"prompt": "Where is the capital of Iceland? "},
58+
],
59+
description="Example queries to prompt the model with. We support both TRT-LLM text-only "
60+
"queries via the 'prompt' key and full-scale HF message template called via "
61+
"apply_chat_template.",
4962
)
5063
sp_kwargs: Dict[str, Any] = Field(
5164
default_factory=lambda: {"max_tokens": 100, "top_k": 200, "temperature": 1.0},
@@ -59,10 +72,28 @@ def model_post_init(self, __context: Any):
5972
NOTE (lucaslie): has to be done with model_post_init to ensure it's always run. field
6073
validators are only run if a value is provided.
6174
"""
62-
queries = [self.queries] if isinstance(self.queries, str) else self.queries
75+
queries = self.queries if isinstance(self.queries, list) else [self.queries]
6376
batch_size = self.batch_size
6477
queries = queries * (batch_size // len(queries) + 1)
65-
self.queries = queries[:batch_size]
78+
queries = queries[:batch_size]
79+
80+
# now let's standardize the queries for the LLM api to understand them
81+
queries_processed = []
82+
for query in queries:
83+
if isinstance(query, str):
84+
queries_processed.append({"prompt": query})
85+
elif isinstance(query, dict):
86+
queries_processed.append(query)
87+
elif isinstance(query, list):
88+
queries_processed.append(
89+
{
90+
"prompt": "Fake prompt. Check out messages field for the HF chat template.",
91+
"messages": query, # contains the actual HF chat template
92+
}
93+
)
94+
else:
95+
raise ValueError(f"Invalid query type: {type(query)}")
96+
self.queries = queries_processed
6697

6798
@field_validator("sp_kwargs", mode="after")
6899
@classmethod
@@ -237,56 +268,13 @@ def main(config: Optional[ExperimentConfig] = None):
237268

238269
llm = build_llm_from_config(config)
239270

240-
# just run config.prompt.queries with our special token sequence including special image tokens
241-
# fmt: off
242-
input_ids = [[
243-
200000, 200005, 1556, 200006, 368, 200080, 200090, 200092, 200092,
244-
200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092,
245-
200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092,
246-
200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092,
247-
200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092,
248-
200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092,
249-
200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092,
250-
200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092,
251-
200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092,
252-
200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092,
253-
200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092,
254-
200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092,
255-
200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092,
256-
200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092,
257-
200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092,
258-
200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092,
259-
200092, 200092, 200092, 200092, 200092, 200092, 200092, 200081, 200080,
260-
200090, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092,
261-
200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092,
262-
200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092,
263-
200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092,
264-
200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092,
265-
200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092,
266-
200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092,
267-
200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092,
268-
200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092,
269-
200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092,
270-
200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092,
271-
200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092,
272-
200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092,
273-
200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092,
274-
200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092,
275-
200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092,
276-
200092, 200081, 51212, 1780, 650, 2556, 310, 290, 1472,
277-
8392, 341, 1357, 13492, 26, 200008, 200005, 140680, 200006,
278-
368
279-
] for _ in range(2)]
280-
# fmt: on
281-
282271
# prompt the model and print its output
283272
ad_logger.info("Running example prompts...")
284273

285274
# now let's try piping through multimodal data
286275

287276
outs = llm.generate(
288-
input_ids,
289-
# config.prompt.queries,
277+
config.prompt.queries,
290278
sampling_params=SamplingParams(**config.prompt.sp_kwargs),
291279
)
292280
results = {"prompts_and_outputs": print_outputs(outs)}

tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -577,6 +577,7 @@ def nest_sequences(
577577
else:
578578
self._extra_args[name] = none_input
579579

580+
# TODO (lucaslie): how strict do we wanna be here? Should we just warn/ignore instead?
580581
assert not extra_args, f"Extra arguments {extra_args.keys()} not found"
581582

582583
def unnest_sequences(self, t_nested: torch.Tensor) -> List[torch.Tensor]:

tensorrt_llm/_torch/auto_deploy/llm.py

Lines changed: 82 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,83 @@
11
import types
2-
from typing import List, Optional
2+
from typing import Any, Dict, List, Optional, Tuple
33

44
from ...executor.result import CompletionOutput
5-
from ...inputs.registry import create_input_processor
5+
from ...inputs.registry import DefaultInputProcessor, ExtraProcessedInputs
66
from ...llmapi.llm import RequestOutput, _TorchLLM
7-
from ...llmapi.tokenizer import TokenizerBase, tokenizer_factory
7+
from ...llmapi.tokenizer import TokenizerBase, TransformersTokenizer, tokenizer_factory
8+
from ...sampling_params import SamplingParams
89
from .distributed import common as dist_ad
910
from .llm_args import LlmArgs
11+
from .models.factory import ModelFactory
1012
from .shim.demollm import DemoGenerationExecutor
1113

1214

15+
class ADInputProcessor(DefaultInputProcessor):
16+
"""Input processor for AutoDeploy backend.
17+
18+
This is a wrapper to either support standard TRT-LLM text-only input processing or use HF's
19+
message chat template system to process multimodal inputs.
20+
"""
21+
22+
def __init__(self, tokenizer: TokenizerBase, processor: Optional[Any] = None):
23+
super().__init__(None, None, tokenizer)
24+
# NOTE: HF's tokenizer/processor that has the apply_chat_template method
25+
self.processor = processor or tokenizer.tokenizer
26+
27+
def __call__(
28+
self, inputs: Dict[str, Any], sampling_params: SamplingParams
29+
) -> Tuple[List[int], Optional[ExtraProcessedInputs]]:
30+
# construct kwargs to reflect DefaultInputProcessor
31+
kwargs = {
32+
"add_special_tokens": sampling_params.add_special_tokens,
33+
}
34+
if sampling_params.truncate_prompt_tokens is not None:
35+
kwargs = {
36+
"truncation": True,
37+
"max_length": sampling_params.truncate_prompt_tokens,
38+
}
39+
# check for messages field and if yes, use the apply_chat_template method
40+
if "messages" in inputs:
41+
# TODO: we don't really need this but it makes for a good sanity check. Consider
42+
# removing this in the future if we need to speed things up.
43+
prompt = self.processor.apply_chat_template(
44+
inputs["messages"],
45+
add_generation_prompt=True,
46+
tokenize=False,
47+
)
48+
inputs["prompt"] = prompt
49+
50+
all_args = self.processor.apply_chat_template(
51+
inputs["messages"],
52+
add_generation_prompt=True,
53+
tokenize=True,
54+
return_dict=True,
55+
return_tensors="pt",
56+
padding=False, # there shouldn't be a need for padding ever...
57+
return_attention_mask=False,
58+
**kwargs,
59+
)
60+
# TODO: is there a more reliable way to avoid the attention_mask here?
61+
all_args.pop("attention_mask", None)
62+
63+
# TODO: can we avoid the extra tolist() here eventually?
64+
token_ids = all_args.pop("input_ids")
65+
assert token_ids.shape[0] == 1, "messages should be unbatched at this point."
66+
return token_ids[0].tolist(), {"multimodal_data": all_args} if all_args else None
67+
else:
68+
token_ids = self.tokenizer.encode(inputs["prompt"], **kwargs)
69+
return token_ids, None
70+
71+
1372
class LLM(_TorchLLM):
1473
"""LLM class is the main class for running an LLM model using AutoDeploy backend."""
1574

1675
args: LlmArgs
76+
_factory: ModelFactory
77+
78+
@property
79+
def factory(self) -> ModelFactory:
80+
return self._factory
1781

1882
def __init__(self, *args, **kwargs):
1983
kwargs["backend"] = "_autodeploy"
@@ -23,30 +87,36 @@ def _try_load_tokenizer(self) -> Optional[TokenizerBase]:
2387
if self.args.skip_tokenizer_init:
2488
return None
2589

26-
factory = self.args.create_factory()
27-
return tokenizer_factory(factory.init_tokenizer())
90+
return tokenizer_factory(self._factory.init_tokenizer())
2891

2992
def _validate_args_for_torch_backend(self, kwargs: dict) -> None:
3093
"""We don't need to validate args for AutoDeploy backend for now."""
3194
pass
3295

33-
def _prefetch_model(self):
34-
"""Prefetch the model for the LLM."""
35-
self.args.create_factory().prefetch_checkpoint()
96+
def _create_input_processor(self) -> ADInputProcessor:
97+
return ADInputProcessor(self.tokenizer, self._factory.init_processor())
3698

3799
def _build_model(self):
38100
"""Build the model for the LLM.
39101
40102
This is a wrapper around the regular build model method that prefetches the model with the
41103
factory.
42104
"""
105+
# create and store a factory
106+
self._factory = self.args.create_factory()
107+
43108
# prefetch model with factory
44-
self._prefetch_model()
109+
self._factory.prefetch_checkpoint()
45110

46111
# NOTE (lucaslie): do regular build model, we bypass the regular LLM CachedModelLoader in
47112
# _autodeploy backend.
48113
super()._build_model()
49114

115+
# now correct input processor
116+
assert isinstance(self.input_processor, DefaultInputProcessor)
117+
assert isinstance(self.tokenizer, TransformersTokenizer)
118+
self.input_processor = self._create_input_processor()
119+
50120

51121
class DemoLLM(LLM):
52122
"""A simple LLM class to demo the LLM interface while debugging the e2e workflow.
@@ -61,9 +131,10 @@ def __init__(self, **kwargs):
61131
self.runtime_context = None
62132

63133
# prefetch model and load tokenizer
64-
self._prefetch_model()
134+
self._factory = self.args.create_factory()
135+
self._factory.prefetch_checkpoint()
65136
self._tokenizer = self._try_load_tokenizer()
66-
self.input_processor = create_input_processor(None, self.tokenizer)
137+
self.input_processor = self._create_input_processor()
67138

68139
# construct demo executor + engine
69140
self._executor = DemoGenerationExecutor(

tensorrt_llm/_torch/auto_deploy/models/factory.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,15 @@ def init_tokenizer(self) -> Optional[Any]:
113113
"""
114114
return None
115115

116+
def init_processor(self) -> Optional[Any]:
117+
"""Initialize the (multi-modal) processor for the model.
118+
119+
Returns:
120+
The initialized processor for the model. If the processor is not available, then this
121+
method should return None.
122+
"""
123+
return None
124+
116125
def prefetch_checkpoint(self, force: bool = False):
117126
"""Try or skip prefetching the checkpoint for the model and tokenizer.
118127

0 commit comments

Comments
 (0)