1616import json
1717import os
1818from contextlib import contextmanager
19- from typing import Dict , Iterable , List , Optional , Tuple , Union
19+ from typing import Any , Dict , Iterable , List , Optional , Tuple , Union
2020
2121import click
2222import numpy as np
@@ -51,11 +51,13 @@ class LmEvalWrapper(TemplateLM):
5151 def __init__ (self ,
5252 llm : Union [LLM , PyTorchLLM ],
5353 sampling_params : Optional [SamplingParams ] = None ,
54- streaming : bool = False ):
54+ streaming : bool = False ,
55+ chat_template_kwargs : Optional [dict [str , Any ]] = None ):
5556 super ().__init__ ()
5657 self .llm = llm
5758 self .sampling_params = sampling_params
5859 self .streaming = streaming
60+ self .chat_template_kwargs = chat_template_kwargs
5961
6062 @property
6163 def eot_token_id (self ) -> int :
@@ -72,6 +74,7 @@ def apply_chat_template(self,
7274 tokenize = False ,
7375 add_generation_prompt = add_generation_prompt ,
7476 continue_final_message = not add_generation_prompt ,
77+ ** (self .chat_template_kwargs or {}),
7578 )
7679
7780 @property
@@ -146,7 +149,8 @@ def __init__(self,
146149 llm : Union [LLM , PyTorchLLM ],
147150 sampling_params : Optional [SamplingParams ] = None ,
148151 streaming : bool = False ,
149- max_images : int = 999 ):
152+ max_images : int = 999 ,
153+ chat_template_kwargs : Optional [dict [str , Any ]] = None ):
150154 """
151155 Initialize the multimodal wrapper.
152156
@@ -161,6 +165,7 @@ def __init__(self,
161165 # NOTE: Required by lm_eval to identify this as a multimodal model
162166 self .MULTIMODAL = True
163167 self .max_images = max_images
168+ self .chat_template_kwargs = chat_template_kwargs
164169 self .model_type = self ._get_model_type (llm )
165170
166171 # NOTE: In TRT-LLM, currently we do not support interleaved text and image. Instead, we are adding image placeholders at the end of the text or at the beginning of the text.
@@ -237,7 +242,9 @@ def apply_chat_template(self,
237242 mm_placeholder_counts = mm_placeholder_counts ,
238243 tools = None ,
239244 chat_template_kwargs = {
240- "continue_final_message" : not add_generation_prompt
245+ ** (self .chat_template_kwargs or {}),
246+ "continue_final_message" :
247+ not add_generation_prompt ,
241248 })
242249 return output
243250
@@ -301,7 +308,8 @@ def __init__(self,
301308 apply_chat_template : bool = False ,
302309 fewshot_as_multiturn : bool = False ,
303310 system_prompt : Optional [str ] = None ,
304- is_multimodal : bool = False ):
311+ is_multimodal : bool = False ,
312+ chat_template_kwargs : Optional [dict [str , Any ]] = None ):
305313 try :
306314 import lm_eval
307315 except ImportError as e :
@@ -319,7 +327,8 @@ def __init__(self,
319327 super ().__init__ (random_seed = random_seed ,
320328 apply_chat_template = apply_chat_template ,
321329 fewshot_as_multiturn = fewshot_as_multiturn ,
322- system_prompt = system_prompt )
330+ system_prompt = system_prompt ,
331+ chat_template_kwargs = chat_template_kwargs )
323332 self .task_name = task_name
324333 self .dataset_path = dataset_path
325334 self .num_samples = num_samples
@@ -390,7 +399,10 @@ def evaluate(self,
390399 import lm_eval
391400 lm_cls = MultimodalLmEvalWrapper if self .MULTIMODAL else LmEvalWrapper
392401 results = lm_eval .evaluate (
393- lm = lm_cls (llm , sampling_params , streaming ),
402+ lm = lm_cls (llm ,
403+ sampling_params = sampling_params ,
404+ streaming = streaming ,
405+ chat_template_kwargs = self .chat_template_kwargs ),
394406 task_dict = self .task_dict ,
395407 limit = self .num_samples ,
396408 apply_chat_template = self .apply_chat_template ,
@@ -428,7 +440,9 @@ def command_harness(cls, ctx, **kwargs):
428440 fewshot_as_multiturn = kwargs .pop ("fewshot_as_multiturn" ,
429441 False ),
430442 system_prompt = kwargs .pop ("system_prompt" , None ),
431- is_multimodal = kwargs .pop ("is_multimodal" , False ))
443+ is_multimodal = kwargs .pop ("is_multimodal" , False ),
444+ chat_template_kwargs = kwargs .pop ("chat_template_kwargs" ,
445+ None ))
432446 sampling_params = SamplingParams (
433447 max_tokens = kwargs .pop ("max_output_length" ),
434448 truncate_prompt_tokens = kwargs .pop ("max_input_length" ),
@@ -462,6 +476,13 @@ def __init__(self, **kwargs):
462476 is_flag = True ,
463477 default = False ,
464478 help = "Whether to apply chat template." )
479+ @click .option (
480+ "--chat_template_kwargs" ,
481+ type = str ,
482+ default = None ,
483+ callback = lambda ctx , param , value : json .loads (value ) if value else None ,
484+ help =
485+ 'Chat template kwargs as JSON string, e.g., \' {"thinking_budget": 0}\' ' )
465486 @click .option ("--fewshot_as_multiturn" ,
466487 is_flag = True ,
467488 default = False ,
@@ -513,6 +534,13 @@ def __init__(self, **kwargs):
513534 is_flag = True ,
514535 default = False ,
515536 help = "Whether to apply chat template." )
537+ @click .option (
538+ "--chat_template_kwargs" ,
539+ type = str ,
540+ default = None ,
541+ callback = lambda ctx , param , value : json .loads (value ) if value else None ,
542+ help =
543+ 'Chat template kwargs as JSON string, e.g., \' {"thinking_budget": 0}\' ' )
516544 @click .option ("--system_prompt" ,
517545 type = str ,
518546 default = None ,
@@ -556,6 +584,13 @@ def __init__(self, **kwargs):
556584 is_flag = True ,
557585 default = False ,
558586 help = "Whether to apply chat template." )
587+ @click .option (
588+ "--chat_template_kwargs" ,
589+ type = str ,
590+ default = None ,
591+ callback = lambda ctx , param , value : json .loads (value ) if value else None ,
592+ help =
593+ 'Chat template kwargs as JSON string, e.g., \' {"thinking_budget": 0}\' ' )
559594 @click .option ("--system_prompt" ,
560595 type = str ,
561596 default = None ,
@@ -599,6 +634,13 @@ def __init__(self, **kwargs):
599634 is_flag = True ,
600635 default = False ,
601636 help = "Whether to apply chat template." )
637+ @click .option (
638+ "--chat_template_kwargs" ,
639+ type = str ,
640+ default = None ,
641+ callback = lambda ctx , param , value : json .loads (value ) if value else None ,
642+ help =
643+ 'Chat template kwargs as JSON string, e.g., \' {"thinking_budget": 0}\' ' )
602644 @click .option ("--system_prompt" ,
603645 type = str ,
604646 default = None ,
@@ -638,6 +680,13 @@ def __init__(self, **kwargs):
638680 type = int ,
639681 default = 0 ,
640682 help = "Random seed for dataset processing." )
683+ @click .option (
684+ "--chat_template_kwargs" ,
685+ type = str ,
686+ default = None ,
687+ callback = lambda ctx , param , value : json .loads (value ) if value else None ,
688+ help =
689+ 'Chat template kwargs as JSON string, e.g., \' {"thinking_budget": 0}\' ' )
641690 @click .option (
642691 "--system_prompt" ,
643692 type = str ,
0 commit comments