99
1010import yaml
1111from azext_aks_agent ._consts import CONST_AGENT_CONFIG_FILE_NAME
12+ from azext_aks_agent .agent .llm_providers import PROVIDER_REGISTRY
1213from azure .cli .core .api import get_config_dir
1314from azure .cli .core .azclierror import AzCLIError
15+ from knack .log import get_logger
16+
17+ logger = get_logger (__name__ )
1418
1519
1620class LLMConfigManager :
@@ -67,18 +71,30 @@ def save(self, provider_name: str, params: dict):
6771 configs = {}
6872
6973 models = configs .get ("llms" , [])
70- model_name = params .get ("MODEL_NAME" )
71- if not model_name :
72- raise ValueError ("MODEL_NAME is required to save configuration." )
73-
74- # Check if model already exists, update it and move it to the last;
75- # otherwise, append new
76- models = [
77- cfg for cfg in models if not (
78- cfg .get ("provider" ) == provider_name and cfg .get ("MODEL_NAME" ) == model_name )]
79- models .append ({"provider" : provider_name , ** params })
8074
81- configs ["llms" ] = models
75+ # modify existing azure openai config from model name to deloyment name
76+ for model in models :
77+ if provider_name .lower () == "azure" and "MODEL_NAME" in model :
78+ model ["DEPLOYMENT_NAME" ] = model .pop ("MODEL_NAME" )
79+
80+ def _update_llm_config (provider_name , required_key , params , existing_models ):
81+ required_value = params .get (required_key )
82+ if not required_value :
83+ raise ValueError (f"{ required_key } is required to save configuration." )
84+
85+ # Check if model already exists, update it and move it to the last;
86+ # otherwise, append the new one.
87+ models = [
88+ cfg for cfg in existing_models if not (
89+ cfg .get ("provider" ) == provider_name and cfg .get (required_key ) == required_value )]
90+ models .append ({"provider" : provider_name , ** params })
91+ return models
92+
93+ # To be consistent, we expose DEPLOYMENT_NAME for Azure provider in both configuration file and init prompts.
94+ if provider_name .lower () == "azure" :
95+ configs ["llms" ] = _update_llm_config (provider_name , "DEPLOYMENT_NAME" , params , models )
96+ else :
97+ configs ["llms" ] = _update_llm_config (provider_name , "MODEL_NAME" , params , models )
8298
8399 with open (self .config_path , "w" ) as f :
84100 yaml .safe_dump (configs , f , sort_keys = False )
@@ -112,15 +128,17 @@ def get_specific(
112128 """
113129 model_configs = self .get_list ()
114130 for cfg in model_configs :
115- if cfg .get ("provider" ) == provider_name and cfg .get (
116- "MODEL_NAME" ) == model_name :
117- return cfg
131+ if cfg .get ("provider" ) == provider_name :
132+ if provider_name .lower () == "azure" and (cfg .get ("DEPLOYMENT_NAME" ) == model_name or cfg .get ("MODEL_NAME" ) == model_name ):
133+ return cfg
134+ elif cfg .get ("MODEL_NAME" ) == model_name :
135+ return cfg
118136 return None
119137
120138 def get_model_config (self , model ) -> Optional [Dict ]:
121- prompt_for_init = "Run 'az aks agent-init' to set up your LLM endpoint (recommended path).\n " \
122- "To configure your LLM manually, create a config file using the templates provided here: " \
123- "https://aka.ms/aks/agentic-cli/init"
139+ prompt_for_init = "Run 'az aks agent-init' to set up your LLM endpoint (recommended path).\n "
140+ "To configure your LLM manually, create a config file using the templates provided here: "
141+ "https://aka.ms/aks/agentic-cli/init"
124142
125143 if not model :
126144 llm_config : Optional [Dict ] = self .get_latest ()
@@ -147,3 +165,23 @@ def is_config_complete(self, config, provider_schema):
147165 config .get (key )):
148166 return False
149167 return True
168+
169+ def export_model_config (self , llm_config ) -> str :
170+ # Check if the configuration is complete
171+ provider_name = llm_config .get ("provider" )
172+ provider_instance = PROVIDER_REGISTRY .get (provider_name )()
173+ # NOTE(mainred) for backward compatibility with Azure OpenAI, replace the MODEL_NAME with DEPLOYMENT_NAME
174+ if provider_name .lower () == "azure" and "MODEL_NAME" in llm_config :
175+ llm_config ["DEPLOYMENT_NAME" ] = llm_config .pop ("MODEL_NAME" )
176+
177+ model_name_key = "MODEL_NAME" if provider_name .lower () != "azure" else "DEPLOYMENT_NAME"
178+ model = provider_instance .model_name (llm_config .get (model_name_key ))
179+
180+ # Set environment variables for the model provider
181+ for k , v in llm_config .items ():
182+ if k not in ["provider" , "MODEL_NAME" , "DEPLOYMENT_NAME" ]:
183+ os .environ [k ] = v
184+ logger .info (
185+ "Using provider: %s, model: %s, Env vars setup successfully." , provider_name , llm_config .get ("MODEL_NAME" ))
186+
187+ return model
0 commit comments