Skip to content

Commit fb3a82b

Browse files
committed
sequence interface revisited
Signed-off-by: Lucas Liebenwein <[email protected]>
1 parent 913695f commit fb3a82b

File tree

10 files changed

+456
-186
lines changed

10 files changed

+456
-186
lines changed

examples/auto_deploy/.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,5 @@
22
!.vscode
33
benchmark_results.json
44
*.png
5+
# ignore config files that users might put here for debugging
6+
*.yaml

tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py

Lines changed: 281 additions & 130 deletions
Large diffs are not rendered by default.

tensorrt_llm/_torch/auto_deploy/llm.py

Lines changed: 82 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,83 @@
11
import types
2-
from typing import List, Optional
2+
from typing import Any, Dict, List, Optional, Tuple
33

44
from ...executor.result import CompletionOutput
5-
from ...inputs.registry import create_input_processor
5+
from ...inputs.registry import DefaultInputProcessor, ExtraProcessedInputs
66
from ...llmapi.llm import RequestOutput, _TorchLLM
7-
from ...llmapi.tokenizer import TokenizerBase, tokenizer_factory
7+
from ...llmapi.tokenizer import TokenizerBase, TransformersTokenizer, tokenizer_factory
8+
from ...sampling_params import SamplingParams
89
from .distributed import common as dist_ad
910
from .llm_args import LlmArgs
11+
from .models.factory import ModelFactory
1012
from .shim.demollm import DemoGenerationExecutor
1113

1214

15+
class ADInputProcessor(DefaultInputProcessor):
16+
"""Input processor for AutoDeploy backend.
17+
18+
This is a wrapper to either support standard TRT-LLM text-only input processing or use HF's
19+
message chat template system to process multimodal inputs.
20+
"""
21+
22+
def __init__(self, tokenizer: TokenizerBase, processor: Optional[Any] = None):
23+
super().__init__(None, None, tokenizer)
24+
# NOTE: HF's tokenizer/processor that has the apply_chat_template method
25+
self.processor = processor or tokenizer.tokenizer
26+
27+
def __call__(
28+
self, inputs: Dict[str, Any], sampling_params: SamplingParams
29+
) -> Tuple[List[int], Optional[ExtraProcessedInputs]]:
30+
# construct kwargs to reflect DefaultInputProcessor
31+
kwargs = {
32+
"add_special_tokens": sampling_params.add_special_tokens,
33+
}
34+
if sampling_params.truncate_prompt_tokens is not None:
35+
kwargs = {
36+
"truncation": True,
37+
"max_length": sampling_params.truncate_prompt_tokens,
38+
}
39+
# check for messages field and if yes, use the apply_chat_template method
40+
if "messages" in inputs:
41+
# TODO: we don't really need this but it makes for a good sanity check. Consider
42+
# removing this in the future if we need to speed things up.
43+
prompt = self.processor.apply_chat_template(
44+
inputs["messages"],
45+
add_generation_prompt=True,
46+
tokenize=False,
47+
)
48+
inputs["prompt"] = prompt
49+
50+
all_args = self.processor.apply_chat_template(
51+
inputs["messages"],
52+
add_generation_prompt=True,
53+
tokenize=True,
54+
return_dict=True,
55+
return_tensors="pt",
56+
padding=False, # there shouldn't be a need for padding ever...
57+
return_attention_mask=False,
58+
**kwargs,
59+
)
60+
# TODO: is there a more reliable way to avoid the attention_mask here?
61+
all_args.pop("attention_mask", None)
62+
63+
# TODO: can we avoid the extra tolist() here eventually?
64+
token_ids = all_args.pop("input_ids")
65+
assert token_ids.shape[0] == 1, "messages should be unbatched at this point."
66+
return token_ids[0].tolist(), {"multimodal_data": all_args} if all_args else None
67+
else:
68+
token_ids = self.tokenizer.encode(inputs["prompt"], **kwargs)
69+
return token_ids, None
70+
71+
1372
class LLM(_TorchLLM):
1473
"""LLM class is the main class for running an LLM model using AutoDeploy backend."""
1574

1675
args: LlmArgs
76+
_factory: ModelFactory
77+
78+
@property
79+
def factory(self) -> ModelFactory:
80+
return self._factory
1781

1882
def __init__(self, *args, **kwargs):
1983
kwargs["backend"] = "_autodeploy"
@@ -23,30 +87,36 @@ def _try_load_tokenizer(self) -> Optional[TokenizerBase]:
2387
if self.args.skip_tokenizer_init:
2488
return None
2589

26-
factory = self.args.create_factory()
27-
return tokenizer_factory(factory.init_tokenizer())
90+
return tokenizer_factory(self._factory.init_tokenizer())
2891

2992
def _validate_args_for_torch_backend(self, kwargs: dict) -> None:
3093
"""We don't need to validate args for AutoDeploy backend for now."""
3194
pass
3295

33-
def _prefetch_model(self):
34-
"""Prefetch the model for the LLM."""
35-
self.args.create_factory().prefetch_checkpoint()
96+
def _create_input_processor(self) -> ADInputProcessor:
97+
return ADInputProcessor(self.tokenizer, self._factory.init_processor())
3698

3799
def _build_model(self):
38100
"""Build the model for the LLM.
39101
40102
This is a wrapper around the regular build model method that prefetches the model with the
41103
factory.
42104
"""
105+
# create and store a factory
106+
self._factory = self.args.create_factory()
107+
43108
# prefetch model with factory
44-
self._prefetch_model()
109+
self._factory.prefetch_checkpoint()
45110

46111
# NOTE (lucaslie): do regular build model, we bypass the regular LLM CachedModelLoader in
47112
# _autodeploy backend.
48113
super()._build_model()
49114

115+
# now correct input processor
116+
assert isinstance(self.input_processor, DefaultInputProcessor)
117+
assert isinstance(self.tokenizer, TransformersTokenizer)
118+
self.input_processor = self._create_input_processor()
119+
50120

51121
class DemoLLM(LLM):
52122
"""A simple LLM class to demo the LLM interface while debugging the e2e workflow.
@@ -61,9 +131,10 @@ def __init__(self, **kwargs):
61131
self.runtime_context = None
62132

63133
# prefetch model and load tokenizer
64-
self._prefetch_model()
134+
self._factory = self.args.create_factory()
135+
self._factory.prefetch_checkpoint()
65136
self._tokenizer = self._try_load_tokenizer()
66-
self.input_processor = create_input_processor(None, self.tokenizer)
137+
self.input_processor = self._create_input_processor()
67138

68139
# construct demo executor + engine
69140
self._executor = DemoGenerationExecutor(

tensorrt_llm/_torch/auto_deploy/models/factory.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@
22

33
import copy
44
from abc import ABC, abstractmethod
5-
from typing import Any, Callable, Dict, Optional, Type
5+
from typing import Any, Callable, Dict, Optional, Tuple, Type
66

77
import torch
88
import torch.nn as nn
99
from torch._prims_common import DeviceLikeType
1010

11-
from ..custom_ops.attention_interface import CacheConfig
11+
from ..custom_ops.attention_interface import CacheConfig, DynamicShapeCallback
1212
from ..utils.logger import ad_logger
1313

1414

@@ -113,6 +113,15 @@ def init_tokenizer(self) -> Optional[Any]:
113113
"""
114114
return None
115115

116+
def init_processor(self) -> Optional[Any]:
117+
"""Initialize the (multi-modal) processor for the model.
118+
119+
Returns:
120+
The initialized processor for the model. If the processor is not available, then this
121+
method should return None.
122+
"""
123+
return None
124+
116125
def prefetch_checkpoint(self, force: bool = False):
117126
"""Try or skip prefetching the checkpoint for the model and tokenizer.
118127
@@ -206,6 +215,33 @@ def _load_checkpoint(self, model: nn.Module, device: DeviceLikeType):
206215
device: The device to load the model on.
207216
"""
208217

218+
def get_example_inputs(self) -> Dict[str, torch.Tensor]:
219+
"""Return a dictionary of example inputs for the model.
220+
221+
This function can be overwritten by a factory when it requires a specific example input to
222+
in order to run through export.
223+
224+
Returns:
225+
A dictionary of example inputs for the model where the key corresponds to the argument
226+
name and the value corresponds to the example input.
227+
"""
228+
return {}
229+
230+
def get_extra_inputs(self) -> Dict[str, Tuple[torch.Tensor, DynamicShapeCallback]]:
231+
"""Return a dictionary of extra inputs for the model.
232+
233+
Returns:
234+
A dictionary of extra inputs for the model where the key corresponds to the argument
235+
name and the value corresponds to a tuple of (none_input, dynamic_shape_callback):
236+
- `none_input`: The none input value of the extra input indicating the tensor
237+
value corresponding to the equivalent of the None input. `None` is not supported
238+
as we require the input to be a tensor. Hence, this none_input acts as a
239+
placeholder for the None input.
240+
- `dynamic_shape_callback`: A function that returns the dynamic shape of the extra
241+
input.
242+
"""
243+
return {}
244+
209245

210246
class ModelFactoryRegistry:
211247
_registry: Dict[str, Type[ModelFactory]] = {}

tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
from itertools import chain
1+
from collections import defaultdict
22
from types import SimpleNamespace
3-
from typing import List, Optional, Tuple
3+
from typing import Dict, List, Optional, Tuple
44

55
import torch
66
from torch._prims_common import DeviceLikeType
@@ -102,16 +102,24 @@ def build_from_config(cls, ad_config: AutoDeployConfig):
102102
max_num_tokens=max_num_tokens,
103103
)
104104

105+
# get factory
106+
factory = ad_config.create_factory()
107+
105108
# update device to contain the current default device if it's in cuda
106109
device = torch.device(ad_config.device)
107110
if device.type == "cuda" and device.index is None:
108111
device = torch.device(f"cuda:{torch.cuda.current_device()}")
109112
device = str(device)
110113

114+
# pass in extra arguments defined by the model factory
115+
for name, (none_input, dynamic_shape_callback) in factory.get_extra_inputs().items():
116+
seq_info.add_extra_arg(name, none_input, dynamic_shape_callback)
117+
118+
# TODO (lucaslie): consider how we move args around InferenceOptimizer.__init__,
119+
# ADEngine.__init__, and ADEngine.build_from_config. Seems a bit unnatural atm.
120+
111121
# construct inference optimizer
112-
build_and_optimize = InferenceOptimizer(
113-
factory=ad_config.create_factory(), ad_config=ad_config
114-
)
122+
build_and_optimize = InferenceOptimizer(factory=factory, ad_config=ad_config)
115123

116124
# construct engine
117125
return cls(build_and_optimize, seq_info, device, max_beam_width)
@@ -176,6 +184,7 @@ def _prepare_inputs(
176184
input_pos: List[int] = []
177185
last_logit_only: List[bool] = []
178186
page_assignments: List[List[int]] = []
187+
extra_args: Dict[str, List[torch.Tensor]] = defaultdict(list)
179188

180189
# look at context requests first
181190
for request in context_requests:
@@ -186,6 +195,15 @@ def _prepare_inputs(
186195
request.py_batch_idx = request.seq_slot
187196
last_logit_only.append(True)
188197

198+
# get cache indices
199+
cache_indices = kv_cache_manager.get_cache_indices(request)
200+
page_assignments.append(cache_indices)
201+
202+
# store extra arguments
203+
if request.py_multimodal_data is not None:
204+
for k, v in request.py_multimodal_data.items():
205+
extra_args[k].append(v)
206+
189207
# look at generate requests next
190208
# TODO: we should also handle extend requests (for speculative decoding) here
191209
for request in gen_requests:
@@ -202,17 +220,17 @@ def _prepare_inputs(
202220
# return all logits
203221
last_logit_only.append(False)
204222

205-
# extract cache information for all requests
206-
for request in chain(context_requests, gen_requests):
207223
# get cache indices
208224
cache_indices = kv_cache_manager.get_cache_indices(request)
209225
page_assignments.append(cache_indices)
210226

211227
# update the sequence info object now
212-
si = self.cache_seq_interface.info
213-
si.nest_sequences(input_ids)
214-
si.update_pos(input_pos, reset=True)
215-
si.assign_cache_loc(page_assignments)
228+
self.cache_seq_interface.info.nest_sequences(
229+
input_ids,
230+
input_pos=input_pos,
231+
page_assignments=page_assignments,
232+
**extra_args,
233+
)
216234
return last_logit_only
217235

218236
def _compute_logits(self) -> List[torch.Tensor]:

tensorrt_llm/_torch/auto_deploy/transform/library/export_to_gm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def _apply(
5353
model = gm.get_submodule("factory_model")
5454

5555
# set the example sequence
56-
cm.info.set_example_sequence()
56+
cm.info.set_example_sequence(**factory.get_example_inputs())
5757

5858
# export the model to a graph module
5959
gm = torch_export_to_gm(

tensorrt_llm/_torch/auto_deploy/transformations/library/kvcache.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,6 @@ def update_in_out_nodes(egm: GraphModule, cm: CachedSequenceInterface) -> None:
2626
# loop through nodes to get input, output, and get_attr nodes
2727
input_nodes, output_nodes = get_all_input_output_nodes(egm.graph)
2828

29-
# we only expect one input node
30-
assert len(input_nodes) == 2, "Expected exactly two input nodes (input_ids, position_ids)."
31-
3229
# NOTE: for now, we wanna make sure we *only* return the final output and no hidden states.
3330
# Later on, we can revisit how to support returning hidden states.
3431
assert len(output_nodes) == 1, "Expected exactly one output node!"
@@ -73,16 +70,17 @@ def insert_cached_attention(
7370

7471
# retrieve input nodes
7572
input_nodes, _ = get_all_input_output_nodes(egm.graph)
73+
input_nodes_mapping = {n.target: n for n in input_nodes}
74+
75+
# filtered and sorted for SequenceInfo arguments (input_ids, position_ids, etc.)
76+
input_nodes_from_info = [input_nodes_mapping[k] for k in cm.info.named_standard_args.keys()]
7677

7778
# insert metadata computation and extract each argument as a node
7879
get_metadata, num_metadata = attn_descriptor.get_prepare_metadata_op()
7980
with graph.inserting_before(input_nodes[-1].next):
8081
ret_node = graph.call_function(
8182
get_metadata,
82-
args=(
83-
*input_nodes,
84-
cm.info.page_size,
85-
),
83+
args=(*input_nodes_from_info, *cm.info.extra_args_for_prepare_metadata),
8684
)
8785
metadata_nodes = [
8886
graph.call_function(operator.getitem, args=(ret_node, idx))
@@ -162,7 +160,7 @@ def _get_mem_info_in_mb():
162160

163161
try:
164162
# Let's run a forward pass to get the memory usage
165-
cm.info._set_max_num_tokens_sample()
163+
cm.info.set_max_num_tokens_sample()
166164
free_mem_pre, _ = _get_mem_info_in_mb()
167165
ad_logger.info(f"Free memory before forward pass (MB): {free_mem_pre}")
168166

tensorrt_llm/executor/worker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def __init__(
8080
self) # TODO: make it weakref
8181
self._executor_config = executor_config
8282
self._is_pytorch_backend = getattr(self._executor_config, "backend",
83-
None) == "pytorch"
83+
None) in ["pytorch", "_autodeploy"]
8484

8585
if global_mpi_size() > 1:
8686
logger.set_rank(self.global_rank)

tests/unittest/_torch/auto_deploy/unit/singlegpu/shim/test_engine.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,6 @@ def test_engine(engine_cls: Type[ADEngine], attn_backend: str, attn_page_size: i
7171
input_ids = [torch.tensor([0, 1, 2], device=device)]
7272
sequence_info.reset()
7373
sequence_info.nest_sequences(input_ids)
74-
engine.cache_seq_interface.info.sync(sequence_info)
7574
logits = engine._compute_logits()
7675
logits = torch.stack(logits)
7776
assert logits is not None, "Logits are None"
@@ -106,7 +105,6 @@ def test_demo_engine_sampling(attn_page_size: int):
106105
input_ids = [torch.tensor([1, 2, 3, 4], device=device)]
107106
sequence_info.reset()
108107
sequence_info.nest_sequences(input_ids)
109-
engine.cache_seq_interface.info.sync(sequence_info)
110108
logits = engine._compute_logits()
111109
logits = torch.stack(logits)
112110

0 commit comments

Comments
 (0)