Skip to content

Commit fe305bb

Browse files
committed
support azure openai MODEL_NAME for backward compatibility
1 parent 38fa9a7 commit fe305bb

File tree

4 files changed

+381
-55
lines changed

4 files changed

+381
-55
lines changed

src/aks-agent/azext_aks_agent/agent/llm_config_manager.py

Lines changed: 55 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,12 @@
99

1010
import yaml
1111
from azext_aks_agent._consts import CONST_AGENT_CONFIG_FILE_NAME
12+
from azext_aks_agent.agent.llm_providers import PROVIDER_REGISTRY
1213
from azure.cli.core.api import get_config_dir
1314
from azure.cli.core.azclierror import AzCLIError
15+
from knack.log import get_logger
16+
17+
logger = get_logger(__name__)
1418

1519

1620
class LLMConfigManager:
@@ -67,18 +71,30 @@ def save(self, provider_name: str, params: dict):
6771
configs = {}
6872

6973
models = configs.get("llms", [])
70-
model_name = params.get("MODEL_NAME")
71-
if not model_name:
72-
raise ValueError("MODEL_NAME is required to save configuration.")
73-
74-
# Check if model already exists, update it and move it to the last;
75-
# otherwise, append new
76-
models = [
77-
cfg for cfg in models if not (
78-
cfg.get("provider") == provider_name and cfg.get("MODEL_NAME") == model_name)]
79-
models.append({"provider": provider_name, **params})
8074

81-
configs["llms"] = models
75+
# modify existing azure openai config from model name to deloyment name
76+
for model in models:
77+
if provider_name.lower() == "azure" and "MODEL_NAME" in model:
78+
model["DEPLOYMENT_NAME"] = model.pop("MODEL_NAME")
79+
80+
def _update_llm_config(provider_name, required_key, params, existing_models):
81+
required_value = params.get(required_key)
82+
if not required_value:
83+
raise ValueError(f"{required_key} is required to save configuration.")
84+
85+
# Check if model already exists, update it and move it to the last;
86+
# otherwise, append the new one.
87+
models = [
88+
cfg for cfg in existing_models if not (
89+
cfg.get("provider") == provider_name and cfg.get(required_key) == required_value)]
90+
models.append({"provider": provider_name, **params})
91+
return models
92+
93+
# To be consistent, we expose DEPLOYMENT_NAME for Azure provider in both configuration file and init prompts.
94+
if provider_name.lower() == "azure":
95+
configs["llms"] = _update_llm_config(provider_name, "DEPLOYMENT_NAME", params, models)
96+
else:
97+
configs["llms"] = _update_llm_config(provider_name, "MODEL_NAME", params, models)
8298

8399
with open(self.config_path, "w") as f:
84100
yaml.safe_dump(configs, f, sort_keys=False)
@@ -112,15 +128,17 @@ def get_specific(
112128
"""
113129
model_configs = self.get_list()
114130
for cfg in model_configs:
115-
if cfg.get("provider") == provider_name and cfg.get(
116-
"MODEL_NAME") == model_name:
117-
return cfg
131+
if cfg.get("provider") == provider_name:
132+
if provider_name.lower() == "azure" and (cfg.get("DEPLOYMENT_NAME") == model_name or cfg.get("MODEL_NAME") == model_name):
133+
return cfg
134+
elif cfg.get("MODEL_NAME") == model_name:
135+
return cfg
118136
return None
119137

120138
def get_model_config(self, model) -> Optional[Dict]:
121-
prompt_for_init = "Run 'az aks agent-init' to set up your LLM endpoint (recommended path).\n" \
122-
"To configure your LLM manually, create a config file using the templates provided here: "\
123-
"https://aka.ms/aks/agentic-cli/init"
139+
prompt_for_init = "Run 'az aks agent-init' to set up your LLM endpoint (recommended path).\n"
140+
"To configure your LLM manually, create a config file using the templates provided here: "
141+
"https://aka.ms/aks/agentic-cli/init"
124142

125143
if not model:
126144
llm_config: Optional[Dict] = self.get_latest()
@@ -147,3 +165,23 @@ def is_config_complete(self, config, provider_schema):
147165
config.get(key)):
148166
return False
149167
return True
168+
169+
def export_model_config(self, llm_config) -> str:
170+
# Check if the configuration is complete
171+
provider_name = llm_config.get("provider")
172+
provider_instance = PROVIDER_REGISTRY.get(provider_name)()
173+
# NOTE(mainred) for backward compatibility with Azure OpenAI, replace the MODEL_NAME with DEPLOYMENT_NAME
174+
if provider_name.lower() == "azure" and "MODEL_NAME" in llm_config:
175+
llm_config["DEPLOYMENT_NAME"] = llm_config.pop("MODEL_NAME")
176+
177+
model_name_key = "MODEL_NAME" if provider_name.lower() != "azure" else "DEPLOYMENT_NAME"
178+
model = provider_instance.model_name(llm_config.get(model_name_key))
179+
180+
# Set environment variables for the model provider
181+
for k, v in llm_config.items():
182+
if k not in ["provider", "MODEL_NAME", "DEPLOYMENT_NAME"]:
183+
os.environ[k] = v
184+
logger.info(
185+
"Using provider: %s, model: %s, Env vars setup successfully.", provider_name, llm_config.get("MODEL_NAME"))
186+
187+
return model

src/aks-agent/azext_aks_agent/agent/llm_providers/azure_provider.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,19 +62,19 @@ def validate_connection(self, params: dict) -> Tuple[bool, str, str]:
6262
api_key = params.get("AZURE_API_KEY")
6363
api_base = params.get("AZURE_API_BASE")
6464
api_version = params.get("AZURE_API_VERSION")
65-
model_name = params.get("MODEL_NAME")
65+
deployment_name = params.get("DEPLOYMENT_NAME")
6666

67-
if not all([api_key, api_base, api_version, model_name]):
67+
if not all([api_key, api_base, api_version, deployment_name]):
6868
return False, "Missing required Azure parameters.", "retry_input"
6969

7070
# REST API reference: https://learn.microsoft.com/en-us/azure/ai-foundry/openai/api-version-lifecycle?tabs=rest
71-
url = urljoin(api_base, f"openai/deployments/{model_name}/chat/completions")
71+
url = urljoin(api_base, f"openai/deployments/{deployment_name}/chat/completions")
7272

7373
query = {"api-version": api_version}
7474
full_url = f"{url}?{urlencode(query)}"
7575
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
7676
payload = {
77-
"model": model_name,
77+
"model": deployment_name,
7878
"messages": [{"role": "user", "content": "ping"}],
7979
"max_tokens": 16
8080
}

src/aks-agent/azext_aks_agent/custom.py

Lines changed: 21 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import os
77

8-
from azext_aks_agent._consts import CONST_AGENT_CONFIG_FILE_NAME
8+
from azext_aks_agent._consts import CONST_AGENT_CONFIG_FILE_NAME, HELP_COLOR
99

1010
# pylint: disable=too-many-lines, disable=broad-except
1111
from azext_aks_agent.agent.agent import aks_agent as aks_agent_internal
@@ -25,29 +25,28 @@
2525
# pylint: disable=unused-argument
2626
def aks_agent_init(cmd):
2727
"""Initialize AKS agent llm configuration."""
28+
from rich.console import Console
29+
console = Console()
30+
console.print(
31+
"Welcome to AKS Agent LLM configuration setup. Type '/exit' to exit.",
32+
style=f"bold {HELP_COLOR}")
2833

29-
with rich_logging() as console:
30-
from holmes.utils.colors import HELP_COLOR
31-
console.print(
32-
"Welcome to AKS Agent LLM configuration setup. Type '/exit' to exit.",
33-
style=f"bold {HELP_COLOR}")
34-
35-
provider = prompt_provider_choice()
36-
params = provider.prompt_params()
34+
provider = prompt_provider_choice()
35+
params = provider.prompt_params()
3736

38-
llm_config_manager = LLMConfigManager()
39-
# If the connection to the model endpoint is valid, save the configuration
40-
is_valid, msg, action = provider.validate_connection(params)
37+
llm_config_manager = LLMConfigManager()
38+
# If the connection to the model endpoint is valid, save the configuration
39+
is_valid, msg, action = provider.validate_connection(params)
4140

42-
if is_valid and action == "save":
43-
llm_config_manager.save(provider.model_route if provider.model_route else "openai", params)
44-
console.print(
45-
f"LLM configuration setup successfully and is saved to {llm_config_manager.config_path}.",
46-
style=f"bold {HELP_COLOR}")
47-
elif not is_valid and action == "retry_input":
48-
raise AzCLIError(f"Please re-run `az aks agent-init` to correct the input parameters. {str(msg)}")
49-
else:
50-
raise AzCLIError(f"Please check your deployed model and network connectivity. {str(msg)}")
41+
if is_valid and action == "save":
42+
llm_config_manager.save(provider.model_route if provider.model_route else "openai", params)
43+
console.print(
44+
f"LLM configuration setup successfully and is saved to {llm_config_manager.config_path}.",
45+
style=f"bold {HELP_COLOR}")
46+
elif not is_valid and action == "retry_input":
47+
raise AzCLIError(f"Please re-run `az aks agent-init` to correct the input parameters. {str(msg)}")
48+
else:
49+
raise AzCLIError(f"Please check your deployed model and network connectivity. {str(msg)}")
5150

5251

5352
# pylint: disable=unused-argument
@@ -81,18 +80,7 @@ def aks_agent(
8180
llm_config_manager = LLMConfigManager(config_file)
8281
llm_config_manager.validate_config()
8382
llm_config = llm_config_manager.get_model_config(model)
84-
85-
# Check if the configuration is complete
86-
provider_name = llm_config.get("provider")
87-
provider_instance = PROVIDER_REGISTRY.get(provider_name)()
88-
model = provider_instance.model_name(llm_config.get("MODEL_NAME"))
89-
90-
# Set environment variables for the model provider
91-
for k, v in llm_config.items():
92-
if k not in ["provider", "MODEL_NAME"]:
93-
os.environ[k] = v
94-
logger.info(
95-
"Using provider: %s, model: %s, Env vars setup successfully.", provider_name, llm_config.get("MODEL_NAME"))
83+
llm_config_manager.export_model_config(llm_config)
9684

9785
with rich_logging():
9886
aks_agent_internal(

0 commit comments

Comments
 (0)