Skip to content

Commit 616d1df

Browse files
authored
[None][chore] set the default value of max_num_tokens explicitly (#8208)
Signed-off-by: junq <22017000+QiJune@users.noreply.github.com>
1 parent 6a6124d commit 616d1df

File tree

3 files changed

+8
-9
lines changed

3 files changed

+8
-9
lines changed

tensorrt_llm/_torch/pyexecutor/py_executor_creator.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -248,9 +248,6 @@ def create_py_executor(
248248
max_batch_size,
249249
) = llm_args.get_runtime_sizes()
250250

251-
if max_num_tokens is None:
252-
max_num_tokens = 8192
253-
254251
tokens_per_block = kv_cache_config.tokens_per_block
255252
if pytorch_backend_config.attn_backend == "VANILLA":
256253
tokens_per_block = max_num_tokens

tensorrt_llm/llmapi/llm_args.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1595,7 +1595,7 @@ class BaseLlmArgs(StrictBaseModel):
15951595
description="The maximum beam width.")
15961596

15971597
max_num_tokens: Optional[int] = Field(
1598-
default=None, description="The maximum number of tokens.")
1598+
default=8192, description="The maximum number of tokens.")
15991599

16001600
gather_generation_logits: bool = Field(
16011601
default=False,
@@ -1894,13 +1894,15 @@ def validate_build_config_with_runtime_params(self):
18941894

18951895
if self.max_batch_size is not None:
18961896
if self.max_batch_size > self.build_config.max_batch_size:
1897-
raise ValueError(
1898-
f"max_batch_size [{self.max_batch_size}] is greater than build_config.max_batch_size [{self.build_config.max_batch_size}] in build_config"
1897+
self.max_batch_size = self.build_config.max_batch_size
1898+
logger.warning(
1899+
f"max_batch_size [{self.max_batch_size}] is overridden by build_config.max_batch_size [{self.build_config.max_batch_size}] in build_config"
18991900
)
19001901
if self.max_num_tokens is not None:
19011902
if self.max_num_tokens > self.build_config.max_num_tokens:
1902-
raise ValueError(
1903-
f"max_num_tokens [{self.max_num_tokens}] is greater than build_config.max_num_tokens [{self.build_config.max_num_tokens}] in build_config"
1903+
self.max_num_tokens = self.build_config.max_num_tokens
1904+
logger.warning(
1905+
f"max_num_tokens [{self.max_num_tokens}] is overridden by build_config.max_num_tokens [{self.build_config.max_num_tokens}] in build_config"
19041906
)
19051907
if self.max_seq_len is not None:
19061908
if self.max_seq_len != self.build_config.max_seq_len:

tests/unittest/api_stability/references_committed/llm.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ methods:
7676
default: null
7777
max_num_tokens:
7878
annotation: Optional[int]
79-
default: null
79+
default: 8192
8080
# Misc
8181
load_format:
8282
annotation: Union[str, tensorrt_llm.llmapi.llm_args.LoadFormat]

0 commit comments

Comments
 (0)