Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 20 additions & 16 deletions docs/getting_started/model_configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,22 +44,26 @@ SYGRA_MIXTRAL_8X7B_CHAT_TEMPLATE={% for m in messages %} ... {% endfor %}

### Configuration Properties

| Key | Description |
|------------------------------|------------------------------------------------------------------------------------------------------------|
| `model_type` | Type of backend server (`tgi`, `vllm`, `openai`, `azure_openai`, `azure`, `mistralai`, `ollama`, `triton`) |
| `model_name` | Model name for your deployments (for Azure/Azure OpenAI) |
| `api_version` | API version for Azure or Azure OpenAI |
| `hf_chat_template_model_id` | Hugging Face model ID |
| `completions_api` | *(Optional)* Boolean: use completions API instead of chat completions API (default: false) |
| `modify_tokenizer` | *(Optional)* Boolean: apply custom chat template and modify the base model tokenizer (default: false) |
| `special_tokens` | *(Optional)* List of special stop tokens used in generation |
| `post_process` | *(Optional)* Post processor after model inference (e.g. `models.model_postprocessor.RemoveThinkData`) |
| `parameters` | *(Optional)* Generation parameters (see below) |
| `ssl_verify` | *(Optional)* Verify SSL certificate (default: true) |
| `ssl_cert` | *(Optional)* Path to SSL certificate file |
> **Note:**
> - Do **not** include `url`, `auth_token` in your YAML config. These are sourced from environment variables as described above.<br>

| Key | Description |
|------------------------------|----------------------------------------------------------------------------------------------------------------------------|
| `model_type` | Type of backend server (`tgi`, `vllm`, `openai`, `azure_openai`, `azure`, `mistralai`, `ollama`, `triton`) |
| `model_name` | Model name for your deployments (for Azure/Azure OpenAI) |
| `api_version` | API version for Azure or Azure OpenAI |
| `hf_chat_template_model_id` | Hugging Face model ID |
| `completions_api` | *(Optional)* Boolean: use completions API instead of chat completions API (default: false) |
| `modify_tokenizer` | *(Optional)* Boolean: apply custom chat template and modify the base model tokenizer (default: false) |
| `special_tokens` | *(Optional)* List of special stop tokens used in generation |
| `post_process` | *(Optional)* Post processor after model inference (e.g. `models.model_postprocessor.RemoveThinkData`) |
| `parameters` | *(Optional)* Generation parameters (see below) |
| `chat_template_params` | *(Optional)* Chat template parameters (e.g. `reasoning_effort` for `gpt-oss-120b`) <br/> when `completions_api` is enabled |
| `ssl_verify` | *(Optional)* Verify SSL certificate (default: true) |
| `ssl_cert` | *(Optional)* Path to SSL certificate file |

![Note](https://img.shields.io/badge/Note-important-yellow)
> - Do **not** include `url`, `auth_token`, or `api_key` in your YAML config. These are sourced from environment variables as described above.<br>
> - If you want to set **ssl_verify** to **false** globally, you can set `ssl_verify:false` under `model_config` section in config/configuration.yaml

#### Customizable Model Parameters

- `temperature`: Sampling randomness (0.0–2.0; lower is more deterministic)
Expand All @@ -70,7 +74,7 @@ SYGRA_MIXTRAL_8X7B_CHAT_TEMPLATE={% for m in messages %} ... {% endfor %}
- `presence_penalty`: (OpenAI only) Encourages novel tokens
- `frequency_penalty`: (OpenAI only) Penalizes frequently occurring tokens

The model alias set as a key in the configuration is referenced in your graph YAML files (for node types such as `llm` or `multi_llm`). You can override these model parameters in the graph YAML for specific scenarios.
The model alias set as a key in the configuration is referenced in your graph YAML files (for node types such as `llm` or `multi_llm`). You can override these model `parameters`, `chat_template_params` in the graph YAML for specific scenarios.

---

Expand Down
42 changes: 33 additions & 9 deletions sygra/core/models/custom_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def __init__(self, model_config: dict[str, Any]) -> None:
# max_wait for 8 attempts = 2^(8-1) = 128 secs
self.retry_attempts = model_config.get("retry_attempts", 8)
self.generation_params: dict[Any, Any] = model_config.get("parameters") or {}
self.chat_template_params: dict[Any, Any] = model_config.get("chat_template_params") or {}
self.hf_chat_template_model_id = model_config.get("hf_chat_template_model_id")
if self.hf_chat_template_model_id:
self.tokenizer = AutoTokenizer.from_pretrained(
Expand Down Expand Up @@ -400,12 +401,15 @@ def ping(self) -> int:
else:
return self._ping_model(url=self.model_config.get("url"), auth_token=auth_token)

def get_chat_formatted_text(self, chat_format_object: Sequence[BaseMessage]) -> str:
def get_chat_formatted_text(
self, chat_format_object: Sequence[BaseMessage], **chat_template_params
) -> str:
chat_formatted_text = str(
self.tokenizer.apply_chat_template(
utils.convert_messages_from_langchain_to_chat_format(chat_format_object),
tokenize=False,
add_generation_prompt=True,
**chat_template_params,
)
)
logger.debug(f"Chat formatted text: {chat_formatted_text}")
Expand Down Expand Up @@ -628,7 +632,11 @@ async def _generate_native_structured_output(
json_schema = pydantic_model.model_json_schema()

# Build Request
payload = {"inputs": self.get_chat_formatted_text(input.messages)}
payload = {
"inputs": self.get_chat_formatted_text(
input.messages, **(self.chat_template_params or {})
)
}

# Prepare generation parameters with guidance
generation_params_with_guidance = {
Expand Down Expand Up @@ -694,7 +702,11 @@ async def _generate_text(
self._set_client(model_params.url, model_params.auth_token)
client = cast(HttpClient, self._client)
# Build Request
payload = {"inputs": self.get_chat_formatted_text(input.messages)}
payload = {
"inputs": self.get_chat_formatted_text(
input.messages, **(self.chat_template_params or {})
)
}
payload = client.build_request_with_payload(payload=payload)
# Send Request
resp = await client.async_send_request(
Expand Down Expand Up @@ -851,7 +863,9 @@ async def _generate_native_structured_output(

# Prepare payload using the client
if self.model_config.get("completions_api", False):
formatted_prompt = self.get_chat_formatted_text(input.messages)
formatted_prompt = self.get_chat_formatted_text(
input.messages, **(self.chat_template_params or {})
)
payload = client.build_request(formatted_prompt=formatted_prompt)
else:
payload = client.build_request(messages=input.messages)
Expand Down Expand Up @@ -919,7 +933,9 @@ async def _generate_text(
# https://github.com/encode/httpx/discussions/2959#discussioncomment-7665278
self._set_client(model_url, model_params.auth_token)
if self.model_config.get("completions_api", False):
formatted_prompt = self.get_chat_formatted_text(input.messages)
formatted_prompt = self.get_chat_formatted_text(
input.messages, **(self.chat_template_params or {})
)
payload = self._client.build_request(formatted_prompt=formatted_prompt)
else:
payload = self._client.build_request(messages=input.messages)
Expand Down Expand Up @@ -977,7 +993,9 @@ async def _generate_native_structured_output(

# Prepare payload using the client
if self.model_config.get("completions_api", False):
formatted_prompt = self.get_chat_formatted_text(input.messages)
formatted_prompt = self.get_chat_formatted_text(
input.messages, **(self.chat_template_params or {})
)
payload = self._client.build_request(formatted_prompt=formatted_prompt)
else:
payload = self._client.build_request(messages=input.messages)
Expand Down Expand Up @@ -1048,7 +1066,9 @@ async def _generate_text(
# set header to close connection otherwise spurious event loop errors show up - https://github.com/encode/httpx/discussions/2959#discussioncomment-7665278
self._set_client(model_url, model_params.auth_token)
if self.model_config.get("completions_api", False):
formatted_prompt = self.get_chat_formatted_text(input.messages)
formatted_prompt = self.get_chat_formatted_text(
input.messages, **(self.chat_template_params or {})
)
payload = self._client.build_request(formatted_prompt=formatted_prompt)
else:
payload = self._client.build_request(messages=input.messages)
Expand Down Expand Up @@ -1098,7 +1118,9 @@ async def _generate_native_structured_output(

# Prepare payload using the client
if self.model_config.get("completions_api", False):
formatted_prompt = self.get_chat_formatted_text(input.messages)
formatted_prompt = self.get_chat_formatted_text(
input.messages, **(self.chat_template_params or {})
)
payload = self._client.build_request(formatted_prompt=formatted_prompt)
else:
payload = self._client.build_request(messages=input.messages)
Expand Down Expand Up @@ -1163,7 +1185,9 @@ async def _generate_text(
try:
self._set_client(model_url, model_params.auth_token)
if self.model_config.get("completions_api", False):
formatted_prompt = self.get_chat_formatted_text(input.messages)
formatted_prompt = self.get_chat_formatted_text(
input.messages, **(self.chat_template_params or {})
)
payload = self._client.build_request(formatted_prompt=formatted_prompt)
else:
payload = self._client.build_request(messages=input.messages)
Expand Down