Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tensorrt_llm/_torch/auto_deploy/models/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def __init__(self, *args, **kwargs):
self.tokenizer_kwargs = deep_merge_dicts(self._tokenizer_defaults, self.tokenizer_kwargs)
self.model_kwargs = deep_merge_dicts(
self._model_defaults,
self.model_kwargs,
self.model_kwargs or {},
)

# set sharding config source to huggingface
Expand Down
52 changes: 51 additions & 1 deletion tensorrt_llm/_torch/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import tempfile
from dataclasses import dataclass, field
from pathlib import Path
from typing import Dict, Generic, List, Optional, TypeVar
from typing import Any, Dict, Generic, List, Optional, TypeVar

import filelock
import torch
Expand Down Expand Up @@ -452,6 +452,56 @@ def cached_file(path_or_repo_id, file_name):
# Some checkpoints lack torch_dtype, populate with dtype
pretrained_config.torch_dtype = getattr(pretrained_config, 'dtype',
None)

# Apply model_kwargs to override config parameters if provided
model_kwargs = kwargs.pop('model_kwargs', None)
if model_kwargs:

def _recursive_update_config(config: transformers.PretrainedConfig,
update_dict: Dict[str, Any]):
"""
Recursively update a PretrainedConfig object with values from update_dict.
Args:
config: PretrainedConfig object to update
update_dict: Dictionary with values to update in the config
"""
for key, value_new in update_dict.items():
if not hasattr(config, key):
logger.warning(
f"model_kwargs key '{key}' not found in pretrained_config, ignoring."
)
continue

target_value = getattr(config, key)

# Handle nested PretrainedConfig objects when value is a dict
if isinstance(value_new, dict) and isinstance(
target_value, transformers.PretrainedConfig):
# Recursively update the nested config
logger.info(
f"Recursively updating nested config: {key}")
_recursive_update_config(target_value, value_new)
elif (key in ["torch_dtype", "dtype"]
and isinstance(value_new, str)
and value_new != "auto"):
# check special handling of torch_dtype (DEPRECATED!) and dtype keys to ensure we
# use the correct torch.dtype object instead of a string.
dtype = getattr(torch, value_new)
assert isinstance(dtype,
torch.dtype), f"Invalid {dtype=}"
setattr(config, key, dtype)
logger.info(
f"Applied model_kwargs: {key}={dtype} (previous value: {target_value})"
)
else:
# Direct update for simple values
setattr(config, key, value_new)
logger.info(
f"Applied model_kwargs: {key}={value_new} (previous value: {target_value})"
)

_recursive_update_config(pretrained_config, model_kwargs)

quant_config = QuantConfig()
layer_quant_config = None
moe_backend = kwargs.get('moe_backend', 'CUTLASS')
Expand Down
10 changes: 10 additions & 0 deletions tensorrt_llm/_torch/models/checkpoints/mistral/config_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from tensorrt_llm._torch.model_config import ModelConfig
from tensorrt_llm._torch.models.checkpoints.base_config_loader import BaseConfigLoader
from tensorrt_llm._torch.models.modeling_utils import register_config_loader
from tensorrt_llm.logger import logger
from tensorrt_llm.models.modeling_utils import QuantConfig
from tensorrt_llm.quantization.mode import QuantAlgo

Expand Down Expand Up @@ -327,6 +328,15 @@ def load(self, checkpoint_dir: str, **kwargs) -> ModelConfig:
block_size = (128, 128)
quant_config.group_size = block_size[0]

# model_kwargs is not supported for Mistral format checkpoints
# Extract it from kwargs to avoid passing to ModelConfig.__init__ (which doesn't accept it)
model_kwargs = kwargs.pop("model_kwargs", None)
if model_kwargs:
logger.warning(
"model_kwargs is not supported for Mistral format checkpoints. "
f"Ignoring model_kwargs: {model_kwargs}"
)

kwargs.pop("trust_remote_code", None) # ModelConfig does not have this input parameter
model_config = ModelConfig(
pretrained_config=pretrained_config,
Expand Down
14 changes: 12 additions & 2 deletions tensorrt_llm/_torch/pyexecutor/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,8 +338,8 @@ def _load_and_validate_config(
self, checkpoint_dir: str,
checkpoint_loader: BaseCheckpointLoader) -> ModelConfig:
"""Loads and validates the model configuration."""
config = checkpoint_loader.load_config(
checkpoint_dir,
load_config_kwargs = dict(
checkpoint_dir=checkpoint_dir,
trust_remote_code=True,
mapping=self.mapping,
enable_min_latency=self.llm_args.enable_min_latency,
Expand All @@ -363,6 +363,12 @@ def _load_and_validate_config(
nvfp4_gemm_allowed_backends=self.llm_args.nvfp4_gemm_config.
allowed_backends)

# Only pass model_kwargs if it's explicitly set (not None)
if self.llm_args.model_kwargs is not None:
load_config_kwargs['model_kwargs'] = self.llm_args.model_kwargs

config = checkpoint_loader.load_config(**load_config_kwargs)

# Store nvfp4 config in extra_attrs for Linear layer access
config.extra_attrs[
'nvfp4_gemm_allowed_backends'] = config.nvfp4_gemm_allowed_backends
Expand All @@ -373,9 +379,13 @@ def _load_and_validate_config(
config, self.llm_args.kv_cache_config.mamba_ssm_cache_dtype)

# Allow overriding the number of layers via environment variable
# Note: This is kept for backward compatibility, but model_kwargs is preferred
num_layers_override = int(os.environ.get("TLLM_OVERRIDE_LAYER_NUM",
"0"))
if num_layers_override > 0:
logger.warning(
f"TLLM_OVERRIDE_LAYER_NUM is deprecated. Use model_kwargs instead: "
f"model_kwargs={{'num_hidden_layers': {num_layers_override}}}")
config.pretrained_config.num_hidden_layers = num_layers_override
for sub_config in ["text_config", "vision_config"]:
if hasattr(config.pretrained_config, sub_config):
Expand Down
7 changes: 7 additions & 0 deletions tensorrt_llm/llmapi/llm_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1907,6 +1907,13 @@ class BaseLlmArgs(StrictBaseModel):

# Below are all remaining arguments

model_kwargs: Optional[Dict[str, Any]] = Field(
default=None,
description="Optional parameters overriding model config defaults. "
"Precedence: (1) model_kwargs, (2) model config file, (3) model config class defaults. "
"Unknown keys are ignored",
status="prototype")

pipeline_parallel_size: int = Field(
default=1, description="The pipeline parallel size.")

Expand Down
4 changes: 4 additions & 0 deletions tests/unittest/api_stability/references/llm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,10 @@ methods:
annotation: Optional[Dict[str, str]]
default: null
status: prototype
model_kwargs:
annotation: Optional[Dict[str, Any]]
default: null
status: prototype
return_annotation: None
generate:
parameters:
Expand Down
25 changes: 25 additions & 0 deletions tests/unittest/llmapi/test_llm_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,19 @@ def test_llm_args_with_pydantic_options(self):
assert llm_args.max_num_tokens == 256
assert llm_args.max_seq_len == 128

@pytest.mark.parametrize("llm_args_cls", [TrtLlmArgs, TorchLlmArgs])
def test_llm_args_with_model_kwargs(self, llm_args_cls):
yaml_content = """
model_kwargs:
num_hidden_layers: 2
"""
dict_content = self._yaml_to_dict(yaml_content)
llm_args = llm_args_cls(model=llama_model_path)
llm_args_dict = update_llm_args_with_extra_dict(llm_args.model_dump(),
dict_content)
llm_args = llm_args_cls(**llm_args_dict)
assert llm_args.model_kwargs['num_hidden_layers'] == 2


def check_defaults(py_config_cls, pybind_config_cls):
py_config = py_config_cls()
Expand Down Expand Up @@ -445,6 +458,18 @@ def test_dynamic_setattr(self):
args = TorchLlmArgs(model=llama_model_path)
args.invalid_arg = 1

@print_traceback_on_error
def test_model_kwargs_with_num_hidden_layers(self):
"""Test that model_kwargs can override num_hidden_layers."""
from tensorrt_llm._torch.model_config import ModelConfig
config_no_kwargs = ModelConfig.from_pretrained(
llama_model_path).pretrained_config
model_kwargs = {'num_hidden_layers': 2}
config_with_kwargs = ModelConfig.from_pretrained(
llama_model_path, model_kwargs=model_kwargs).pretrained_config
assert config_no_kwargs.num_hidden_layers != config_with_kwargs.num_hidden_layers
assert config_with_kwargs.num_hidden_layers == 2


class TestTrtLlmArgs:

Expand Down