@@ -632,7 +632,7 @@ def generate(
632632 temperature : float = 0.7 ,
633633 do_sample : bool = True ,
634634 top_p : float = 0.9 ,
635- skip_chat_template : bool = False ,
635+ use_chat_template : bool = False ,
636636 ** kwargs
637637 ) -> str :
638638 """Generate text from input, respecting the model's context length."""
@@ -648,10 +648,10 @@ def generate(
648648 safe_input_length , max_new_tokens = self ._get_safe_generation_params (max_length )
649649
650650 # Prefer chat-template tokenization when available to ensure special tokens are handled
651- # Skip chat template if skip_chat_template =True (treat chat models as completion models)
651+ # Apply chat template only when use_chat_template =True
652652 inputs = None
653653 prefers_chat_template = False
654- if not skip_chat_template :
654+ if use_chat_template :
655655 if self .is_chat_model is True :
656656 prefers_chat_template = True
657657 # Heuristic fallback if metadata wasn't provided
@@ -757,7 +757,7 @@ def generate(
757757 # Decode only the newly generated tokens
758758 new_tokens = outputs [0 ][input_length :]
759759 # When skipping chat template, preserve special tokens (match try_chat_model_without_template.py behavior)
760- skip_special_tokens = not skip_chat_template
760+ skip_special_tokens = use_chat_template
761761 generated_text = self .tokenizer .decode (new_tokens , skip_special_tokens = skip_special_tokens )
762762
763763 return generated_text .strip ()
@@ -1035,9 +1035,9 @@ def __init__(
10351035 except Exception as e :
10361036 raise RuntimeError (f"Failed to initialize vLLM engine: { e } " )
10371037
1038- def _format_prompt (self , user_text : str , skip_chat_template : bool = False ) -> str :
1039- # Prefer chat template for chat models, unless skip_chat_template =True
1040- if skip_chat_template :
1038+ def _format_prompt (self , user_text : str , use_chat_template : bool = False ) -> str :
1039+ # Apply chat template for chat models only when use_chat_template =True
1040+ if not use_chat_template :
10411041 return user_text
10421042 try :
10431043 prefers_chat = False
@@ -1062,12 +1062,12 @@ def generate(
10621062 temperature : float = 0.7 ,
10631063 do_sample : bool = True ,
10641064 top_p : float = 0.9 ,
1065- skip_chat_template : bool = False ,
1065+ use_chat_template : bool = False ,
10661066 ** kwargs
10671067 ) -> str :
10681068 try :
10691069 from vllm import SamplingParams
1070- prompt = self ._format_prompt (input_text , skip_chat_template = skip_chat_template )
1070+ prompt = self ._format_prompt (input_text , use_chat_template = use_chat_template )
10711071 # Map our "max_length" contract to vLLM's max_tokens for new tokens
10721072 # Our safe length logic is in HF wrapper; here we approximate with max_tokens
10731073 params = SamplingParams (
@@ -1091,7 +1091,7 @@ def generate_batch(
10911091 temperature : float = 0.7 ,
10921092 do_sample : bool = True ,
10931093 top_p : float = 0.9 ,
1094- skip_chat_template : bool = False ,
1094+ use_chat_template : bool = False ,
10951095 ** kwargs
10961096 ) -> List [str ]:
10971097 """Generate for a list of prompts in one vLLM call.
@@ -1103,7 +1103,7 @@ def generate_batch(
11031103 return []
11041104 try :
11051105 from vllm import SamplingParams
1106- formatted = [self ._format_prompt (p , skip_chat_template = skip_chat_template ) for p in prompts ]
1106+ formatted = [self ._format_prompt (p , use_chat_template = use_chat_template ) for p in prompts ]
11071107 params = SamplingParams (
11081108 max_tokens = max_length ,
11091109 temperature = temperature ,
@@ -1122,7 +1122,7 @@ def generate_batch(
11221122 self .logger .error (f"vLLM batch generation failed: { e } " )
11231123 # Fall back to sequential to salvage outputs
11241124 return [
1125- self .generate (p , max_length = max_length , temperature = temperature , do_sample = do_sample , top_p = top_p , skip_chat_template = skip_chat_template , ** kwargs )
1125+ self .generate (p , max_length = max_length , temperature = temperature , do_sample = do_sample , top_p = top_p , use_chat_template = use_chat_template , ** kwargs )
11261126 for p in prompts
11271127 ]
11281128
@@ -1383,7 +1383,7 @@ def generate_batch(
13831383 completion_window = str (kwargs .pop ("batch_completion_window" , "24h" ))
13841384
13851385 # Wrapper-only kwargs should not be forwarded to the provider payload.
1386- kwargs .pop ("skip_chat_template " , None )
1386+ kwargs .pop ("use_chat_template " , None )
13871387
13881388 if not prefer_batch_api :
13891389 return super ().generate_batch (
@@ -1979,7 +1979,7 @@ def generate(
19791979 top_p : float = 0.9 ,
19801980 ** kwargs
19811981 ) -> str :
1982- kwargs .pop ("skip_chat_template " , None )
1982+ kwargs .pop ("use_chat_template " , None )
19831983 payload = {
19841984 "contents" : [{"role" : "user" , "parts" : [{"text" : input_text }]}],
19851985 "generationConfig" : self ._build_gemini_generation_config (
@@ -2022,7 +2022,7 @@ def generate_batch(
20222022 timeout = kwargs .pop ("batch_timeout_seconds" , self .batch_timeout_seconds )
20232023
20242024 # Wrapper-only kwarg
2025- kwargs .pop ("skip_chat_template " , None )
2025+ kwargs .pop ("use_chat_template " , None )
20262026
20272027 if not prefer_batch_api :
20282028 return super ().generate_batch (
0 commit comments