Skip to content

Commit 743fb0a

Browse files
authored
[AutoDeploy] _AutoDeployLlmArgs as primary config object (NVIDIA#4891)
Signed-off-by: Lucas Liebenwein <[email protected]>
1 parent 91e8d43 commit 743fb0a

File tree

22 files changed

+529
-205
lines changed

22 files changed

+529
-205
lines changed

examples/auto_deploy/.vscode/launch.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
"program": "build_and_run_ad.py",
99
"args": [
1010
"--config",
11-
"{\"batch_size\": 2, \"page_size\": 16, \"world_size\": 2, \"compile_backend\": \"torch-simple\", \"attn_backend\": \"FlashInfer\",\"model_factory\": \"AutoModelForCausalLM\", \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\", \"benchmark\": false}",
11+
"{\"batch_size\": 2, \"attn_page_size\": 16, \"world_size\": 2, \"compile_backend\": \"torch-simple\", \"attn_backend\": \"FlashInfer\",\"model_factory\": \"AutoModelForCausalLM\", \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\", \"benchmark\": false}",
1212
"--model-kwargs",
1313
"{}",
1414
// "{\"num_hidden_layers\": 3}",

examples/auto_deploy/README.md

Lines changed: 14 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ In the below example:
151151
| `"mla_backend"` | Specifies implementation for multi-head latent attention |
152152
| `"max_seq_len"` | Maximum sequence length for inference/cache |
153153
| `"max_batch_size"` | Maximum dimension for statically allocated KV cache |
154-
| `"page_size"` | Page size for attention |
154+
| `"attn_page_size"` | Page size for attention |
155155
| `"benchmark"` | Indicates whether to run the built-in benchmark for token generation |
156156

157157
For default values and additional configuration options, refer to the [simple_config.py](./simple_config.py) file.
@@ -236,37 +236,25 @@ Here is an example of how you can build an LLM object with AutoDeploy integratio
236236

237237
```
238238
from tensorrt_llm import LLM
239-
from tensorrt_llm.builder import BuildConfig
240-
from tensorrt_llm._torch.auto_deploy.shim import AutoDeployConfig
241239
242-
# 1. Set up the build configuration
243-
build_config = BuildConfig(
244-
max_seq_len=<MAX_SEQ_LEN>,
245-
max_batch_size=<MAX_BS>,
246-
)
247-
build_config.plugin_config.tokens_per_block = <PAGE_SIZE>
248-
# if using "TritonWithFlattenedInputs" as backend, <PAGE_SIZE> should equal to <MAX_SEQ_LEN>
249-
# Refer to examples/auto_deploy/simple_config.py (line 109) for details.
250-
251-
# 2. Set up AutoDeploy configuration
252-
# AutoDeploy will use its own cache implementation
253-
model_kwargs = {"use_cache":False}
254240
255-
ad_config = AutoDeployConfig(
241+
# Construct the LLM high-level interface object with autodeploy as backend
242+
llm = LLM(
243+
model=<HF_MODEL_CARD_OR_DIR>,
244+
backend="_autodeploy",
245+
tensor_parallel_size=<NUM_WORLD_RANK>,
256246
use_cuda_graph=True, # set True if using "torch-opt" as compile backend
257247
torch_compile_enabled=True, # set True if using "torch-opt" as compile backend
258-
model_kwargs=model_kwargs,
248+
model_kwargs={"use_cache": False}, # AutoDeploy uses its own cache implementation
259249
attn_backend="TritonWithFlattenedInputs", # choose between "TritonWithFlattenedInputs" and "FlashInfer"
250+
attn_page_size=64, # page size for attention (tokens_per_block, should be == max_seq_len for triton)
260251
skip_loading_weights=False,
261-
)
262-
263-
# 3. Construct the LLM high-level interface object with autodeploy as backend
264-
llm = LLM(
265-
model=<HF_MODEL_CARD_OR_DIR>,
266-
backend="autodeploy",
267-
build_config=build_config,
268-
auto_deploy_config=ad_config,
269-
tensor_parallel_size=<NUM_WORLD_RANK>,
252+
model_factory="AutoModelForCausalLM", # choose appropriate model factory
253+
mla_backend="MultiHeadLatentAttention", # for models that support MLA
254+
free_mem_ratio=0.8, # fraction of available memory for cache
255+
simple_shard_only=False, # tensor parallelism sharding strategy
256+
max_seq_len=<MAX_SEQ_LEN>,
257+
max_batch_size=<MAX_BATCH_SIZE>,
270258
)
271259
272260
```

examples/auto_deploy/build_and_run_ad.py

Lines changed: 16 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,9 @@
88
from simple_config import SimpleConfig
99

1010
from tensorrt_llm._torch.auto_deploy.models import ModelFactoryRegistry
11-
from tensorrt_llm._torch.auto_deploy.shim import AutoDeployConfig, DemoLLM
11+
from tensorrt_llm._torch.auto_deploy.shim import DemoLLM
1212
from tensorrt_llm._torch.auto_deploy.utils.benchmark import benchmark, store_benchmark_results
1313
from tensorrt_llm._torch.auto_deploy.utils.logger import ad_logger
14-
from tensorrt_llm.builder import BuildConfig
1514
from tensorrt_llm.llmapi.llm import LLM, RequestOutput
1615
from tensorrt_llm.sampling_params import SamplingParams
1716

@@ -33,27 +32,6 @@ def get_config_and_check_args() -> SimpleConfig:
3332

3433
def build_llm_from_config(config: SimpleConfig) -> LLM:
3534
"""Builds a LLM object from our config."""
36-
# set up builder config
37-
build_config = BuildConfig(max_seq_len=config.max_seq_len, max_batch_size=config.max_batch_size)
38-
build_config.plugin_config.tokens_per_block = config.page_size
39-
40-
# setup AD config
41-
ad_config = AutoDeployConfig(
42-
# Both torch-opt and torch-cudagraph invoke cudagraphs
43-
use_cuda_graph=config.compile_backend in ["torch-opt", "torch-cudagraph"],
44-
# Both torch-opt and torch-compile invoke torch.compile
45-
torch_compile_enabled=config.compile_backend in ["torch-opt", "torch-compile"],
46-
model_factory=config.model_factory,
47-
model_kwargs=config.model_kwargs,
48-
attn_backend=config.attn_backend,
49-
mla_backend=config.mla_backend,
50-
skip_loading_weights=config.skip_loading_weights,
51-
cuda_graph_max_batch_size=config.max_batch_size,
52-
free_mem_ratio=config.free_mem_ratio,
53-
simple_shard_only=config.simple_shard_only,
54-
)
55-
ad_logger.info(f"AutoDeploy Config: {ad_config}")
56-
5735
# TODO: let's see if prefetching can't be done through the LLM api?
5836
# I believe the "classic workflow" invoked via the LLM api can do that.
5937
# put everything into the HF model Factory and try pre-fetching the checkpoint
@@ -73,9 +51,21 @@ def build_llm_from_config(config: SimpleConfig) -> LLM:
7351
}
7452
llm = llm_lookup[config.runtime](
7553
model=factory.model,
76-
backend="autodeploy",
77-
build_config=build_config,
78-
auto_deploy_config=ad_config,
54+
backend="_autodeploy",
55+
max_seq_len=config.max_seq_len,
56+
max_batch_size=config.max_batch_size,
57+
# AutoDeploy-specific parameters
58+
use_cuda_graph=config.compile_backend in ["torch-opt", "torch-cudagraph"],
59+
torch_compile_enabled=config.compile_backend in ["torch-opt", "torch-compile"],
60+
model_factory=config.model_factory,
61+
model_kwargs=config.model_kwargs,
62+
attn_backend=config.attn_backend,
63+
mla_backend=config.mla_backend,
64+
skip_loading_weights=config.skip_loading_weights,
65+
cuda_graph_max_batch_size=config.max_batch_size,
66+
free_mem_ratio=config.free_mem_ratio,
67+
simple_shard_only=config.simple_shard_only,
68+
attn_page_size=config.attn_page_size, # Now passed directly as AutoDeploy parameter
7969
tensor_parallel_size=config.world_size,
8070
tokenizer=factory.init_tokenizer() if config.customize_tokenizer else None,
8171
)

examples/auto_deploy/simple_config.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from typing import Dict, List, Literal, Optional, Union
88

99

10+
# TODO: remove and unify with _AutoDeployLlmArgs
1011
@dataclass
1112
class SimpleConfig:
1213
"""Experiment Configuration."""
@@ -55,7 +56,7 @@ class SimpleConfig:
5556
mla_backend: Literal["MultiHeadLatentAttention"] = "MultiHeadLatentAttention"
5657
max_seq_len: int = 512 # max sequence length for inference/cache
5758
max_batch_size: int = 8 # max dimension for statically allocated kv cache
58-
page_size: int = 64 # page size for attention
59+
attn_page_size: int = 64 # page size for attention
5960
simple_shard_only: bool = False # if True, force simple sharding(all_gather) in TP;
6061
# otherwise auto-detect and use column+row (all_reduce) sharding
6162

@@ -94,11 +95,8 @@ def __post_init__(self):
9495
# check if model was supplied
9596
assert self.model, "model must be supplied!"
9697

97-
# we don't want to loose the default values for model_kwargs unless explicitly set by the
98-
# user. They are not preserved by the standard initialization process since they whole dict
99-
# gets replaced by the user provided one. We don't want that though.
100-
f_default = self.__dataclass_fields__["model_kwargs"].default_factory()
101-
setattr(self, "model_kwargs", {**f_default, **getattr(self, "model_kwargs")})
98+
# NEVER use cache
99+
self.model_kwargs["use_cache"] = False
102100

103101
# special handling for torch_dtype in model_kwargs since HF does not correctly update
104102
# torch_dtype string to an actual torch.dtype object (only with default)
@@ -120,7 +118,7 @@ def __post_init__(self):
120118

121119
# No paging allowed in TritonWithFlattenedInputs
122120
if self.attn_backend in ["TritonWithFlattenedInputs"]:
123-
self.page_size = self.max_seq_len
121+
self.attn_page_size = self.max_seq_len
124122

125123
# use min instead of max to avoid OOM for large batch size
126124
self.model_kwargs["max_position_embeddings"] = min(

tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ class SequenceInfo:
8686
# then the maximum number of sequences possible in the batch is min (max_batch_size, max_num_tokens // ISL).
8787
# Similarly, if a batch is composed of generate-only requests,
8888
# then the maximum number of sequences possible in the batch is min (max_batch_size, max_num_tokens).
89-
max_num_tokens: int = 0
89+
max_num_tokens: Optional[int] = None
9090

9191
## [UPDATE WITH CARE] TENSOR FIELDS THAT WILL BE PASSED TO PREPARE_METADATA OP #################
9292
# input_ids MUST ALWAYS BE THE FIRST FIELD
@@ -112,7 +112,7 @@ def __post_init__(self):
112112
# see https://github.com/NVIDIA/TensorRT-LLM/issues/4504
113113
max_seq_len_adjusted = self.max_seq_len + 1
114114

115-
if self.max_num_tokens < 1:
115+
if self.max_num_tokens is None or self.max_num_tokens < 1:
116116
self.max_num_tokens = self.max_batch_size * max_seq_len_adjusted
117117
# if the provided max_num_tokens is less than the max_batch_size * max_seq_len,
118118
# we use the provided max_num_tokens to calculate the number of pages

tensorrt_llm/_torch/auto_deploy/models/hf.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -75,9 +75,18 @@ def __init__(
7575
self.tokenizer_kwargs.setdefault("trust_remote_code", True)
7676
self._quant_config = None
7777

78-
# heuristic to disable use_cache
78+
# NEVER use cache
7979
self.model_kwargs["use_cache"] = False
8080

81+
# special handling for torch_dtype in model_kwargs since HF does not correctly update
82+
# torch_dtype string to an actual torch.dtype object (only with default)
83+
if "torch_dtype" in self.model_kwargs:
84+
dtype = self.model_kwargs["torch_dtype"]
85+
if isinstance(dtype, str):
86+
dtype = getattr(torch, self.model_kwargs["torch_dtype"])
87+
assert isinstance(dtype, torch.dtype), f"Invalid dtype: {dtype}"
88+
self.model_kwargs["torch_dtype"] = dtype
89+
8190
# prefetch the model+checkpoint
8291
self.prefetch_checkpoint()
8392
# load the quantization config
@@ -322,19 +331,18 @@ class AutoModelForImageTextToTextFactory(AutoModelForCausalLMFactory):
322331
def __init__(self, *args, **kwargs):
323332
super().__init__(*args, **kwargs)
324333

325-
# additional heuristic to disable use_cache
326-
self.model_kwargs["text_config"] = self.model_kwargs.get("text_config", {})
327-
self.model_kwargs["text_config"]["use_cache"] = False
328-
329-
self.model_kwargs["text_config"]["max_position_embeddings"] = self.model_kwargs[
330-
"max_position_embeddings"
331-
]
332-
333-
# additional heuristic to propagate use of num_hidden_layers
334+
# additional heuristic to propagate "important keys"
334335
# TODO (lucaslie): WAR until we have better support on dashboard to control model_kwargs
335-
nhl_key = "num_hidden_layers"
336-
if nhl_key in self.model_kwargs:
337-
self.model_kwargs["text_config"][nhl_key] = self.model_kwargs[nhl_key]
336+
keys_to_propagate = [
337+
"num_hidden_layers",
338+
"max_position_embeddings",
339+
"use_cache",
340+
"torch_dtype",
341+
]
342+
self.model_kwargs["text_config"] = self.model_kwargs.get("text_config", {})
343+
for key in keys_to_propagate:
344+
if key in self.model_kwargs:
345+
self.model_kwargs["text_config"][key] = self.model_kwargs[key]
338346

339347
@property
340348
def automodel_from_config(self):

tensorrt_llm/_torch/auto_deploy/shim/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22

33
from .ad_executor import create_autodeploy_executor
44
from .demollm import DemoLLM
5-
from .interface import AutoDeployConfig, CachedSequenceInterface, GetInferenceModel
5+
from .interface import CachedSequenceInterface, GetInferenceModel

tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from ...._utils import mpi_rank, mpi_world_size
1010
from ....bindings.executor import ExecutorConfig
1111
from ....bindings.internal.batch_manager import CacheType
12+
from ....llmapi.llm_args import _AutoDeployLlmArgs
1213
from ....mapping import Mapping
1314
from ...distributed import MPIDist
1415
from ...pyexecutor.config import PyTorchConfig
@@ -27,7 +28,7 @@
2728
from ..models import ModelFactoryRegistry
2829
from ..transformations.transform import InferenceOptimizer
2930
from ..utils.logger import ad_logger
30-
from .interface import AutoDeployConfig, CachedSequenceInterface, GetInferenceModel
31+
from .interface import CachedSequenceInterface, GetInferenceModel
3132

3233

3334
class _CacheManagerWithFakePool(KVCacheManager):
@@ -84,11 +85,11 @@ def _device(self) -> DeviceLikeType:
8485
def build_from_config(
8586
cls,
8687
model: str,
87-
ad_config: AutoDeployConfig,
88+
ad_config: _AutoDeployLlmArgs,
8889
seq_info: SequenceInfo,
8990
device: DeviceLikeType,
9091
):
91-
"""Build the ADEngine using the AutoDeployConfig that gets passed through from the LLM."""
92+
"""Build the ADEngine using the _AutoDeployLlmArgs that gets passed through from the LLM."""
9293

9394
# update device to contain the current default device if it's in cuda
9495
device = torch.device(device)
@@ -245,7 +246,7 @@ def create_autodeploy_executor(
245246
):
246247
"""Create an AutoDeploy executor from the given configuration and checkpoint directory.
247248
248-
This is the entrypoint API to the autodeploy backend.
249+
This is the entrypoint API to the _autodeploy backend.
249250
"""
250251
# initialize process groups
251252
world_size = mpi_world_size()
@@ -258,33 +259,34 @@ def create_autodeploy_executor(
258259
dist.initialize_or_skip(rank, world_size, port)
259260

260261
# some config
261-
if executor_config.pytorch_backend_config is None:
262-
executor_config.pytorch_backend_config = AutoDeployConfig(attn_backend="FlashInfer")
262+
msg = "pytorch_backend_config must be an _AutoDeployLlmArgs object"
263+
assert isinstance(executor_config.pytorch_backend_config, _AutoDeployLlmArgs), msg
264+
ad_config: _AutoDeployLlmArgs = executor_config.pytorch_backend_config
263265

264-
max_batch_size = executor_config.max_batch_size
265-
max_seq_len = executor_config.max_seq_len
266-
tokens_per_block = executor_config.tokens_per_block
267-
max_num_tokens = executor_config.max_num_tokens
268-
ad_logger.info(f"{max_seq_len=}, {max_batch_size=}, {tokens_per_block=}, {max_num_tokens=}")
266+
max_batch_size = ad_config.max_batch_size
267+
max_seq_len = ad_config.max_seq_len
268+
attn_page_size = ad_config.attn_page_size
269+
max_num_tokens = ad_config.max_num_tokens
270+
ad_logger.info(f"{max_seq_len=}, {max_batch_size=}, {attn_page_size=}, {max_num_tokens=}")
269271

270272
# initialize model engine
271273
engine = ADEngine.build_from_config(
272274
model=checkpoint_dir,
273-
ad_config=executor_config.pytorch_backend_config,
275+
ad_config=ad_config,
274276
seq_info=SequenceInfo(
275277
max_seq_len=max_seq_len,
276278
max_batch_size=max_batch_size,
277-
page_size=tokens_per_block,
279+
page_size=attn_page_size,
278280
max_num_tokens=max_num_tokens,
279281
),
280282
device="cuda",
281283
)
282284

283285
# resource managers
284286
kv_cache_manager = _CacheManagerWithFakePool(
285-
executor_config.kv_cache_config,
287+
ad_config.kv_cache_config,
286288
num_blocks=engine.cache_seq_interface.info.num_pages,
287-
tokens_per_block=tokens_per_block,
289+
tokens_per_block=attn_page_size,
288290
max_seq_len=max_seq_len,
289291
max_batch_size=max_batch_size,
290292
)
@@ -302,18 +304,17 @@ def create_autodeploy_executor(
302304
sampler = TorchSampler(max_seq_len=max_seq_len)
303305

304306
# creating the executor object
305-
py_config: PyTorchConfig = executor_config.pytorch_backend_config
306307
py_executor = PyExecutor(
307308
resource_manager,
308309
scheduler,
309310
model_engine=engine,
310311
sampler=sampler,
311312
dist=mpi_dist,
312-
disable_overlap_scheduler=py_config.disable_overlap_scheduler,
313-
max_input_len=executor_config.max_input_len,
314-
max_batch_size=executor_config.max_batch_size,
315-
max_draft_tokens=executor_config.speculative_config.max_draft_tokens
316-
if executor_config.speculative_config is not None
313+
disable_overlap_scheduler=ad_config.disable_overlap_scheduler,
314+
max_input_len=ad_config.max_input_len,
315+
max_batch_size=ad_config.max_batch_size,
316+
max_draft_tokens=ad_config.speculative_config.max_draft_tokens
317+
if ad_config.speculative_config is not None
317318
else 0,
318319
)
319320
return py_executor

0 commit comments

Comments
 (0)