Skip to content

Commit ad662dd

Browse files
authored
chore: disallow arbitrary in llm_args.Configs (NVIDIA#6367)
Signed-off-by: Superjomn <[email protected]>
1 parent 1a69309 commit ad662dd

File tree

2 files changed

+289
-53
lines changed

2 files changed

+289
-53
lines changed

tensorrt_llm/llmapi/llm_args.py

Lines changed: 63 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,16 @@ def Field(default: Any = ...,
9292
return PydanticField(default, **kwargs)
9393

9494

95-
class CudaGraphConfig(BaseModel):
95+
class StrictBaseModel(BaseModel):
96+
"""
97+
A base model that forbids arbitrary fields.
98+
"""
99+
100+
class Config:
101+
extra = "forbid" # globally forbid arbitrary fields
102+
103+
104+
class CudaGraphConfig(StrictBaseModel):
96105
"""
97106
Configuration for CUDA graphs.
98107
"""
@@ -119,8 +128,40 @@ def validate_cuda_graph_max_batch_size(cls, v):
119128
"cuda_graph_config.max_batch_size must be non-negative")
120129
return v
121130

131+
@staticmethod
132+
def _generate_cuda_graph_batch_sizes(max_batch_size: int,
133+
enable_padding: bool) -> List[int]:
134+
"""Generate a list of batch sizes for CUDA graphs.
135+
136+
Args:
137+
max_batch_size: Maximum batch size to generate up to
138+
enable_padding: Whether padding is enabled, which affects the batch size distribution
139+
140+
Returns:
141+
List of batch sizes to create CUDA graphs for
142+
"""
143+
if enable_padding:
144+
batch_sizes = [1, 2, 4] + [i * 8 for i in range(1, 17)]
145+
else:
146+
batch_sizes = list(range(1, 32)) + [32, 64, 128]
147+
148+
# Add powers of 2 up to max_batch_size
149+
batch_sizes += [
150+
2**i for i in range(8, math.floor(math.log(max_batch_size, 2)))
151+
]
152+
153+
# Filter and sort batch sizes
154+
batch_sizes = sorted(
155+
[size for size in batch_sizes if size <= max_batch_size])
156+
157+
# Add max_batch_size if not already included
158+
if max_batch_size != batch_sizes[-1]:
159+
batch_sizes.append(max_batch_size)
160+
161+
return batch_sizes
162+
122163

123-
class MoeConfig(BaseModel):
164+
class MoeConfig(StrictBaseModel):
124165
"""
125166
Configuration for MoE.
126167
"""
@@ -225,7 +266,7 @@ def to_mapping(self) -> Mapping:
225266
auto_parallel=self.auto_parallel)
226267

227268

228-
class CalibConfig(BaseModel):
269+
class CalibConfig(StrictBaseModel):
229270
"""
230271
Calibration configuration.
231272
"""
@@ -277,7 +318,7 @@ class _ModelFormatKind(Enum):
277318
TLLM_ENGINE = 2
278319

279320

280-
class DecodingBaseConfig(BaseModel):
321+
class DecodingBaseConfig(StrictBaseModel):
281322
max_draft_len: Optional[int] = None
282323
speculative_model_dir: Optional[Union[str, Path]] = None
283324

@@ -298,6 +339,7 @@ def from_dict(cls, data: dict):
298339
config_class = config_classes.get(decoding_type)
299340
if config_class is None:
300341
raise ValueError(f"Invalid decoding type: {decoding_type}")
342+
data.pop("decoding_type")
301343

302344
return config_class(**data)
303345

@@ -496,7 +538,7 @@ def mirror_pybind_fields(pybind_class):
496538
"""
497539

498540
def decorator(cls):
499-
assert issubclass(cls, BaseModel)
541+
assert issubclass(cls, StrictBaseModel)
500542
# Get all non-private fields from the C++ class
501543
cpp_fields = PybindMirror.get_pybind_variable_fields(pybind_class)
502544
python_fields = set(cls.model_fields.keys())
@@ -597,7 +639,7 @@ def _to_pybind(self):
597639

598640

599641
@PybindMirror.mirror_pybind_fields(_DynamicBatchConfig)
600-
class DynamicBatchConfig(BaseModel, PybindMirror):
642+
class DynamicBatchConfig(StrictBaseModel, PybindMirror):
601643
"""Dynamic batch configuration.
602644
603645
Controls how batch size and token limits are dynamically adjusted at runtime.
@@ -623,7 +665,7 @@ def _to_pybind(self):
623665

624666

625667
@PybindMirror.mirror_pybind_fields(_SchedulerConfig)
626-
class SchedulerConfig(BaseModel, PybindMirror):
668+
class SchedulerConfig(StrictBaseModel, PybindMirror):
627669
capacity_scheduler_policy: CapacitySchedulerPolicy = Field(
628670
default=CapacitySchedulerPolicy.GUARANTEED_NO_EVICT,
629671
description="The capacity scheduler policy to use")
@@ -645,7 +687,7 @@ def _to_pybind(self):
645687

646688

647689
@PybindMirror.mirror_pybind_fields(_PeftCacheConfig)
648-
class PeftCacheConfig(BaseModel, PybindMirror):
690+
class PeftCacheConfig(StrictBaseModel, PybindMirror):
649691
"""
650692
Configuration for the PEFT cache.
651693
"""
@@ -773,7 +815,7 @@ def supports_backend(self, backend: str) -> bool:
773815

774816

775817
@PybindMirror.mirror_pybind_fields(_KvCacheConfig)
776-
class KvCacheConfig(BaseModel, PybindMirror):
818+
class KvCacheConfig(StrictBaseModel, PybindMirror):
777819
"""
778820
Configuration for the KV cache.
779821
"""
@@ -856,7 +898,7 @@ def _to_pybind(self):
856898

857899

858900
@PybindMirror.mirror_pybind_fields(_ExtendedRuntimePerfKnobConfig)
859-
class ExtendedRuntimePerfKnobConfig(BaseModel, PybindMirror):
901+
class ExtendedRuntimePerfKnobConfig(StrictBaseModel, PybindMirror):
860902
"""
861903
Configuration for extended runtime performance knobs.
862904
"""
@@ -887,7 +929,7 @@ def _to_pybind(self):
887929

888930

889931
@PybindMirror.mirror_pybind_fields(_CacheTransceiverConfig)
890-
class CacheTransceiverConfig(BaseModel, PybindMirror):
932+
class CacheTransceiverConfig(StrictBaseModel, PybindMirror):
891933
"""
892934
Configuration for the cache transceiver.
893935
"""
@@ -947,7 +989,7 @@ def model_name(self) -> Union[str, Path]:
947989
return self.model if isinstance(self.model, str) else None
948990

949991

950-
class BaseLlmArgs(BaseModel):
992+
class BaseLlmArgs(StrictBaseModel):
951993
"""
952994
Base class for both TorchLlmArgs and TrtLlmArgs. It contains all the arguments that are common to both.
953995
"""
@@ -1354,7 +1396,8 @@ def init_build_config(self):
13541396
"""
13551397
Creating a default BuildConfig if none is provided
13561398
"""
1357-
if self.build_config is None:
1399+
build_config = getattr(self, "build_config", None)
1400+
if build_config is None:
13581401
kwargs = {}
13591402
if self.max_batch_size:
13601403
kwargs["max_batch_size"] = self.max_batch_size
@@ -1367,10 +1410,10 @@ def init_build_config(self):
13671410
if self.max_input_len:
13681411
kwargs["max_input_len"] = self.max_input_len
13691412
self.build_config = BuildConfig(**kwargs)
1370-
1371-
assert isinstance(
1372-
self.build_config, BuildConfig
1373-
), f"build_config is not initialized: {self.build_config}"
1413+
else:
1414+
assert isinstance(
1415+
build_config,
1416+
BuildConfig), f"build_config is not initialized: {build_config}"
13741417
return self
13751418

13761419
@model_validator(mode="after")
@@ -1813,7 +1856,7 @@ class LoadFormat(Enum):
18131856
DUMMY = 1
18141857

18151858

1816-
class TorchCompileConfig(BaseModel):
1859+
class TorchCompileConfig(StrictBaseModel):
18171860
"""
18181861
Configuration for torch.compile.
18191862
"""
@@ -2049,38 +2092,6 @@ def validate_checkpoint_format(self):
20492092

20502093
return self
20512094

2052-
@staticmethod
2053-
def _generate_cuda_graph_batch_sizes(max_batch_size: int,
2054-
enable_padding: bool) -> List[int]:
2055-
"""Generate a list of batch sizes for CUDA graphs.
2056-
2057-
Args:
2058-
max_batch_size: Maximum batch size to generate up to
2059-
enable_padding: Whether padding is enabled, which affects the batch size distribution
2060-
2061-
Returns:
2062-
List of batch sizes to create CUDA graphs for
2063-
"""
2064-
if enable_padding:
2065-
batch_sizes = [1, 2, 4] + [i * 8 for i in range(1, 17)]
2066-
else:
2067-
batch_sizes = list(range(1, 32)) + [32, 64, 128]
2068-
2069-
# Add powers of 2 up to max_batch_size
2070-
batch_sizes += [
2071-
2**i for i in range(8, math.floor(math.log(max_batch_size, 2)))
2072-
]
2073-
2074-
# Filter and sort batch sizes
2075-
batch_sizes = sorted(
2076-
[size for size in batch_sizes if size <= max_batch_size])
2077-
2078-
# Add max_batch_size if not already included
2079-
if max_batch_size != batch_sizes[-1]:
2080-
batch_sizes.append(max_batch_size)
2081-
2082-
return batch_sizes
2083-
20842095
@model_validator(mode="after")
20852096
def validate_load_balancer(self) -> 'TorchLlmArgs':
20862097
from .._torch import MoeLoadBalancerConfig
@@ -2117,7 +2128,7 @@ def validate_cuda_graph_config(self) -> 'TorchLlmArgs':
21172128
if config.batch_sizes:
21182129
config.batch_sizes = sorted(config.batch_sizes)
21192130
if config.max_batch_size != 0:
2120-
if config.batch_sizes != self._generate_cuda_graph_batch_sizes(
2131+
if config.batch_sizes != CudaGraphConfig._generate_cuda_graph_batch_sizes(
21212132
config.max_batch_size, config.enable_padding):
21222133
raise ValueError(
21232134
"Please don't set both cuda_graph_config.batch_sizes "
@@ -2129,7 +2140,7 @@ def validate_cuda_graph_config(self) -> 'TorchLlmArgs':
21292140
config.max_batch_size = max(config.batch_sizes)
21302141
else:
21312142
max_batch_size = config.max_batch_size or 128
2132-
generated_sizes = self._generate_cuda_graph_batch_sizes(
2143+
generated_sizes = CudaGraphConfig._generate_cuda_graph_batch_sizes(
21332144
max_batch_size, config.enable_padding)
21342145
config.batch_sizes = generated_sizes
21352146
config.max_batch_size = max_batch_size

0 commit comments

Comments
 (0)