Skip to content

Commit d69bf9f

Browse files
authored
[None][feat] add chat template kwargs support to longbench-v2 (#9544)
Signed-off-by: Fanrong Li <[email protected]>
1 parent 9d2df04 commit d69bf9f

File tree

1 file changed

+18
-4
lines changed

1 file changed

+18
-4
lines changed

tensorrt_llm/evaluate/longbench_v2.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,8 @@ def __init__(self,
6666
output_dir: Optional[str] = None,
6767
random_seed: int = 0,
6868
apply_chat_template: bool = False,
69-
system_prompt: Optional[str] = None):
69+
system_prompt: Optional[str] = None,
70+
chat_template_kwargs: Optional[dict[str, Any]] = None):
7071
"""Initialize LongBench v2 evaluator.
7172
7273
Args:
@@ -85,10 +86,12 @@ def __init__(self,
8586
random_seed: Random seed for reproducibility
8687
apply_chat_template: Whether to apply model's chat template
8788
system_prompt: System prompt to prepend
89+
chat_template_kwargs: Chat template kwargs as JSON string
8890
"""
8991
super().__init__(random_seed=random_seed,
9092
apply_chat_template=apply_chat_template,
91-
system_prompt=system_prompt)
93+
system_prompt=system_prompt,
94+
chat_template_kwargs=chat_template_kwargs)
9295

9396
self.dataset_path = dataset_path
9497
self.num_samples = num_samples
@@ -813,6 +816,15 @@ def _save_results(self, results: List[Dict], metrics: Dict[str, float]):
813816
type=int,
814817
default=32000,
815818
help="Maximum generation length in sampling parameters.")
819+
@click.option(
820+
"--chat_template_kwargs",
821+
type=str,
822+
default=None,
823+
callback=lambda ctx, param, value: json.loads(value) if value else None,
824+
help=
825+
'A JSON string specifying chat template arguments, used to enable features like thinking mode. Examples: '
826+
'\'{"enable_thinking": true}\' for Qwen3, or \'{"thinking": true}\' for DeepSeek-V3.2.'
827+
)
816828
@click.pass_context
817829
@staticmethod
818830
def command(ctx, dataset_path: str, prompts_dir: Optional[str],
@@ -821,7 +833,8 @@ def command(ctx, dataset_path: str, prompts_dir: Optional[str],
821833
cot: bool, no_context: bool, rag: int, max_len: int,
822834
output_dir: Optional[str], random_seed: int,
823835
apply_chat_template: bool, system_prompt: Optional[str],
824-
max_input_length: int, max_output_length: int) -> None:
836+
max_input_length: int, max_output_length: int,
837+
chat_template_kwargs: Optional[dict[str, Any]]) -> None:
825838
llm: Union[LLM, PyTorchLLM] = ctx.obj
826839

827840
sampling_params = SamplingParams(
@@ -844,7 +857,8 @@ def command(ctx, dataset_path: str, prompts_dir: Optional[str],
844857
output_dir=output_dir,
845858
random_seed=random_seed,
846859
apply_chat_template=apply_chat_template,
847-
system_prompt=system_prompt)
860+
system_prompt=system_prompt,
861+
chat_template_kwargs=chat_template_kwargs)
848862

849863
evaluator.evaluate(llm, sampling_params)
850864
llm.shutdown()

0 commit comments

Comments
 (0)