Skip to content

Commit 89e0117

Browse files
authored
[TRTLLM-8836][chore] Create ModelEngine from LlmArgs (#8600)
Signed-off-by: junq <[email protected]>
1 parent d798d66 commit 89e0117

File tree

9 files changed

+160
-181
lines changed

9 files changed

+160
-181
lines changed

tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -143,18 +143,16 @@ def __init__(
143143
"""Initialize the engine with model and sequence information."""
144144
# NOTE (lucaslie): create a fake Namespace to satisfy PyExecutor requirements...
145145
# This is not correctly declared in the base ModelEngine class though...
146-
self.pytorch_backend_config = SimpleNamespace()
147-
self.pytorch_backend_config.print_iter_log = False
148-
self.pytorch_backend_config.enable_iter_perf_stats = False
149-
self.pytorch_backend_config.enable_iter_req_stats = False
150-
self.pytorch_backend_config.stream_interval = 1
151-
self.pytorch_backend_config.attention_dp_enable_balance = False
152-
self.pytorch_backend_config.attention_dp_time_out_iters = 50
153-
self.pytorch_backend_config.attention_dp_batching_wait_iters = 10
154-
self.pytorch_backend_config.batch_wait_timeout_ms = 0
155-
self.pytorch_backend_config.batch_wait_timeout_iters = 0
156-
self.pytorch_backend_config.batch_wait_max_tokens_ratio = 0.0
157-
self.pytorch_backend_config.max_num_tokens = seq_info.max_num_tokens
146+
self.llm_args = SimpleNamespace()
147+
self.llm_args.print_iter_log = False
148+
self.llm_args.enable_iter_perf_stats = False
149+
self.llm_args.enable_iter_req_stats = False
150+
self.llm_args.stream_interval = 1
151+
self.llm_args.attention_dp_config = None
152+
self.llm_args.batch_wait_timeout_ms = 0
153+
self.llm_args.batch_wait_timeout_iters = 0
154+
self.llm_args.batch_wait_max_tokens_ratio = 0.0
155+
self.llm_args.max_num_tokens = seq_info.max_num_tokens
158156
self.iter_counter = 0
159157

160158
# NOTE (lucaslie): not a declared base member in the base class; required by PyExecutor...

tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,8 @@ def __init__(self, engine: "PyTorchModelEngine"):
3636
self.engine_ref = weakref.ref(engine)
3737

3838
# High-level configuration
39-
config = engine.pytorch_backend_config
40-
self.enabled = config.use_cuda_graph
41-
self.padding_enabled = config.cuda_graph_padding_enabled
39+
self.enabled = engine.llm_args.cuda_graph_config is not None
40+
self.padding_enabled = engine._cuda_graph_padding_enabled
4241
self.supported_batch_sizes = engine._cuda_graph_batch_sizes
4342
self.max_supported_batch_size = engine._max_cuda_graph_batch_size
4443
self.max_beam_width = engine.max_beam_width

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 47 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@
2020
MultimodalRuntimeData)
2121
from tensorrt_llm.inputs.registry import (create_input_processor,
2222
create_input_processor_with_hash)
23-
from tensorrt_llm.llmapi.llm_args import TorchLlmArgs
23+
from tensorrt_llm.llmapi.llm_args import (CudaGraphConfig, TorchCompileConfig,
24+
TorchLlmArgs)
2425
from tensorrt_llm.logger import logger
2526
from tensorrt_llm.lora_helper import LoraConfig
2627
from tensorrt_llm.lora_manager import LoraModelConfig
@@ -53,7 +54,7 @@
5354
from ..utils import (get_model_extra_attrs,
5455
set_per_request_piecewise_cuda_graph_flag,
5556
set_torch_compiling, with_model_extra_attrs)
56-
from .config import PyTorchConfig, _construct_checkpoint_loader
57+
from .config import _construct_checkpoint_loader
5758
from .config_utils import is_mla
5859
from .cuda_graph_runner import CUDAGraphRunner
5960
from .guided_decoder import CapturableGuidedDecoder
@@ -131,7 +132,7 @@ def __init__(
131132
self,
132133
*,
133134
model_path: str,
134-
pytorch_backend_config: PyTorchConfig,
135+
llm_args: TorchLlmArgs,
135136
mapping: Optional[Mapping] = None,
136137
attn_runtime_features: Optional[AttentionRuntimeFeatures] = None,
137138
dist: Optional[MPIDist] = None,
@@ -140,10 +141,7 @@ def __init__(
140141
drafting_loop_wrapper: Optional[Callable[[torch.nn.Module],
141142
torch.nn.Module]] = None,
142143
model: Optional[torch.nn.Module] = None,
143-
llm_args: Optional[TorchLlmArgs] = None,
144144
):
145-
assert llm_args is not None, "llm_args must be provided for PyTorchModelEngine"
146-
147145
self.forward_pass_callable = None
148146
self.ub_buffers = None
149147
(
@@ -168,7 +166,7 @@ def __init__(
168166
self.dist = dist
169167
if dist is not None:
170168
ExpertStatistic.create(self.dist.rank)
171-
self.pytorch_backend_config = pytorch_backend_config
169+
self.llm_args = llm_args
172170
self.original_max_draft_len = spec_config.max_draft_len if spec_config is not None else 0
173171
self.original_max_total_draft_tokens = spec_config.max_total_draft_tokens if spec_config is not None else 0
174172

@@ -192,7 +190,7 @@ def __init__(
192190
lora_config: Optional[
193191
LoraConfig] = None if is_draft_model else llm_args.lora_config
194192
loader = ModelLoader(
195-
pytorch_backend_config=pytorch_backend_config,
193+
llm_args=llm_args,
196194
mapping=self.mapping,
197195
spec_config=self.spec_config,
198196
sparse_attention_config=self.sparse_attention_config,
@@ -215,7 +213,7 @@ def __init__(
215213
# In case that some tests use stub models and override `_load_model`.
216214
if not hasattr(self.model, 'extra_attrs'):
217215
self.model.extra_attrs = {}
218-
if self.pytorch_backend_config.enable_layerwise_nvtx_marker:
216+
if self.llm_args.enable_layerwise_nvtx_marker:
219217
layerwise_nvtx_marker = LayerwiseNvtxMarker()
220218
module_prefix = 'Model'
221219
if self.model.model_config and self.model.model_config.pretrained_config and self.model.model_config.pretrained_config.architectures:
@@ -224,19 +222,39 @@ def __init__(
224222
layerwise_nvtx_marker.register_hooks(self.model, module_prefix)
225223

226224
self.enable_attention_dp = self.model.model_config.mapping.enable_attention_dp
227-
self._disable_overlap_scheduler = self.pytorch_backend_config.disable_overlap_scheduler
225+
self._disable_overlap_scheduler = self.llm_args.disable_overlap_scheduler
228226
self._torch_compile_backend = None
229227
self.dtype = self.model.config.torch_dtype
230228
self._init_model_capacity()
231229

232-
self._torch_compile_backend = None
230+
self.cuda_graph_config = self.llm_args.cuda_graph_config
231+
cuda_graph_batch_sizes = self.cuda_graph_config.batch_sizes if self.cuda_graph_config else CudaGraphConfig.model_fields[
232+
'batch_sizes'].default
233+
cuda_graph_padding_enabled = self.cuda_graph_config.enable_padding if self.cuda_graph_config else CudaGraphConfig.model_fields[
234+
'enable_padding'].default
235+
236+
self.torch_compile_config = self.llm_args.torch_compile_config
237+
torch_compile_enabled = bool(self.torch_compile_config is not None)
238+
torch_compile_fullgraph = self.torch_compile_config.enable_fullgraph if self.torch_compile_config is not None else TorchCompileConfig.model_fields[
239+
'enable_fullgraph'].default
240+
torch_compile_inductor_enabled = self.torch_compile_config.enable_inductor if self.torch_compile_config is not None else TorchCompileConfig.model_fields[
241+
'enable_inductor'].default
242+
torch_compile_piecewise_cuda_graph = self.torch_compile_config.enable_piecewise_cuda_graph if self.torch_compile_config is not None else TorchCompileConfig.model_fields[
243+
'enable_piecewise_cuda_graph'].default
244+
torch_compile_piecewise_cuda_graph_num_tokens = self.torch_compile_config.capture_num_tokens if self.torch_compile_config is not None else TorchCompileConfig.model_fields[
245+
'capture_num_tokens'].default
246+
torch_compile_enable_userbuffers = self.torch_compile_config.enable_userbuffers if self.torch_compile_config is not None else TorchCompileConfig.model_fields[
247+
'enable_userbuffers'].default
248+
torch_compile_max_num_streams = self.torch_compile_config.max_num_streams if self.torch_compile_config is not None else TorchCompileConfig.model_fields[
249+
'max_num_streams'].default
250+
233251
# Eagle3 draft model now does not support torch.compile
234-
self._torch_compile_enabled = pytorch_backend_config.torch_compile_enabled and not is_draft_model
235-
self._torch_compile_piecewise_cuda_graph = pytorch_backend_config.torch_compile_piecewise_cuda_graph
252+
self._torch_compile_enabled = torch_compile_enabled
253+
self._torch_compile_piecewise_cuda_graph = torch_compile_piecewise_cuda_graph
236254

237255
piecewise_cuda_graph_num_tokens = (
238-
pytorch_backend_config.torch_compile_piecewise_cuda_graph_num_tokens
239-
or pytorch_backend_config.cuda_graph_batch_sizes or [])
256+
torch_compile_piecewise_cuda_graph_num_tokens
257+
or cuda_graph_batch_sizes or [])
240258

241259
self._piecewise_cuda_graph_num_tokens = [
242260
i for i in piecewise_cuda_graph_num_tokens
@@ -245,33 +263,30 @@ def __init__(
245263

246264
try:
247265
use_ub_for_nccl = (
248-
pytorch_backend_config.allreduce_strategy == "NCCL_SYMMETRIC"
266+
self.llm_args.allreduce_strategy == "NCCL_SYMMETRIC"
249267
and self._init_userbuffers(self.model.config.hidden_size))
250268
if self._torch_compile_enabled:
251269
set_torch_compiling(True)
252270
use_ub = not use_ub_for_nccl and (
253-
pytorch_backend_config.torch_compile_enable_userbuffers
271+
torch_compile_enable_userbuffers
254272
and self._init_userbuffers(self.model.config.hidden_size))
255273
self._torch_compile_backend = Backend(
256-
pytorch_backend_config.torch_compile_inductor_enabled,
274+
torch_compile_inductor_enabled,
257275
enable_userbuffers=use_ub,
258276
enable_piecewise_cuda_graph=self.
259277
_torch_compile_piecewise_cuda_graph,
260278
capture_num_tokens=self._piecewise_cuda_graph_num_tokens,
261-
max_num_streams=pytorch_backend_config.
262-
torch_compile_max_num_streams)
279+
max_num_streams=torch_compile_max_num_streams)
263280
if isinstance(self.model, DecoderModelForCausalLM):
264281
self.model.model = torch.compile(
265282
self.model.model,
266283
backend=self._torch_compile_backend,
267-
fullgraph=pytorch_backend_config.torch_compile_fullgraph
268-
)
284+
fullgraph=torch_compile_fullgraph)
269285
else:
270286
self.model = torch.compile(
271287
self.model,
272288
backend=self._torch_compile_backend,
273-
fullgraph=pytorch_backend_config.torch_compile_fullgraph
274-
)
289+
fullgraph=torch_compile_fullgraph)
275290
torch._dynamo.config.cache_size_limit = 16
276291
else:
277292
set_torch_compiling(False)
@@ -283,7 +298,7 @@ def __init__(
283298
self.is_warmup = False
284299

285300
self.attn_backend = get_attention_backend(
286-
pytorch_backend_config.attn_backend,
301+
self.llm_args.attn_backend,
287302
sparse_attn_config=self.sparse_attention_config)
288303

289304
if self.is_spec_decode:
@@ -329,13 +344,12 @@ def __init__(
329344
self.iter_states = {}
330345
self._cuda_graph_mem_pool = self._torch_compile_backend._graph_pool_handle if self._torch_compile_enabled else None
331346

332-
self._cuda_graph_padding_enabled = pytorch_backend_config.cuda_graph_padding_enabled
347+
self._cuda_graph_padding_enabled = cuda_graph_padding_enabled
333348

334349
self._cuda_graph_batch_sizes = _filter_cuda_graph_batch_sizes(
335-
pytorch_backend_config.cuda_graph_batch_sizes, self.batch_size,
336-
self.max_num_tokens, self.original_max_total_draft_tokens,
337-
self._cuda_graph_padding_enabled
338-
) if pytorch_backend_config.cuda_graph_batch_sizes else []
350+
cuda_graph_batch_sizes, self.batch_size, self.max_num_tokens,
351+
self.original_max_total_draft_tokens,
352+
self._cuda_graph_padding_enabled) if cuda_graph_batch_sizes else []
339353

340354
self._max_cuda_graph_batch_size = (self._cuda_graph_batch_sizes[-1] if
341355
self._cuda_graph_batch_sizes else 0)
@@ -554,7 +568,7 @@ def _run_torch_compile_warmup(self, resource_manager: ResourceManager):
554568

555569
def _run_autotuner_warmup(self, resource_manager: ResourceManager):
556570
"""Runs a forward pass to populate the autotuner cache."""
557-
if not self.pytorch_backend_config.enable_autotuner:
571+
if not self.llm_args.enable_autotuner:
558572
return
559573

560574
logger.info("Running autotuner warmup...")
@@ -2299,7 +2313,7 @@ def forward(
22992313

23002314
with MoeLoadBalancerIterContext(moe_load_balancer):
23012315
# Special handling for multimodal encoder only mode
2302-
if self.pytorch_backend_config.mm_encoder_only:
2316+
if self.llm_args.mm_encoder_only:
23032317
return self._forward_step_mm_encoder_only(
23042318
inputs, scheduled_requests)
23052319
else:
@@ -2463,7 +2477,7 @@ def _init_userbuffers(self, hidden_size):
24632477
# Disable UB for unsupported platforms
24642478
if not ub.ub_supported():
24652479
return False
2466-
use_nccl_symmetric = self.pytorch_backend_config.allreduce_strategy == "NCCL_SYMMETRIC"
2480+
use_nccl_symmetric = self.llm_args.allreduce_strategy == "NCCL_SYMMETRIC"
24672481
ub.initialize_userbuffers_manager(
24682482
self.mapping.tp_size, self.mapping.pp_size, self.mapping.cp_size,
24692483
self.mapping.rank, self.mapping.gpus_per_node,

tensorrt_llm/_torch/pyexecutor/model_loader.py

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import torch
88

99
from tensorrt_llm._utils import str_dtype_to_torch
10+
from tensorrt_llm.llmapi.llm_args import TorchLlmArgs
1011
from tensorrt_llm.logger import logger
1112
from tensorrt_llm.lora_helper import LoraConfig
1213
from tensorrt_llm.mapping import Mapping
@@ -19,7 +20,7 @@
1920
from ..models.modeling_utils import MetaInitMode, timing
2021
from ..modules.fused_moe.moe_load_balancer import (
2122
MoeLoadBalancer, maybe_create_moe_load_balancer)
22-
from .config import LoadFormat, PyTorchConfig
23+
from .config import LoadFormat
2324

2425
_KV_CACHE_MAP = {
2526
"fp8": QuantAlgo.FP8.value,
@@ -157,7 +158,7 @@ class ModelLoader:
157158
"""
158159

159160
def __init__(self,
160-
pytorch_backend_config: PyTorchConfig,
161+
llm_args: TorchLlmArgs,
161162
mapping: Mapping,
162163
spec_config: Optional["DecodingBaseConfig"],
163164
sparse_attention_config: Optional["SparseAttentionConfig"],
@@ -168,14 +169,14 @@ def __init__(self,
168169
Initializes the ModelLoader.
169170
170171
Args:
171-
pytorch_backend_config: Configuration for the PyTorch backend.
172+
llm_args: Configuration for the PyTorch backend.
172173
mapping: The distributed mapping configuration.
173174
spec_config: Configuration for speculative decoding.
174175
max_num_tokens: The maximum number of tokens the engine will handle.
175176
max_seq_len: The maximum sequence length.
176177
lora_config: Configuration for LoRA.
177178
"""
178-
self.pytorch_backend_config = pytorch_backend_config
179+
self.llm_args = llm_args
179180
self.mapping = mapping
180181
self.spec_config = spec_config
181182
self.sparse_attention_config = sparse_attention_config
@@ -200,7 +201,7 @@ def load(
200201
"""
201202
config = self._load_and_validate_config(checkpoint_dir,
202203
checkpoint_loader)
203-
load_format = self.pytorch_backend_config.load_format
204+
load_format = self.llm_args.load_format
204205

205206
with timing("Model init total"), maybe_create_moe_load_balancer(
206207
config, self.mapping) as moe_load_balancer:
@@ -291,30 +292,29 @@ def _load_and_validate_config(
291292
checkpoint_dir,
292293
trust_remote_code=True,
293294
mapping=self.mapping,
294-
enable_min_latency=self.pytorch_backend_config.enable_min_latency,
295-
use_cuda_graph=self.pytorch_backend_config.use_cuda_graph,
296-
force_dynamic_quantization=self.pytorch_backend_config.
297-
force_dynamic_quantization,
295+
enable_min_latency=self.llm_args.enable_min_latency,
296+
use_cuda_graph=self.llm_args.cuda_graph_config is not None,
297+
force_dynamic_quantization=self.llm_args.force_dynamic_quantization,
298298
spec_config=self.spec_config,
299299
sparse_attention_config=self.sparse_attention_config,
300300
max_num_tokens=self.max_num_tokens,
301301
max_seq_len=self.max_seq_len,
302-
moe_max_num_tokens=self.pytorch_backend_config.moe_max_num_tokens,
303-
moe_load_balancer=self.pytorch_backend_config.moe_load_balancer,
302+
moe_max_num_tokens=self.llm_args.moe_config.max_num_tokens,
303+
moe_load_balancer=self.llm_args.moe_config.load_balancer,
304304
lora_config=self.lora_config,
305-
allreduce_strategy=self.pytorch_backend_config.allreduce_strategy,
306-
mm_encoder_only=self.pytorch_backend_config.mm_encoder_only,
307-
attn_backend=self.pytorch_backend_config.attn_backend,
308-
moe_backend=self.pytorch_backend_config.moe_backend,
309-
moe_disable_finalize_fusion=self.pytorch_backend_config.
310-
moe_disable_finalize_fusion,
311-
use_low_precision_moe_combine=self.pytorch_backend_config.
305+
allreduce_strategy=self.llm_args.allreduce_strategy,
306+
mm_encoder_only=self.llm_args.mm_encoder_only,
307+
attn_backend=self.llm_args.attn_backend,
308+
moe_backend=self.llm_args.moe_config.backend,
309+
moe_disable_finalize_fusion=self.llm_args.moe_config.
310+
disable_finalize_fusion,
311+
use_low_precision_moe_combine=self.llm_args.moe_config.
312312
use_low_precision_moe_combine)
313313

314-
validate_and_set_kv_cache_quant(
315-
config, self.pytorch_backend_config.kv_cache_dtype)
314+
validate_and_set_kv_cache_quant(config,
315+
self.llm_args.kv_cache_config.dtype)
316316
validate_and_set_mamba_ssm_cache_dtype(
317-
config, self.pytorch_backend_config.mamba_ssm_cache_dtype)
317+
config, self.llm_args.kv_cache_config.mamba_ssm_cache_dtype)
318318

319319
# Allow overriding the number of layers via environment variable
320320
num_layers_override = int(os.environ.get("TLLM_OVERRIDE_LAYER_NUM",

0 commit comments

Comments
 (0)