diff --git a/docs/getting_started/model_configuration.md b/docs/getting_started/model_configuration.md index 06fa425e..f4df0ead 100644 --- a/docs/getting_started/model_configuration.md +++ b/docs/getting_started/model_configuration.md @@ -44,22 +44,26 @@ SYGRA_MIXTRAL_8X7B_CHAT_TEMPLATE={% for m in messages %} ... {% endfor %} ### Configuration Properties -| Key | Description | -|------------------------------|------------------------------------------------------------------------------------------------------------| -| `model_type` | Type of backend server (`tgi`, `vllm`, `openai`, `azure_openai`, `azure`, `mistralai`, `ollama`, `triton`) | -| `model_name` | Model name for your deployments (for Azure/Azure OpenAI) | -| `api_version` | API version for Azure or Azure OpenAI | -| `hf_chat_template_model_id` | Hugging Face model ID | -| `completions_api` | *(Optional)* Boolean: use completions API instead of chat completions API (default: false) | -| `modify_tokenizer` | *(Optional)* Boolean: apply custom chat template and modify the base model tokenizer (default: false) | -| `special_tokens` | *(Optional)* List of special stop tokens used in generation | -| `post_process` | *(Optional)* Post processor after model inference (e.g. `models.model_postprocessor.RemoveThinkData`) | -| `parameters` | *(Optional)* Generation parameters (see below) | -| `ssl_verify` | *(Optional)* Verify SSL certificate (default: true) | -| `ssl_cert` | *(Optional)* Path to SSL certificate file | -> **Note:** -> - Do **not** include `url`, `auth_token` in your YAML config. These are sourced from environment variables as described above.
+ +| Key | Description | +|------------------------------|----------------------------------------------------------------------------------------------------------------------------| +| `model_type` | Type of backend server (`tgi`, `vllm`, `openai`, `azure_openai`, `azure`, `mistralai`, `ollama`, `triton`) | +| `model_name` | Model name for your deployments (for Azure/Azure OpenAI) | +| `api_version` | API version for Azure or Azure OpenAI | +| `hf_chat_template_model_id` | Hugging Face model ID | +| `completions_api` | *(Optional)* Boolean: use completions API instead of chat completions API (default: false) | +| `modify_tokenizer` | *(Optional)* Boolean: apply custom chat template and modify the base model tokenizer (default: false) | +| `special_tokens` | *(Optional)* List of special stop tokens used in generation | +| `post_process` | *(Optional)* Post processor after model inference (e.g. `models.model_postprocessor.RemoveThinkData`) | +| `parameters` | *(Optional)* Generation parameters (see below) | +| `chat_template_params` | *(Optional)* Chat template parameters (e.g. `reasoning_effort` for `gpt-oss-120b`)
when `completions_api` is enabled | +| `ssl_verify` | *(Optional)* Verify SSL certificate (default: true) | +| `ssl_cert` | *(Optional)* Path to SSL certificate file | + +![Note](https://img.shields.io/badge/Note-important-yellow) +> - Do **not** include `url`, `auth_token`, or `api_key` in your YAML config. These are sourced from environment variables as described above.
> - If you want to set **ssl_verify** to **false** globally, you can set `ssl_verify:false` under `model_config` section in config/configuration.yaml + #### Customizable Model Parameters - `temperature`: Sampling randomness (0.0–2.0; lower is more deterministic) @@ -70,7 +74,7 @@ SYGRA_MIXTRAL_8X7B_CHAT_TEMPLATE={% for m in messages %} ... {% endfor %} - `presence_penalty`: (OpenAI only) Encourages novel tokens - `frequency_penalty`: (OpenAI only) Penalizes frequently occurring tokens -The model alias set as a key in the configuration is referenced in your graph YAML files (for node types such as `llm` or `multi_llm`). You can override these model parameters in the graph YAML for specific scenarios. +The model alias set as a key in the configuration is referenced in your graph YAML files (for node types such as `llm` or `multi_llm`). You can override these model `parameters`, `chat_template_params` in the graph YAML for specific scenarios. --- diff --git a/sygra/core/models/custom_models.py b/sygra/core/models/custom_models.py index 3bf56005..bc22191b 100644 --- a/sygra/core/models/custom_models.py +++ b/sygra/core/models/custom_models.py @@ -74,6 +74,7 @@ def __init__(self, model_config: dict[str, Any]) -> None: # max_wait for 8 attempts = 2^(8-1) = 128 secs self.retry_attempts = model_config.get("retry_attempts", 8) self.generation_params: dict[Any, Any] = model_config.get("parameters") or {} + self.chat_template_params: dict[Any, Any] = model_config.get("chat_template_params") or {} self.hf_chat_template_model_id = model_config.get("hf_chat_template_model_id") if self.hf_chat_template_model_id: self.tokenizer = AutoTokenizer.from_pretrained( @@ -400,12 +401,15 @@ def ping(self) -> int: else: return self._ping_model(url=self.model_config.get("url"), auth_token=auth_token) - def get_chat_formatted_text(self, chat_format_object: Sequence[BaseMessage]) -> str: + def get_chat_formatted_text( + self, chat_format_object: Sequence[BaseMessage], **chat_template_params + ) -> str: chat_formatted_text = str( self.tokenizer.apply_chat_template( utils.convert_messages_from_langchain_to_chat_format(chat_format_object), tokenize=False, add_generation_prompt=True, + **chat_template_params, ) ) logger.debug(f"Chat formatted text: {chat_formatted_text}") @@ -628,7 +632,11 @@ async def _generate_native_structured_output( json_schema = pydantic_model.model_json_schema() # Build Request - payload = {"inputs": self.get_chat_formatted_text(input.messages)} + payload = { + "inputs": self.get_chat_formatted_text( + input.messages, **(self.chat_template_params or {}) + ) + } # Prepare generation parameters with guidance generation_params_with_guidance = { @@ -694,7 +702,11 @@ async def _generate_text( self._set_client(model_params.url, model_params.auth_token) client = cast(HttpClient, self._client) # Build Request - payload = {"inputs": self.get_chat_formatted_text(input.messages)} + payload = { + "inputs": self.get_chat_formatted_text( + input.messages, **(self.chat_template_params or {}) + ) + } payload = client.build_request_with_payload(payload=payload) # Send Request resp = await client.async_send_request( @@ -851,7 +863,9 @@ async def _generate_native_structured_output( # Prepare payload using the client if self.model_config.get("completions_api", False): - formatted_prompt = self.get_chat_formatted_text(input.messages) + formatted_prompt = self.get_chat_formatted_text( + input.messages, **(self.chat_template_params or {}) + ) payload = client.build_request(formatted_prompt=formatted_prompt) else: payload = client.build_request(messages=input.messages) @@ -919,7 +933,9 @@ async def _generate_text( # https://github.com/encode/httpx/discussions/2959#discussioncomment-7665278 self._set_client(model_url, model_params.auth_token) if self.model_config.get("completions_api", False): - formatted_prompt = self.get_chat_formatted_text(input.messages) + formatted_prompt = self.get_chat_formatted_text( + input.messages, **(self.chat_template_params or {}) + ) payload = self._client.build_request(formatted_prompt=formatted_prompt) else: payload = self._client.build_request(messages=input.messages) @@ -977,7 +993,9 @@ async def _generate_native_structured_output( # Prepare payload using the client if self.model_config.get("completions_api", False): - formatted_prompt = self.get_chat_formatted_text(input.messages) + formatted_prompt = self.get_chat_formatted_text( + input.messages, **(self.chat_template_params or {}) + ) payload = self._client.build_request(formatted_prompt=formatted_prompt) else: payload = self._client.build_request(messages=input.messages) @@ -1048,7 +1066,9 @@ async def _generate_text( # set header to close connection otherwise spurious event loop errors show up - https://github.com/encode/httpx/discussions/2959#discussioncomment-7665278 self._set_client(model_url, model_params.auth_token) if self.model_config.get("completions_api", False): - formatted_prompt = self.get_chat_formatted_text(input.messages) + formatted_prompt = self.get_chat_formatted_text( + input.messages, **(self.chat_template_params or {}) + ) payload = self._client.build_request(formatted_prompt=formatted_prompt) else: payload = self._client.build_request(messages=input.messages) @@ -1098,7 +1118,9 @@ async def _generate_native_structured_output( # Prepare payload using the client if self.model_config.get("completions_api", False): - formatted_prompt = self.get_chat_formatted_text(input.messages) + formatted_prompt = self.get_chat_formatted_text( + input.messages, **(self.chat_template_params or {}) + ) payload = self._client.build_request(formatted_prompt=formatted_prompt) else: payload = self._client.build_request(messages=input.messages) @@ -1163,7 +1185,9 @@ async def _generate_text( try: self._set_client(model_url, model_params.auth_token) if self.model_config.get("completions_api", False): - formatted_prompt = self.get_chat_formatted_text(input.messages) + formatted_prompt = self.get_chat_formatted_text( + input.messages, **(self.chat_template_params or {}) + ) payload = self._client.build_request(formatted_prompt=formatted_prompt) else: payload = self._client.build_request(messages=input.messages)