Skip to content

Commit 8f5df2a

Browse files
authored
updated model configuration in routing. (#20)
* updated model configuration in routing. added the option to include model lists in .env and updated .env.example as well. * fallback json fix * moved model config to toml as suggested * added an env variable to override model config toml path * added pydantic models for LLM model config and routing params * updated setup info with model config in README * minor visual fix to README * fix tests looking for config file * update yml to work with env changes
1 parent b52a1db commit 8f5df2a

File tree

10 files changed

+207
-60
lines changed

10 files changed

+207
-60
lines changed

.env.example

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
# override this to your own model config toml
2+
LITELLM_CONFIG_PATH=model.config.toml
3+
14
# Azure OpenAI Configuration (Legacy)
25
AZURE_OPENAI_MODEL=your_model_name_here # e.g., o3-mini-deep-research
36
AZURE_OPENAI_ENDPOINT=https://your-resource-name.cognitiveservices.azure.com/

.github/workflows/tests.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ jobs:
2222
env:
2323
X_API_KEY: "some-api-key"
2424
IS_PROD: "true"
25+
LITELLM_DEFAULT_MODEL_GROUP: ${{ secrets.LITELLM_DEFAULT_MODEL_GROUP }}
2526

2627
steps:
2728
- name: Checkout repository

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ attachments/**
22

33
results/**
44

5+
model.config.toml
6+
57
# Python
68
__pycache__/
79
*.py[cod]

README.md

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -105,16 +105,12 @@ poetry run dramatiq mxtoai.tasks --watch ./.
105105
Copy the `.env.example` file to `.env` and update with your specific configuration:
106106

107107
```env
108+
LITELLM_CONFIG_PATH=model.config.toml
109+
108110
# Redis configuration
109111
REDIS_HOST=localhost
110112
REDIS_PORT=6379
111113
112-
# Model configuration
113-
MODEL_ENDPOINT=your_azure_openai_endpoint
114-
MODEL_API_KEY=your_azure_openai_api_key
115-
MODEL_NAME=your-azure-openai-model-deployment
116-
MODEL_API_VERSION=2025-01-01-preview
117-
118114
# Optional for research functionality
119115
JINA_API_KEY=your-jina-api-key
120116
@@ -126,6 +122,35 @@ AZURE_VISION_KEY=your-azure-vision-key
126122
SERPAPI_API_KEY=your-serpapi-api-key
127123
```
128124

125+
This project supports load balancing and routing across multiple models, so you can define as many models as you'd like. Copy `model.config.example.toml` to a toml file and update it with your preferred configuration. Update `.env` with the path your toml relative to root.
126+
127+
A sample configuration looks like this:
128+
129+
```toml
130+
[[model]]
131+
model_name = "gpt-4"
132+
133+
[model.litellm_params]
134+
model = "azure/gpt-4"
135+
base_url = "https://your-endpoint.openai.azure.com"
136+
api_key = "your-key"
137+
api_version = "2023-05-15"
138+
weight = 5
139+
```
140+
141+
It is also recommended that you set router configuration. It will be defaulted to the below config if not set.
142+
143+
```toml
144+
[router_config]
145+
routing_strategy = "simple-shuffle"
146+
147+
[[router_config.fallbacks]]
148+
gpt-4 = ["gpt-4-reasoning"]
149+
150+
[router_config.default_litellm_params]
151+
drop_params = true
152+
```
153+
129154
## API Endpoints
130155

131156
### Process Email

model.config.example.toml

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
[[model]]
2+
model_name = "gpt-4"
3+
4+
[model.litellm_params]
5+
model = "azure/gpt-4"
6+
base_url = "https://your-endpoint.openai.azure.com"
7+
api_key = "your-key"
8+
api_version = "2023-05-15"
9+
weight = 5
10+
11+
[[model]]
12+
model_name = "gpt-4"
13+
14+
[model.litellm_params]
15+
model = "azure/gpt-4-1106-preview"
16+
base_url = "https://your-endpoint-2.openai.azure.com"
17+
api_key = "your-key-2"
18+
api_version = "2023-05-15"
19+
weight = 5
20+
21+
[[model]]
22+
model_name = "gpt-4-reasoning"
23+
24+
[model.litellm_params]
25+
model = "azure/gpt-4o-mini"
26+
base_url = "https://your-endpoint-3.openai.azure.com"
27+
api_key = "your-key-3"
28+
api_version = "2023-05-15"
29+
weight = 1
30+
31+
[router_config]
32+
routing_strategy = "simple-shuffle"
33+
34+
[[router_config.fallbacks]]
35+
gpt-4 = ["gpt-4-reasoning"]
36+
37+
[router_config.default_litellm_params]
38+
drop_params = true

mxtoai/exceptions.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,19 @@
11
class UnspportedHandleException(Exception):
2-
def __init__(self, message):
2+
def __init__(self, message: str):
33
super().__init__(message)
44

55
class HandleAlreadyExistsException(Exception):
6-
def __init__(self, message):
6+
def __init__(self, message: str):
7+
super().__init__(message)
8+
9+
class EnvironmentVariableNotFoundException(Exception):
10+
def __init__(self, message: str):
11+
super().__init__(message)
12+
13+
class ModelListNotFoundException(Exception):
14+
def __init__(self, message: str):
15+
super().__init__(message)
16+
17+
class ModelConfigFileNotFoundException(Exception):
18+
def __init__(self, message: str):
719
super().__init__(message)

mxtoai/models.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1-
from typing import Optional
1+
from typing import Optional, Any, Dict, List
22

33
from pydantic import BaseModel
44

55

66
class ProcessingInstructions(BaseModel):
77
handle: str
8-
aliases: list[str]
8+
aliases: List[str]
99
process_attachments: bool
1010
deep_research_mandatory: bool
1111
rejection_message: Optional[str] = (
@@ -18,3 +18,19 @@ class ProcessingInstructions(BaseModel):
1818
requires_schedule_extraction: bool = False
1919
target_model: Optional[str] = "gpt-4"
2020
output_instructions: Optional[str] = None
21+
22+
class LiteLLMParams(BaseModel):
23+
model: str
24+
base_url: str
25+
api_key: str
26+
api_version: str
27+
weight: int
28+
29+
class ModelConfig(BaseModel):
30+
model_name: str
31+
litellm_params: LiteLLMParams
32+
33+
class RouterConfig(BaseModel):
34+
routing_strategy: str
35+
fallbacks: List[Dict[str, List[str]]]
36+
default_litellm_params: Dict[str, Any]

mxtoai/routed_litellm_model.py

Lines changed: 86 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
import os
2-
from typing import Any, Optional
2+
from typing import Any, Optional, List, Dict
33

4+
import toml
45
from dotenv import load_dotenv
56
from smolagents import ChatMessage, LiteLLMRouterModel, Tool
67

8+
import mxtoai.models as models
9+
import mxtoai.exceptions as exceptions
710
from mxtoai._logging import get_logger
811
from mxtoai.models import ProcessingInstructions
912

@@ -25,64 +28,98 @@ def __init__(self, current_handle: Optional[ProcessingInstructions] = None, **kw
2528
2629
"""
2730
self.current_handle = current_handle
31+
self.config_path = os.getenv("LITELLM_CONFIG_PATH", "model.config.example.toml")
32+
self.config = self._load_toml_config()
2833

2934
# Configure model list from environment variables
30-
model_list = [
31-
{
32-
"model_name": "gpt-4",
33-
"litellm_params": {
34-
"model": f"azure/{os.getenv('GPT4O_1_NAME')}",
35-
"base_url": os.getenv("GPT4O_1_ENDPOINT"),
36-
"api_key": os.getenv("GPT4O_1_API_KEY"),
37-
"api_version": os.getenv("GPT4O_1_API_VERSION"),
38-
"weight": int(os.getenv("GPT4O_1_WEIGHT", 5)),
39-
},
40-
},
41-
{
42-
"model_name": "gpt-4",
43-
"litellm_params": {
44-
"model": f"azure/{os.getenv('GPT41_MINI_NAME')}",
45-
"base_url": os.getenv("GPT41_MINI_ENDPOINT"),
46-
"api_key": os.getenv("GPT41_MINI_API_KEY"),
47-
"api_version": os.getenv("GPT41_MINI_API_VERSION"),
48-
"weight": int(os.getenv("GPT41_MINI_WEIGHT", 5)),
49-
},
50-
},
51-
{
52-
"model_name": "gpt-4-reasoning",
53-
"litellm_params": {
54-
"model": f"azure/{os.getenv('O3_MINI_NAME')}",
55-
"api_base": os.getenv("O3_MINI_ENDPOINT"),
56-
"api_key": os.getenv("O3_MINI_API_KEY"),
57-
"api_version": os.getenv("O3_MINI_API_VERSION"),
58-
"weight": int(os.getenv("O3_MINI_WEIGHT", 1)),
59-
},
60-
},
61-
]
62-
63-
client_router_kwargs = {
64-
"routing_strategy": "simple-shuffle",
65-
"fallbacks": [
66-
{
67-
"gpt-4": ["gpt-4-reasoning"] # Fallback to reasoning model if both GPT-4 instances fail
68-
}
69-
],
70-
# "set_verbose": True,
71-
# "debug_level": "DEBUG",
72-
"default_litellm_params": {"drop_params": True}, # Global setting for dropping unsupported parameters
73-
}
74-
35+
model_list = self._load_model_config()
36+
client_router_kwargs = self._load_router_config()
37+
7538
# The model_id for LiteLLMRouterModel is the default model group the router will target.
7639
# Our _get_target_model() will override this per call via the 'model' param in generate().
77-
default_model_group = "gpt-4"
40+
default_model_group = os.getenv("LITELLM_DEFAULT_MODEL_GROUP")
41+
42+
if not default_model_group:
43+
raise exceptions.EnvironmentVariableNotFoundException(
44+
"LITELLM_DEFAULT_MODEL_GROUP environment variable not found. Please set it to the default model group."
45+
)
7846

7947
super().__init__(
8048
model_id=default_model_group,
81-
model_list=model_list,
82-
client_kwargs=client_router_kwargs,
49+
model_list=[model.dict() for model in model_list],
50+
client_kwargs=client_router_kwargs.dict(),
8351
**kwargs, # Pass through other LiteLLMModel/Model kwargs
8452
)
8553

54+
def _load_toml_config(self) -> Dict[str, Any]:
55+
"""
56+
Load configuration from a TOML file.
57+
58+
Returns:
59+
Dict[str, Any]: Configuration loaded from the TOML file.
60+
"""
61+
62+
if not os.path.exists(self.config_path):
63+
raise exceptions.ModelConfigFileNotFoundException(
64+
f"Model config file not found at {self.config_path}. Please check the path."
65+
)
66+
67+
try:
68+
with open(self.config_path, "r") as f:
69+
return toml.load(f)
70+
except Exception as e:
71+
logger.error(f"Failed to load TOML config: {e}")
72+
return {}
73+
74+
def _load_model_config(self) -> List[Dict[str, Any]]:
75+
"""
76+
Load model configuration from environment variables.
77+
78+
Returns:
79+
List[Dict[str, Any]]: List of model configurations.
80+
81+
"""
82+
model_entries = self.config.get("model", [])
83+
model_list = []
84+
85+
if isinstance(model_entries, dict):
86+
# In case there's only one model (TOML parser returns dict)
87+
model_entries = [model_entries]
88+
89+
for entry in model_entries:
90+
model_list.append(models.ModelConfig(
91+
model_name=entry.get("model_name"),
92+
litellm_params=models.LiteLLMParams(
93+
**entry.get("litellm_params")
94+
)
95+
))
96+
97+
if not model_list:
98+
raise exceptions.ModelListNotFoundException(
99+
"No model list found in config toml. Please check the configuration."
100+
)
101+
102+
return model_list
103+
104+
def _load_router_config(self) -> models.RouterConfig:
105+
"""
106+
Load router configuration from environment variables.
107+
108+
Returns:
109+
models.RouterConfig: Router configuration
110+
"""
111+
router_config = models.RouterConfig(**self.config.get("router_config"))
112+
113+
if not router_config:
114+
logger.warning("No router config found in model-config.toml. Using defaults.")
115+
return models.RouterConfig(
116+
routing_strategy="simple-shuffle",
117+
fallbacks=[],
118+
default_litellm_params={"drop_params": True},
119+
)
120+
return router_config
121+
122+
86123
def _get_target_model(self) -> str:
87124
"""
88125
Determine which model to route to based on the current handle configuration.

poetry.lock

Lines changed: 13 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ dependencies = [
4444
"jinja2 (>=3.1.6,<4.0.0)",
4545
"pydantic[email] (>=2.11.4,<3.0.0)",
4646
"python-multipart (>=0.0.20,<0.0.21)",
47+
"toml (>=0.10.2,<0.11.0)",
4748
]
4849

4950
[tool.ruff]

0 commit comments

Comments
 (0)