11import os
2- from typing import Any , Optional
2+ from typing import Any , Optional , List , Dict
33
4+ import toml
45from dotenv import load_dotenv
56from smolagents import ChatMessage , LiteLLMRouterModel , Tool
67
8+ import mxtoai .models as models
9+ import mxtoai .exceptions as exceptions
710from mxtoai ._logging import get_logger
811from mxtoai .models import ProcessingInstructions
912
@@ -25,64 +28,98 @@ def __init__(self, current_handle: Optional[ProcessingInstructions] = None, **kw
2528
2629 """
2730 self .current_handle = current_handle
31+ self .config_path = os .getenv ("LITELLM_CONFIG_PATH" , "model.config.example.toml" )
32+ self .config = self ._load_toml_config ()
2833
2934 # Configure model list from environment variables
30- model_list = [
31- {
32- "model_name" : "gpt-4" ,
33- "litellm_params" : {
34- "model" : f"azure/{ os .getenv ('GPT4O_1_NAME' )} " ,
35- "base_url" : os .getenv ("GPT4O_1_ENDPOINT" ),
36- "api_key" : os .getenv ("GPT4O_1_API_KEY" ),
37- "api_version" : os .getenv ("GPT4O_1_API_VERSION" ),
38- "weight" : int (os .getenv ("GPT4O_1_WEIGHT" , 5 )),
39- },
40- },
41- {
42- "model_name" : "gpt-4" ,
43- "litellm_params" : {
44- "model" : f"azure/{ os .getenv ('GPT41_MINI_NAME' )} " ,
45- "base_url" : os .getenv ("GPT41_MINI_ENDPOINT" ),
46- "api_key" : os .getenv ("GPT41_MINI_API_KEY" ),
47- "api_version" : os .getenv ("GPT41_MINI_API_VERSION" ),
48- "weight" : int (os .getenv ("GPT41_MINI_WEIGHT" , 5 )),
49- },
50- },
51- {
52- "model_name" : "gpt-4-reasoning" ,
53- "litellm_params" : {
54- "model" : f"azure/{ os .getenv ('O3_MINI_NAME' )} " ,
55- "api_base" : os .getenv ("O3_MINI_ENDPOINT" ),
56- "api_key" : os .getenv ("O3_MINI_API_KEY" ),
57- "api_version" : os .getenv ("O3_MINI_API_VERSION" ),
58- "weight" : int (os .getenv ("O3_MINI_WEIGHT" , 1 )),
59- },
60- },
61- ]
62-
63- client_router_kwargs = {
64- "routing_strategy" : "simple-shuffle" ,
65- "fallbacks" : [
66- {
67- "gpt-4" : ["gpt-4-reasoning" ] # Fallback to reasoning model if both GPT-4 instances fail
68- }
69- ],
70- # "set_verbose": True,
71- # "debug_level": "DEBUG",
72- "default_litellm_params" : {"drop_params" : True }, # Global setting for dropping unsupported parameters
73- }
74-
35+ model_list = self ._load_model_config ()
36+ client_router_kwargs = self ._load_router_config ()
37+
7538 # The model_id for LiteLLMRouterModel is the default model group the router will target.
7639 # Our _get_target_model() will override this per call via the 'model' param in generate().
77- default_model_group = "gpt-4"
40+ default_model_group = os .getenv ("LITELLM_DEFAULT_MODEL_GROUP" )
41+
42+ if not default_model_group :
43+ raise exceptions .EnvironmentVariableNotFoundException (
44+ "LITELLM_DEFAULT_MODEL_GROUP environment variable not found. Please set it to the default model group."
45+ )
7846
7947 super ().__init__ (
8048 model_id = default_model_group ,
81- model_list = model_list ,
82- client_kwargs = client_router_kwargs ,
49+ model_list = [ model . dict () for model in model_list ] ,
50+ client_kwargs = client_router_kwargs . dict () ,
8351 ** kwargs , # Pass through other LiteLLMModel/Model kwargs
8452 )
8553
54+ def _load_toml_config (self ) -> Dict [str , Any ]:
55+ """
56+ Load configuration from a TOML file.
57+
58+ Returns:
59+ Dict[str, Any]: Configuration loaded from the TOML file.
60+ """
61+
62+ if not os .path .exists (self .config_path ):
63+ raise exceptions .ModelConfigFileNotFoundException (
64+ f"Model config file not found at { self .config_path } . Please check the path."
65+ )
66+
67+ try :
68+ with open (self .config_path , "r" ) as f :
69+ return toml .load (f )
70+ except Exception as e :
71+ logger .error (f"Failed to load TOML config: { e } " )
72+ return {}
73+
74+ def _load_model_config (self ) -> List [Dict [str , Any ]]:
75+ """
76+ Load model configuration from environment variables.
77+
78+ Returns:
79+ List[Dict[str, Any]]: List of model configurations.
80+
81+ """
82+ model_entries = self .config .get ("model" , [])
83+ model_list = []
84+
85+ if isinstance (model_entries , dict ):
86+ # In case there's only one model (TOML parser returns dict)
87+ model_entries = [model_entries ]
88+
89+ for entry in model_entries :
90+ model_list .append (models .ModelConfig (
91+ model_name = entry .get ("model_name" ),
92+ litellm_params = models .LiteLLMParams (
93+ ** entry .get ("litellm_params" )
94+ )
95+ ))
96+
97+ if not model_list :
98+ raise exceptions .ModelListNotFoundException (
99+ "No model list found in config toml. Please check the configuration."
100+ )
101+
102+ return model_list
103+
104+ def _load_router_config (self ) -> models .RouterConfig :
105+ """
106+ Load router configuration from environment variables.
107+
108+ Returns:
109+ models.RouterConfig: Router configuration
110+ """
111+ router_config = models .RouterConfig (** self .config .get ("router_config" ))
112+
113+ if not router_config :
114+ logger .warning ("No router config found in model-config.toml. Using defaults." )
115+ return models .RouterConfig (
116+ routing_strategy = "simple-shuffle" ,
117+ fallbacks = [],
118+ default_litellm_params = {"drop_params" : True },
119+ )
120+ return router_config
121+
122+
86123 def _get_target_model (self ) -> str :
87124 """
88125 Determine which model to route to based on the current handle configuration.
0 commit comments