@@ -92,7 +92,16 @@ def Field(default: Any = ...,
9292 return PydanticField (default , ** kwargs )
9393
9494
95- class CudaGraphConfig (BaseModel ):
95+ class StrictBaseModel (BaseModel ):
96+ """
97+ A base model that forbids arbitrary fields.
98+ """
99+
100+ class Config :
101+ extra = "forbid" # globally forbid arbitrary fields
102+
103+
104+ class CudaGraphConfig (StrictBaseModel ):
96105 """
97106 Configuration for CUDA graphs.
98107 """
@@ -119,8 +128,40 @@ def validate_cuda_graph_max_batch_size(cls, v):
119128 "cuda_graph_config.max_batch_size must be non-negative" )
120129 return v
121130
131+ @staticmethod
132+ def _generate_cuda_graph_batch_sizes (max_batch_size : int ,
133+ enable_padding : bool ) -> List [int ]:
134+ """Generate a list of batch sizes for CUDA graphs.
135+
136+ Args:
137+ max_batch_size: Maximum batch size to generate up to
138+ enable_padding: Whether padding is enabled, which affects the batch size distribution
139+
140+ Returns:
141+ List of batch sizes to create CUDA graphs for
142+ """
143+ if enable_padding :
144+ batch_sizes = [1 , 2 , 4 ] + [i * 8 for i in range (1 , 17 )]
145+ else :
146+ batch_sizes = list (range (1 , 32 )) + [32 , 64 , 128 ]
147+
148+ # Add powers of 2 up to max_batch_size
149+ batch_sizes += [
150+ 2 ** i for i in range (8 , math .floor (math .log (max_batch_size , 2 )))
151+ ]
152+
153+ # Filter and sort batch sizes
154+ batch_sizes = sorted (
155+ [size for size in batch_sizes if size <= max_batch_size ])
156+
157+ # Add max_batch_size if not already included
158+ if max_batch_size != batch_sizes [- 1 ]:
159+ batch_sizes .append (max_batch_size )
160+
161+ return batch_sizes
162+
122163
123- class MoeConfig (BaseModel ):
164+ class MoeConfig (StrictBaseModel ):
124165 """
125166 Configuration for MoE.
126167 """
@@ -225,7 +266,7 @@ def to_mapping(self) -> Mapping:
225266 auto_parallel = self .auto_parallel )
226267
227268
228- class CalibConfig (BaseModel ):
269+ class CalibConfig (StrictBaseModel ):
229270 """
230271 Calibration configuration.
231272 """
@@ -277,7 +318,7 @@ class _ModelFormatKind(Enum):
277318 TLLM_ENGINE = 2
278319
279320
280- class DecodingBaseConfig (BaseModel ):
321+ class DecodingBaseConfig (StrictBaseModel ):
281322 max_draft_len : Optional [int ] = None
282323 speculative_model_dir : Optional [Union [str , Path ]] = None
283324
@@ -298,6 +339,7 @@ def from_dict(cls, data: dict):
298339 config_class = config_classes .get (decoding_type )
299340 if config_class is None :
300341 raise ValueError (f"Invalid decoding type: { decoding_type } " )
342+ data .pop ("decoding_type" )
301343
302344 return config_class (** data )
303345
@@ -496,7 +538,7 @@ def mirror_pybind_fields(pybind_class):
496538 """
497539
498540 def decorator (cls ):
499- assert issubclass (cls , BaseModel )
541+ assert issubclass (cls , StrictBaseModel )
500542 # Get all non-private fields from the C++ class
501543 cpp_fields = PybindMirror .get_pybind_variable_fields (pybind_class )
502544 python_fields = set (cls .model_fields .keys ())
@@ -597,7 +639,7 @@ def _to_pybind(self):
597639
598640
599641@PybindMirror .mirror_pybind_fields (_DynamicBatchConfig )
600- class DynamicBatchConfig (BaseModel , PybindMirror ):
642+ class DynamicBatchConfig (StrictBaseModel , PybindMirror ):
601643 """Dynamic batch configuration.
602644
603645 Controls how batch size and token limits are dynamically adjusted at runtime.
@@ -623,7 +665,7 @@ def _to_pybind(self):
623665
624666
625667@PybindMirror .mirror_pybind_fields (_SchedulerConfig )
626- class SchedulerConfig (BaseModel , PybindMirror ):
668+ class SchedulerConfig (StrictBaseModel , PybindMirror ):
627669 capacity_scheduler_policy : CapacitySchedulerPolicy = Field (
628670 default = CapacitySchedulerPolicy .GUARANTEED_NO_EVICT ,
629671 description = "The capacity scheduler policy to use" )
@@ -645,7 +687,7 @@ def _to_pybind(self):
645687
646688
647689@PybindMirror .mirror_pybind_fields (_PeftCacheConfig )
648- class PeftCacheConfig (BaseModel , PybindMirror ):
690+ class PeftCacheConfig (StrictBaseModel , PybindMirror ):
649691 """
650692 Configuration for the PEFT cache.
651693 """
@@ -773,7 +815,7 @@ def supports_backend(self, backend: str) -> bool:
773815
774816
775817@PybindMirror .mirror_pybind_fields (_KvCacheConfig )
776- class KvCacheConfig (BaseModel , PybindMirror ):
818+ class KvCacheConfig (StrictBaseModel , PybindMirror ):
777819 """
778820 Configuration for the KV cache.
779821 """
@@ -856,7 +898,7 @@ def _to_pybind(self):
856898
857899
858900@PybindMirror .mirror_pybind_fields (_ExtendedRuntimePerfKnobConfig )
859- class ExtendedRuntimePerfKnobConfig (BaseModel , PybindMirror ):
901+ class ExtendedRuntimePerfKnobConfig (StrictBaseModel , PybindMirror ):
860902 """
861903 Configuration for extended runtime performance knobs.
862904 """
@@ -887,7 +929,7 @@ def _to_pybind(self):
887929
888930
889931@PybindMirror .mirror_pybind_fields (_CacheTransceiverConfig )
890- class CacheTransceiverConfig (BaseModel , PybindMirror ):
932+ class CacheTransceiverConfig (StrictBaseModel , PybindMirror ):
891933 """
892934 Configuration for the cache transceiver.
893935 """
@@ -947,7 +989,7 @@ def model_name(self) -> Union[str, Path]:
947989 return self .model if isinstance (self .model , str ) else None
948990
949991
950- class BaseLlmArgs (BaseModel ):
992+ class BaseLlmArgs (StrictBaseModel ):
951993 """
952994 Base class for both TorchLlmArgs and TrtLlmArgs. It contains all the arguments that are common to both.
953995 """
@@ -1354,7 +1396,8 @@ def init_build_config(self):
13541396 """
13551397 Creating a default BuildConfig if none is provided
13561398 """
1357- if self .build_config is None :
1399+ build_config = getattr (self , "build_config" , None )
1400+ if build_config is None :
13581401 kwargs = {}
13591402 if self .max_batch_size :
13601403 kwargs ["max_batch_size" ] = self .max_batch_size
@@ -1367,10 +1410,10 @@ def init_build_config(self):
13671410 if self .max_input_len :
13681411 kwargs ["max_input_len" ] = self .max_input_len
13691412 self .build_config = BuildConfig (** kwargs )
1370-
1371- assert isinstance (
1372- self . build_config , BuildConfig
1373- ), f"build_config is not initialized: { self . build_config } "
1413+ else :
1414+ assert isinstance (
1415+ build_config ,
1416+ BuildConfig ), f"build_config is not initialized: { build_config } "
13741417 return self
13751418
13761419 @model_validator (mode = "after" )
@@ -1813,7 +1856,7 @@ class LoadFormat(Enum):
18131856 DUMMY = 1
18141857
18151858
1816- class TorchCompileConfig (BaseModel ):
1859+ class TorchCompileConfig (StrictBaseModel ):
18171860 """
18181861 Configuration for torch.compile.
18191862 """
@@ -2049,38 +2092,6 @@ def validate_checkpoint_format(self):
20492092
20502093 return self
20512094
2052- @staticmethod
2053- def _generate_cuda_graph_batch_sizes (max_batch_size : int ,
2054- enable_padding : bool ) -> List [int ]:
2055- """Generate a list of batch sizes for CUDA graphs.
2056-
2057- Args:
2058- max_batch_size: Maximum batch size to generate up to
2059- enable_padding: Whether padding is enabled, which affects the batch size distribution
2060-
2061- Returns:
2062- List of batch sizes to create CUDA graphs for
2063- """
2064- if enable_padding :
2065- batch_sizes = [1 , 2 , 4 ] + [i * 8 for i in range (1 , 17 )]
2066- else :
2067- batch_sizes = list (range (1 , 32 )) + [32 , 64 , 128 ]
2068-
2069- # Add powers of 2 up to max_batch_size
2070- batch_sizes += [
2071- 2 ** i for i in range (8 , math .floor (math .log (max_batch_size , 2 )))
2072- ]
2073-
2074- # Filter and sort batch sizes
2075- batch_sizes = sorted (
2076- [size for size in batch_sizes if size <= max_batch_size ])
2077-
2078- # Add max_batch_size if not already included
2079- if max_batch_size != batch_sizes [- 1 ]:
2080- batch_sizes .append (max_batch_size )
2081-
2082- return batch_sizes
2083-
20842095 @model_validator (mode = "after" )
20852096 def validate_load_balancer (self ) -> 'TorchLlmArgs' :
20862097 from .._torch import MoeLoadBalancerConfig
@@ -2117,7 +2128,7 @@ def validate_cuda_graph_config(self) -> 'TorchLlmArgs':
21172128 if config .batch_sizes :
21182129 config .batch_sizes = sorted (config .batch_sizes )
21192130 if config .max_batch_size != 0 :
2120- if config .batch_sizes != self ._generate_cuda_graph_batch_sizes (
2131+ if config .batch_sizes != CudaGraphConfig ._generate_cuda_graph_batch_sizes (
21212132 config .max_batch_size , config .enable_padding ):
21222133 raise ValueError (
21232134 "Please don't set both cuda_graph_config.batch_sizes "
@@ -2129,7 +2140,7 @@ def validate_cuda_graph_config(self) -> 'TorchLlmArgs':
21292140 config .max_batch_size = max (config .batch_sizes )
21302141 else :
21312142 max_batch_size = config .max_batch_size or 128
2132- generated_sizes = self ._generate_cuda_graph_batch_sizes (
2143+ generated_sizes = CudaGraphConfig ._generate_cuda_graph_batch_sizes (
21332144 max_batch_size , config .enable_padding )
21342145 config .batch_sizes = generated_sizes
21352146 config .max_batch_size = max_batch_size
0 commit comments