Skip to content

Commit 76eae4b

Browse files
authored
[Enhancement] Add support for Chat Template params (#52)
* Add support for Chat Template params * Updated Documentation * Updated Documentation
1 parent 0f08c6c commit 76eae4b

File tree

2 files changed

+53
-25
lines changed

2 files changed

+53
-25
lines changed

docs/getting_started/model_configuration.md

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -44,22 +44,26 @@ SYGRA_MIXTRAL_8X7B_CHAT_TEMPLATE={% for m in messages %} ... {% endfor %}
4444
4545
### Configuration Properties
4646
47-
| Key | Description |
48-
|------------------------------|------------------------------------------------------------------------------------------------------------|
49-
| `model_type` | Type of backend server (`tgi`, `vllm`, `openai`, `azure_openai`, `azure`, `mistralai`, `ollama`, `triton`) |
50-
| `model_name` | Model name for your deployments (for Azure/Azure OpenAI) |
51-
| `api_version` | API version for Azure or Azure OpenAI |
52-
| `hf_chat_template_model_id` | Hugging Face model ID |
53-
| `completions_api` | *(Optional)* Boolean: use completions API instead of chat completions API (default: false) |
54-
| `modify_tokenizer` | *(Optional)* Boolean: apply custom chat template and modify the base model tokenizer (default: false) |
55-
| `special_tokens` | *(Optional)* List of special stop tokens used in generation |
56-
| `post_process` | *(Optional)* Post processor after model inference (e.g. `models.model_postprocessor.RemoveThinkData`) |
57-
| `parameters` | *(Optional)* Generation parameters (see below) |
58-
| `ssl_verify` | *(Optional)* Verify SSL certificate (default: true) |
59-
| `ssl_cert` | *(Optional)* Path to SSL certificate file |
60-
> **Note:**
61-
> - Do **not** include `url`, `auth_token` in your YAML config. These are sourced from environment variables as described above.<br>
47+
48+
| Key | Description |
49+
|------------------------------|----------------------------------------------------------------------------------------------------------------------------|
50+
| `model_type` | Type of backend server (`tgi`, `vllm`, `openai`, `azure_openai`, `azure`, `mistralai`, `ollama`, `triton`) |
51+
| `model_name` | Model name for your deployments (for Azure/Azure OpenAI) |
52+
| `api_version` | API version for Azure or Azure OpenAI |
53+
| `hf_chat_template_model_id` | Hugging Face model ID |
54+
| `completions_api` | *(Optional)* Boolean: use completions API instead of chat completions API (default: false) |
55+
| `modify_tokenizer` | *(Optional)* Boolean: apply custom chat template and modify the base model tokenizer (default: false) |
56+
| `special_tokens` | *(Optional)* List of special stop tokens used in generation |
57+
| `post_process` | *(Optional)* Post processor after model inference (e.g. `models.model_postprocessor.RemoveThinkData`) |
58+
| `parameters` | *(Optional)* Generation parameters (see below) |
59+
| `chat_template_params` | *(Optional)* Chat template parameters (e.g. `reasoning_effort` for `gpt-oss-120b`) <br/> when `completions_api` is enabled |
60+
| `ssl_verify` | *(Optional)* Verify SSL certificate (default: true) |
61+
| `ssl_cert` | *(Optional)* Path to SSL certificate file |
62+
63+
![Note](https://img.shields.io/badge/Note-important-yellow)
64+
> - Do **not** include `url`, `auth_token`, or `api_key` in your YAML config. These are sourced from environment variables as described above.<br>
6265
> - If you want to set **ssl_verify** to **false** globally, you can set `ssl_verify:false` under `model_config` section in config/configuration.yaml
66+
6367
#### Customizable Model Parameters
6468
6569
- `temperature`: Sampling randomness (0.0–2.0; lower is more deterministic)
@@ -70,7 +74,7 @@ SYGRA_MIXTRAL_8X7B_CHAT_TEMPLATE={% for m in messages %} ... {% endfor %}
7074
- `presence_penalty`: (OpenAI only) Encourages novel tokens
7175
- `frequency_penalty`: (OpenAI only) Penalizes frequently occurring tokens
7276
73-
The model alias set as a key in the configuration is referenced in your graph YAML files (for node types such as `llm` or `multi_llm`). You can override these model parameters in the graph YAML for specific scenarios.
77+
The model alias set as a key in the configuration is referenced in your graph YAML files (for node types such as `llm` or `multi_llm`). You can override these model `parameters`, `chat_template_params` in the graph YAML for specific scenarios.
7478
7579
---
7680

sygra/core/models/custom_models.py

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ def __init__(self, model_config: dict[str, Any]) -> None:
7474
# max_wait for 8 attempts = 2^(8-1) = 128 secs
7575
self.retry_attempts = model_config.get("retry_attempts", 8)
7676
self.generation_params: dict[Any, Any] = model_config.get("parameters") or {}
77+
self.chat_template_params: dict[Any, Any] = model_config.get("chat_template_params") or {}
7778
self.hf_chat_template_model_id = model_config.get("hf_chat_template_model_id")
7879
if self.hf_chat_template_model_id:
7980
self.tokenizer = AutoTokenizer.from_pretrained(
@@ -400,12 +401,15 @@ def ping(self) -> int:
400401
else:
401402
return self._ping_model(url=self.model_config.get("url"), auth_token=auth_token)
402403

403-
def get_chat_formatted_text(self, chat_format_object: Sequence[BaseMessage]) -> str:
404+
def get_chat_formatted_text(
405+
self, chat_format_object: Sequence[BaseMessage], **chat_template_params
406+
) -> str:
404407
chat_formatted_text = str(
405408
self.tokenizer.apply_chat_template(
406409
utils.convert_messages_from_langchain_to_chat_format(chat_format_object),
407410
tokenize=False,
408411
add_generation_prompt=True,
412+
**chat_template_params,
409413
)
410414
)
411415
logger.debug(f"Chat formatted text: {chat_formatted_text}")
@@ -628,7 +632,11 @@ async def _generate_native_structured_output(
628632
json_schema = pydantic_model.model_json_schema()
629633

630634
# Build Request
631-
payload = {"inputs": self.get_chat_formatted_text(input.messages)}
635+
payload = {
636+
"inputs": self.get_chat_formatted_text(
637+
input.messages, **(self.chat_template_params or {})
638+
)
639+
}
632640

633641
# Prepare generation parameters with guidance
634642
generation_params_with_guidance = {
@@ -694,7 +702,11 @@ async def _generate_text(
694702
self._set_client(model_params.url, model_params.auth_token)
695703
client = cast(HttpClient, self._client)
696704
# Build Request
697-
payload = {"inputs": self.get_chat_formatted_text(input.messages)}
705+
payload = {
706+
"inputs": self.get_chat_formatted_text(
707+
input.messages, **(self.chat_template_params or {})
708+
)
709+
}
698710
payload = client.build_request_with_payload(payload=payload)
699711
# Send Request
700712
resp = await client.async_send_request(
@@ -851,7 +863,9 @@ async def _generate_native_structured_output(
851863

852864
# Prepare payload using the client
853865
if self.model_config.get("completions_api", False):
854-
formatted_prompt = self.get_chat_formatted_text(input.messages)
866+
formatted_prompt = self.get_chat_formatted_text(
867+
input.messages, **(self.chat_template_params or {})
868+
)
855869
payload = client.build_request(formatted_prompt=formatted_prompt)
856870
else:
857871
payload = client.build_request(messages=input.messages)
@@ -919,7 +933,9 @@ async def _generate_text(
919933
# https://github.com/encode/httpx/discussions/2959#discussioncomment-7665278
920934
self._set_client(model_url, model_params.auth_token)
921935
if self.model_config.get("completions_api", False):
922-
formatted_prompt = self.get_chat_formatted_text(input.messages)
936+
formatted_prompt = self.get_chat_formatted_text(
937+
input.messages, **(self.chat_template_params or {})
938+
)
923939
payload = self._client.build_request(formatted_prompt=formatted_prompt)
924940
else:
925941
payload = self._client.build_request(messages=input.messages)
@@ -977,7 +993,9 @@ async def _generate_native_structured_output(
977993

978994
# Prepare payload using the client
979995
if self.model_config.get("completions_api", False):
980-
formatted_prompt = self.get_chat_formatted_text(input.messages)
996+
formatted_prompt = self.get_chat_formatted_text(
997+
input.messages, **(self.chat_template_params or {})
998+
)
981999
payload = self._client.build_request(formatted_prompt=formatted_prompt)
9821000
else:
9831001
payload = self._client.build_request(messages=input.messages)
@@ -1048,7 +1066,9 @@ async def _generate_text(
10481066
# set header to close connection otherwise spurious event loop errors show up - https://github.com/encode/httpx/discussions/2959#discussioncomment-7665278
10491067
self._set_client(model_url, model_params.auth_token)
10501068
if self.model_config.get("completions_api", False):
1051-
formatted_prompt = self.get_chat_formatted_text(input.messages)
1069+
formatted_prompt = self.get_chat_formatted_text(
1070+
input.messages, **(self.chat_template_params or {})
1071+
)
10521072
payload = self._client.build_request(formatted_prompt=formatted_prompt)
10531073
else:
10541074
payload = self._client.build_request(messages=input.messages)
@@ -1098,7 +1118,9 @@ async def _generate_native_structured_output(
10981118

10991119
# Prepare payload using the client
11001120
if self.model_config.get("completions_api", False):
1101-
formatted_prompt = self.get_chat_formatted_text(input.messages)
1121+
formatted_prompt = self.get_chat_formatted_text(
1122+
input.messages, **(self.chat_template_params or {})
1123+
)
11021124
payload = self._client.build_request(formatted_prompt=formatted_prompt)
11031125
else:
11041126
payload = self._client.build_request(messages=input.messages)
@@ -1163,7 +1185,9 @@ async def _generate_text(
11631185
try:
11641186
self._set_client(model_url, model_params.auth_token)
11651187
if self.model_config.get("completions_api", False):
1166-
formatted_prompt = self.get_chat_formatted_text(input.messages)
1188+
formatted_prompt = self.get_chat_formatted_text(
1189+
input.messages, **(self.chat_template_params or {})
1190+
)
11671191
payload = self._client.build_request(formatted_prompt=formatted_prompt)
11681192
else:
11691193
payload = self._client.build_request(messages=input.messages)

0 commit comments

Comments
 (0)