@@ -74,6 +74,7 @@ def __init__(self, model_config: dict[str, Any]) -> None:
7474 # max_wait for 8 attempts = 2^(8-1) = 128 secs
7575 self .retry_attempts = model_config .get ("retry_attempts" , 8 )
7676 self .generation_params : dict [Any , Any ] = model_config .get ("parameters" ) or {}
77+ self .chat_template_params : dict [Any , Any ] = model_config .get ("chat_template_params" ) or {}
7778 self .hf_chat_template_model_id = model_config .get ("hf_chat_template_model_id" )
7879 if self .hf_chat_template_model_id :
7980 self .tokenizer = AutoTokenizer .from_pretrained (
@@ -400,12 +401,15 @@ def ping(self) -> int:
400401 else :
401402 return self ._ping_model (url = self .model_config .get ("url" ), auth_token = auth_token )
402403
403- def get_chat_formatted_text (self , chat_format_object : Sequence [BaseMessage ]) -> str :
404+ def get_chat_formatted_text (
405+ self , chat_format_object : Sequence [BaseMessage ], ** chat_template_params
406+ ) -> str :
404407 chat_formatted_text = str (
405408 self .tokenizer .apply_chat_template (
406409 utils .convert_messages_from_langchain_to_chat_format (chat_format_object ),
407410 tokenize = False ,
408411 add_generation_prompt = True ,
412+ ** chat_template_params ,
409413 )
410414 )
411415 logger .debug (f"Chat formatted text: { chat_formatted_text } " )
@@ -628,7 +632,11 @@ async def _generate_native_structured_output(
628632 json_schema = pydantic_model .model_json_schema ()
629633
630634 # Build Request
631- payload = {"inputs" : self .get_chat_formatted_text (input .messages )}
635+ payload = {
636+ "inputs" : self .get_chat_formatted_text (
637+ input .messages , ** (self .chat_template_params or {})
638+ )
639+ }
632640
633641 # Prepare generation parameters with guidance
634642 generation_params_with_guidance = {
@@ -694,7 +702,11 @@ async def _generate_text(
694702 self ._set_client (model_params .url , model_params .auth_token )
695703 client = cast (HttpClient , self ._client )
696704 # Build Request
697- payload = {"inputs" : self .get_chat_formatted_text (input .messages )}
705+ payload = {
706+ "inputs" : self .get_chat_formatted_text (
707+ input .messages , ** (self .chat_template_params or {})
708+ )
709+ }
698710 payload = client .build_request_with_payload (payload = payload )
699711 # Send Request
700712 resp = await client .async_send_request (
@@ -851,7 +863,9 @@ async def _generate_native_structured_output(
851863
852864 # Prepare payload using the client
853865 if self .model_config .get ("completions_api" , False ):
854- formatted_prompt = self .get_chat_formatted_text (input .messages )
866+ formatted_prompt = self .get_chat_formatted_text (
867+ input .messages , ** (self .chat_template_params or {})
868+ )
855869 payload = client .build_request (formatted_prompt = formatted_prompt )
856870 else :
857871 payload = client .build_request (messages = input .messages )
@@ -919,7 +933,9 @@ async def _generate_text(
919933 # https://github.com/encode/httpx/discussions/2959#discussioncomment-7665278
920934 self ._set_client (model_url , model_params .auth_token )
921935 if self .model_config .get ("completions_api" , False ):
922- formatted_prompt = self .get_chat_formatted_text (input .messages )
936+ formatted_prompt = self .get_chat_formatted_text (
937+ input .messages , ** (self .chat_template_params or {})
938+ )
923939 payload = self ._client .build_request (formatted_prompt = formatted_prompt )
924940 else :
925941 payload = self ._client .build_request (messages = input .messages )
@@ -977,7 +993,9 @@ async def _generate_native_structured_output(
977993
978994 # Prepare payload using the client
979995 if self .model_config .get ("completions_api" , False ):
980- formatted_prompt = self .get_chat_formatted_text (input .messages )
996+ formatted_prompt = self .get_chat_formatted_text (
997+ input .messages , ** (self .chat_template_params or {})
998+ )
981999 payload = self ._client .build_request (formatted_prompt = formatted_prompt )
9821000 else :
9831001 payload = self ._client .build_request (messages = input .messages )
@@ -1048,7 +1066,9 @@ async def _generate_text(
10481066 # set header to close connection otherwise spurious event loop errors show up - https://github.com/encode/httpx/discussions/2959#discussioncomment-7665278
10491067 self ._set_client (model_url , model_params .auth_token )
10501068 if self .model_config .get ("completions_api" , False ):
1051- formatted_prompt = self .get_chat_formatted_text (input .messages )
1069+ formatted_prompt = self .get_chat_formatted_text (
1070+ input .messages , ** (self .chat_template_params or {})
1071+ )
10521072 payload = self ._client .build_request (formatted_prompt = formatted_prompt )
10531073 else :
10541074 payload = self ._client .build_request (messages = input .messages )
@@ -1098,7 +1118,9 @@ async def _generate_native_structured_output(
10981118
10991119 # Prepare payload using the client
11001120 if self .model_config .get ("completions_api" , False ):
1101- formatted_prompt = self .get_chat_formatted_text (input .messages )
1121+ formatted_prompt = self .get_chat_formatted_text (
1122+ input .messages , ** (self .chat_template_params or {})
1123+ )
11021124 payload = self ._client .build_request (formatted_prompt = formatted_prompt )
11031125 else :
11041126 payload = self ._client .build_request (messages = input .messages )
@@ -1163,7 +1185,9 @@ async def _generate_text(
11631185 try :
11641186 self ._set_client (model_url , model_params .auth_token )
11651187 if self .model_config .get ("completions_api" , False ):
1166- formatted_prompt = self .get_chat_formatted_text (input .messages )
1188+ formatted_prompt = self .get_chat_formatted_text (
1189+ input .messages , ** (self .chat_template_params or {})
1190+ )
11671191 payload = self ._client .build_request (formatted_prompt = formatted_prompt )
11681192 else :
11691193 payload = self ._client .build_request (messages = input .messages )
0 commit comments