2020 MultimodalRuntimeData )
2121from tensorrt_llm .inputs .registry import (create_input_processor ,
2222 create_input_processor_with_hash )
23- from tensorrt_llm .llmapi .llm_args import TorchLlmArgs
23+ from tensorrt_llm .llmapi .llm_args import (CudaGraphConfig , TorchCompileConfig ,
24+ TorchLlmArgs )
2425from tensorrt_llm .logger import logger
2526from tensorrt_llm .lora_helper import LoraConfig
2627from tensorrt_llm .lora_manager import LoraModelConfig
5354from ..utils import (get_model_extra_attrs ,
5455 set_per_request_piecewise_cuda_graph_flag ,
5556 set_torch_compiling , with_model_extra_attrs )
56- from .config import PyTorchConfig , _construct_checkpoint_loader
57+ from .config import _construct_checkpoint_loader
5758from .config_utils import is_mla
5859from .cuda_graph_runner import CUDAGraphRunner
5960from .guided_decoder import CapturableGuidedDecoder
@@ -131,7 +132,7 @@ def __init__(
131132 self ,
132133 * ,
133134 model_path : str ,
134- pytorch_backend_config : PyTorchConfig ,
135+ llm_args : TorchLlmArgs ,
135136 mapping : Optional [Mapping ] = None ,
136137 attn_runtime_features : Optional [AttentionRuntimeFeatures ] = None ,
137138 dist : Optional [MPIDist ] = None ,
@@ -140,10 +141,7 @@ def __init__(
140141 drafting_loop_wrapper : Optional [Callable [[torch .nn .Module ],
141142 torch .nn .Module ]] = None ,
142143 model : Optional [torch .nn .Module ] = None ,
143- llm_args : Optional [TorchLlmArgs ] = None ,
144144 ):
145- assert llm_args is not None , "llm_args must be provided for PyTorchModelEngine"
146-
147145 self .forward_pass_callable = None
148146 self .ub_buffers = None
149147 (
@@ -168,7 +166,7 @@ def __init__(
168166 self .dist = dist
169167 if dist is not None :
170168 ExpertStatistic .create (self .dist .rank )
171- self .pytorch_backend_config = pytorch_backend_config
169+ self .llm_args = llm_args
172170 self .original_max_draft_len = spec_config .max_draft_len if spec_config is not None else 0
173171 self .original_max_total_draft_tokens = spec_config .max_total_draft_tokens if spec_config is not None else 0
174172
@@ -192,7 +190,7 @@ def __init__(
192190 lora_config : Optional [
193191 LoraConfig ] = None if is_draft_model else llm_args .lora_config
194192 loader = ModelLoader (
195- pytorch_backend_config = pytorch_backend_config ,
193+ llm_args = llm_args ,
196194 mapping = self .mapping ,
197195 spec_config = self .spec_config ,
198196 sparse_attention_config = self .sparse_attention_config ,
@@ -215,7 +213,7 @@ def __init__(
215213 # In case that some tests use stub models and override `_load_model`.
216214 if not hasattr (self .model , 'extra_attrs' ):
217215 self .model .extra_attrs = {}
218- if self .pytorch_backend_config .enable_layerwise_nvtx_marker :
216+ if self .llm_args .enable_layerwise_nvtx_marker :
219217 layerwise_nvtx_marker = LayerwiseNvtxMarker ()
220218 module_prefix = 'Model'
221219 if self .model .model_config and self .model .model_config .pretrained_config and self .model .model_config .pretrained_config .architectures :
@@ -224,19 +222,39 @@ def __init__(
224222 layerwise_nvtx_marker .register_hooks (self .model , module_prefix )
225223
226224 self .enable_attention_dp = self .model .model_config .mapping .enable_attention_dp
227- self ._disable_overlap_scheduler = self .pytorch_backend_config .disable_overlap_scheduler
225+ self ._disable_overlap_scheduler = self .llm_args .disable_overlap_scheduler
228226 self ._torch_compile_backend = None
229227 self .dtype = self .model .config .torch_dtype
230228 self ._init_model_capacity ()
231229
232- self ._torch_compile_backend = None
230+ self .cuda_graph_config = self .llm_args .cuda_graph_config
231+ cuda_graph_batch_sizes = self .cuda_graph_config .batch_sizes if self .cuda_graph_config else CudaGraphConfig .model_fields [
232+ 'batch_sizes' ].default
233+ cuda_graph_padding_enabled = self .cuda_graph_config .enable_padding if self .cuda_graph_config else CudaGraphConfig .model_fields [
234+ 'enable_padding' ].default
235+
236+ self .torch_compile_config = self .llm_args .torch_compile_config
237+ torch_compile_enabled = bool (self .torch_compile_config is not None )
238+ torch_compile_fullgraph = self .torch_compile_config .enable_fullgraph if self .torch_compile_config is not None else TorchCompileConfig .model_fields [
239+ 'enable_fullgraph' ].default
240+ torch_compile_inductor_enabled = self .torch_compile_config .enable_inductor if self .torch_compile_config is not None else TorchCompileConfig .model_fields [
241+ 'enable_inductor' ].default
242+ torch_compile_piecewise_cuda_graph = self .torch_compile_config .enable_piecewise_cuda_graph if self .torch_compile_config is not None else TorchCompileConfig .model_fields [
243+ 'enable_piecewise_cuda_graph' ].default
244+ torch_compile_piecewise_cuda_graph_num_tokens = self .torch_compile_config .capture_num_tokens if self .torch_compile_config is not None else TorchCompileConfig .model_fields [
245+ 'capture_num_tokens' ].default
246+ torch_compile_enable_userbuffers = self .torch_compile_config .enable_userbuffers if self .torch_compile_config is not None else TorchCompileConfig .model_fields [
247+ 'enable_userbuffers' ].default
248+ torch_compile_max_num_streams = self .torch_compile_config .max_num_streams if self .torch_compile_config is not None else TorchCompileConfig .model_fields [
249+ 'max_num_streams' ].default
250+
233251 # Eagle3 draft model now does not support torch.compile
234- self ._torch_compile_enabled = pytorch_backend_config . torch_compile_enabled and not is_draft_model
235- self ._torch_compile_piecewise_cuda_graph = pytorch_backend_config . torch_compile_piecewise_cuda_graph
252+ self ._torch_compile_enabled = torch_compile_enabled
253+ self ._torch_compile_piecewise_cuda_graph = torch_compile_piecewise_cuda_graph
236254
237255 piecewise_cuda_graph_num_tokens = (
238- pytorch_backend_config . torch_compile_piecewise_cuda_graph_num_tokens
239- or pytorch_backend_config . cuda_graph_batch_sizes or [])
256+ torch_compile_piecewise_cuda_graph_num_tokens
257+ or cuda_graph_batch_sizes or [])
240258
241259 self ._piecewise_cuda_graph_num_tokens = [
242260 i for i in piecewise_cuda_graph_num_tokens
@@ -245,33 +263,30 @@ def __init__(
245263
246264 try :
247265 use_ub_for_nccl = (
248- pytorch_backend_config .allreduce_strategy == "NCCL_SYMMETRIC"
266+ self . llm_args .allreduce_strategy == "NCCL_SYMMETRIC"
249267 and self ._init_userbuffers (self .model .config .hidden_size ))
250268 if self ._torch_compile_enabled :
251269 set_torch_compiling (True )
252270 use_ub = not use_ub_for_nccl and (
253- pytorch_backend_config . torch_compile_enable_userbuffers
271+ torch_compile_enable_userbuffers
254272 and self ._init_userbuffers (self .model .config .hidden_size ))
255273 self ._torch_compile_backend = Backend (
256- pytorch_backend_config . torch_compile_inductor_enabled ,
274+ torch_compile_inductor_enabled ,
257275 enable_userbuffers = use_ub ,
258276 enable_piecewise_cuda_graph = self .
259277 _torch_compile_piecewise_cuda_graph ,
260278 capture_num_tokens = self ._piecewise_cuda_graph_num_tokens ,
261- max_num_streams = pytorch_backend_config .
262- torch_compile_max_num_streams )
279+ max_num_streams = torch_compile_max_num_streams )
263280 if isinstance (self .model , DecoderModelForCausalLM ):
264281 self .model .model = torch .compile (
265282 self .model .model ,
266283 backend = self ._torch_compile_backend ,
267- fullgraph = pytorch_backend_config .torch_compile_fullgraph
268- )
284+ fullgraph = torch_compile_fullgraph )
269285 else :
270286 self .model = torch .compile (
271287 self .model ,
272288 backend = self ._torch_compile_backend ,
273- fullgraph = pytorch_backend_config .torch_compile_fullgraph
274- )
289+ fullgraph = torch_compile_fullgraph )
275290 torch ._dynamo .config .cache_size_limit = 16
276291 else :
277292 set_torch_compiling (False )
@@ -283,7 +298,7 @@ def __init__(
283298 self .is_warmup = False
284299
285300 self .attn_backend = get_attention_backend (
286- pytorch_backend_config .attn_backend ,
301+ self . llm_args .attn_backend ,
287302 sparse_attn_config = self .sparse_attention_config )
288303
289304 if self .is_spec_decode :
@@ -329,13 +344,12 @@ def __init__(
329344 self .iter_states = {}
330345 self ._cuda_graph_mem_pool = self ._torch_compile_backend ._graph_pool_handle if self ._torch_compile_enabled else None
331346
332- self ._cuda_graph_padding_enabled = pytorch_backend_config . cuda_graph_padding_enabled
347+ self ._cuda_graph_padding_enabled = cuda_graph_padding_enabled
333348
334349 self ._cuda_graph_batch_sizes = _filter_cuda_graph_batch_sizes (
335- pytorch_backend_config .cuda_graph_batch_sizes , self .batch_size ,
336- self .max_num_tokens , self .original_max_total_draft_tokens ,
337- self ._cuda_graph_padding_enabled
338- ) if pytorch_backend_config .cuda_graph_batch_sizes else []
350+ cuda_graph_batch_sizes , self .batch_size , self .max_num_tokens ,
351+ self .original_max_total_draft_tokens ,
352+ self ._cuda_graph_padding_enabled ) if cuda_graph_batch_sizes else []
339353
340354 self ._max_cuda_graph_batch_size = (self ._cuda_graph_batch_sizes [- 1 ] if
341355 self ._cuda_graph_batch_sizes else 0 )
@@ -554,7 +568,7 @@ def _run_torch_compile_warmup(self, resource_manager: ResourceManager):
554568
555569 def _run_autotuner_warmup (self , resource_manager : ResourceManager ):
556570 """Runs a forward pass to populate the autotuner cache."""
557- if not self .pytorch_backend_config .enable_autotuner :
571+ if not self .llm_args .enable_autotuner :
558572 return
559573
560574 logger .info ("Running autotuner warmup..." )
@@ -2299,7 +2313,7 @@ def forward(
22992313
23002314 with MoeLoadBalancerIterContext (moe_load_balancer ):
23012315 # Special handling for multimodal encoder only mode
2302- if self .pytorch_backend_config .mm_encoder_only :
2316+ if self .llm_args .mm_encoder_only :
23032317 return self ._forward_step_mm_encoder_only (
23042318 inputs , scheduled_requests )
23052319 else :
@@ -2463,7 +2477,7 @@ def _init_userbuffers(self, hidden_size):
24632477 # Disable UB for unsupported platforms
24642478 if not ub .ub_supported ():
24652479 return False
2466- use_nccl_symmetric = self .pytorch_backend_config .allreduce_strategy == "NCCL_SYMMETRIC"
2480+ use_nccl_symmetric = self .llm_args .allreduce_strategy == "NCCL_SYMMETRIC"
24672481 ub .initialize_userbuffers_manager (
24682482 self .mapping .tp_size , self .mapping .pp_size , self .mapping .cp_size ,
24692483 self .mapping .rank , self .mapping .gpus_per_node ,
0 commit comments