Skip to content

Commit 4ba36d5

Browse files
authored
{AKS}: agent-init replaces model name with deployment name for Azure OpenAI service (#9391)
1 parent 26d21d0 commit 4ba36d5

File tree

9 files changed

+393
-65
lines changed

9 files changed

+393
-65
lines changed

src/aks-agent/HISTORY.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,11 @@ To release a new version, please select a new version number (usually plus 1 to
1212
Pending
1313
+++++++
1414

15+
1.0.0b9
16+
+++++++
17+
* agent-init: replace model name with deployment name for Azure OpenAI service.
18+
* agent-init: remove importing holmesgpt to resolve the latency issue.
19+
1520
1.0.0b8
1621
+++++++
1722
* Error handling: dont raise traceback for init prompt and holmesgpt interaction.

src/aks-agent/azext_aks_agent/_consts.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,7 @@
2929
CONST_MCP_MIN_VERSION = "0.0.10"
3030
CONST_MCP_GITHUB_REPO = "Azure/aks-mcp"
3131
CONST_MCP_BINARY_DIR = "bin"
32+
33+
# Color constants for terminal output
34+
HELP_COLOR = "cyan" # same as AI_COLOR for now
35+
ERROR_COLOR = "red"

src/aks-agent/azext_aks_agent/agent/llm_config_manager.py

Lines changed: 52 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,12 @@
99

1010
import yaml
1111
from azext_aks_agent._consts import CONST_AGENT_CONFIG_FILE_NAME
12+
from azext_aks_agent.agent.llm_providers import PROVIDER_REGISTRY
1213
from azure.cli.core.api import get_config_dir
1314
from azure.cli.core.azclierror import AzCLIError
15+
from knack.log import get_logger
16+
17+
logger = get_logger(__name__)
1418

1519

1620
class LLMConfigManager:
@@ -67,18 +71,30 @@ def save(self, provider_name: str, params: dict):
6771
configs = {}
6872

6973
models = configs.get("llms", [])
70-
model_name = params.get("MODEL_NAME")
71-
if not model_name:
72-
raise ValueError("MODEL_NAME is required to save configuration.")
73-
74-
# Check if model already exists, update it and move it to the last;
75-
# otherwise, append new
76-
models = [
77-
cfg for cfg in models if not (
78-
cfg.get("provider") == provider_name and cfg.get("MODEL_NAME") == model_name)]
79-
models.append({"provider": provider_name, **params})
8074

81-
configs["llms"] = models
75+
# modify existing azure openai config from model name to deloyment name
76+
for model in models:
77+
if provider_name.lower() == "azure" and "MODEL_NAME" in model:
78+
model["DEPLOYMENT_NAME"] = model.pop("MODEL_NAME")
79+
80+
def _update_llm_config(provider_name, required_key, params, existing_models):
81+
required_value = params.get(required_key)
82+
if not required_value:
83+
raise ValueError(f"{required_key} is required to save configuration.")
84+
85+
# Check if model already exists, update it and move it to the last;
86+
# otherwise, append the new one.
87+
models = [
88+
cfg for cfg in existing_models if not (
89+
cfg.get("provider") == provider_name and cfg.get(required_key) == required_value)]
90+
models.append({"provider": provider_name, **params})
91+
return models
92+
93+
# To be consistent, we expose DEPLOYMENT_NAME for Azure provider in both configuration file and init prompts.
94+
if provider_name.lower() == "azure":
95+
configs["llms"] = _update_llm_config(provider_name, "DEPLOYMENT_NAME", params, models)
96+
else:
97+
configs["llms"] = _update_llm_config(provider_name, "MODEL_NAME", params, models)
8298

8399
with open(self.config_path, "w") as f:
84100
yaml.safe_dump(configs, f, sort_keys=False)
@@ -112,14 +128,16 @@ def get_specific(
112128
"""
113129
model_configs = self.get_list()
114130
for cfg in model_configs:
115-
if cfg.get("provider") == provider_name and cfg.get(
116-
"MODEL_NAME") == model_name:
131+
if cfg.get("provider") == provider_name and provider_name.lower() == "azure":
132+
if cfg.get("DEPLOYMENT_NAME") == model_name or cfg.get("MODEL_NAME") == model_name:
133+
return cfg
134+
if cfg.get("provider") == provider_name and cfg.get("MODEL_NAME") == model_name:
117135
return cfg
118136
return None
119137

120138
def get_model_config(self, model) -> Optional[Dict]:
121139
prompt_for_init = "Run 'az aks agent-init' to set up your LLM endpoint (recommended path).\n" \
122-
"To configure your LLM manually, create a config file using the templates provided here: "\
140+
"To configure your LLM manually, create a config file using the templates provided here: " \
123141
"https://aka.ms/aks/agentic-cli/init"
124142

125143
if not model:
@@ -147,3 +165,23 @@ def is_config_complete(self, config, provider_schema):
147165
config.get(key)):
148166
return False
149167
return True
168+
169+
def export_model_config(self, llm_config) -> str:
170+
# Check if the configuration is complete
171+
provider_name = llm_config.get("provider")
172+
provider_instance = PROVIDER_REGISTRY.get(provider_name)()
173+
# NOTE(mainred) for backward compatibility with Azure OpenAI, replace the MODEL_NAME with DEPLOYMENT_NAME
174+
if provider_name.lower() == "azure" and "MODEL_NAME" in llm_config:
175+
llm_config["DEPLOYMENT_NAME"] = llm_config.pop("MODEL_NAME")
176+
177+
model_name_key = "MODEL_NAME" if provider_name.lower() != "azure" else "DEPLOYMENT_NAME"
178+
model = provider_instance.model_name(llm_config.get(model_name_key))
179+
180+
# Set environment variables for the model provider
181+
for k, v in llm_config.items():
182+
if k not in ["provider", "MODEL_NAME", "DEPLOYMENT_NAME"]:
183+
os.environ[k] = v
184+
logger.info(
185+
"Using provider: %s, model: %s, Env vars setup successfully.", provider_name, llm_config.get("MODEL_NAME"))
186+
187+
return model

src/aks-agent/azext_aks_agent/agent/llm_providers/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from typing import List, Tuple
77

8+
from azext_aks_agent._consts import ERROR_COLOR, HELP_COLOR
89
from rich.console import Console
910

1011
from .anthropic_provider import AnthropicProvider
@@ -47,7 +48,6 @@ def _get_provider_by_index(idx: int) -> LLMProvider:
4748
Return provider instance by numeric index (1-based).
4849
Raises ValueError if index is out of range.
4950
"""
50-
from holmes.utils.colors import HELP_COLOR
5151
if 1 <= idx <= len(_PROVIDER_CLASSES):
5252
console.print("You selected provider:", _PROVIDER_CLASSES[idx - 1]().readable_name, style=f"bold {HELP_COLOR}")
5353
return _PROVIDER_CLASSES[idx - 1]()
@@ -59,7 +59,6 @@ def prompt_provider_choice() -> LLMProvider:
5959
Show a numbered menu and return the chosen provider instance.
6060
Keeps prompting until a valid selection is made.
6161
"""
62-
from holmes.utils.colors import ERROR_COLOR, HELP_COLOR
6362
choices = _provider_choices_numbered()
6463
if not choices:
6564
raise ValueError("No providers are registered.")

src/aks-agent/azext_aks_agent/agent/llm_providers/azure_provider.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,10 @@ def model_route(self) -> str:
3232
@property
3333
def parameter_schema(self):
3434
return {
35-
"MODEL_NAME": {
35+
"DEPLOYMENT_NAME": {
3636
"secret": False,
3737
"default": None,
38-
"hint": "should be consistent with your deployed name, e.g., gpt-5",
38+
"hint": "ensure your deployment name is the same as the model name, e.g., gpt-5",
3939
"validator": non_empty
4040
},
4141
"AZURE_API_KEY": {
@@ -62,19 +62,19 @@ def validate_connection(self, params: dict) -> Tuple[bool, str, str]:
6262
api_key = params.get("AZURE_API_KEY")
6363
api_base = params.get("AZURE_API_BASE")
6464
api_version = params.get("AZURE_API_VERSION")
65-
model_name = params.get("MODEL_NAME")
65+
deployment_name = params.get("DEPLOYMENT_NAME")
6666

67-
if not all([api_key, api_base, api_version, model_name]):
67+
if not all([api_key, api_base, api_version, deployment_name]):
6868
return False, "Missing required Azure parameters.", "retry_input"
6969

7070
# REST API reference: https://learn.microsoft.com/en-us/azure/ai-foundry/openai/api-version-lifecycle?tabs=rest
71-
url = urljoin(api_base, f"openai/deployments/{model_name}/chat/completions")
71+
url = urljoin(api_base, f"openai/deployments/{deployment_name}/chat/completions")
7272

7373
query = {"api-version": api_version}
7474
full_url = f"{url}?{urlencode(query)}"
7575
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
7676
payload = {
77-
"model": model_name,
77+
"model": deployment_name,
7878
"messages": [{"role": "user", "content": "ping"}],
7979
"max_tokens": 16
8080
}

src/aks-agent/azext_aks_agent/agent/llm_providers/base.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from typing import Any, Callable, Dict, Tuple
99
from urllib.parse import urlparse
1010

11+
from azext_aks_agent._consts import ERROR_COLOR, HELP_COLOR
1112
from rich.console import Console
1213

1314
console = Console()
@@ -85,9 +86,6 @@ def parameter_schema(self) -> Dict[str, Dict[str, Any]]:
8586

8687
def prompt_params(self):
8788
"""Prompt user for parameters using parameter_schema when available."""
88-
from holmes.interactive import SlashCommands
89-
from holmes.utils.colors import ERROR_COLOR, HELP_COLOR
90-
9189
schema = self.parameter_schema
9290
params = {}
9391
for param, meta in schema.items():
@@ -134,7 +132,7 @@ def prompt_params(self):
134132
params[param] = value
135133
break
136134
console.print(
137-
f"Invalid value for {param}. Please try again, or type '{SlashCommands.EXIT.command}' to exit.",
135+
f"Invalid value for {param}. Please try again, or type '/exit' to exit.",
138136
style=f"{ERROR_COLOR}")
139137

140138
return params

src/aks-agent/azext_aks_agent/custom.py

Lines changed: 22 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,12 @@
55

66
import os
77

8-
from azext_aks_agent._consts import CONST_AGENT_CONFIG_FILE_NAME
8+
from azext_aks_agent._consts import CONST_AGENT_CONFIG_FILE_NAME, HELP_COLOR
99

1010
# pylint: disable=too-many-lines, disable=broad-except
1111
from azext_aks_agent.agent.agent import aks_agent as aks_agent_internal
1212
from azext_aks_agent.agent.llm_config_manager import LLMConfigManager
13-
from azext_aks_agent.agent.llm_providers import (
14-
PROVIDER_REGISTRY,
15-
prompt_provider_choice,
16-
)
13+
from azext_aks_agent.agent.llm_providers import prompt_provider_choice
1714
from azext_aks_agent.agent.logging import rich_logging
1815
from azure.cli.core.api import get_config_dir
1916
from azure.cli.core.azclierror import AzCLIError
@@ -25,29 +22,28 @@
2522
# pylint: disable=unused-argument
2623
def aks_agent_init(cmd):
2724
"""Initialize AKS agent llm configuration."""
25+
from rich.console import Console
26+
console = Console()
27+
console.print(
28+
"Welcome to AKS Agent LLM configuration setup. Type '/exit' to exit.",
29+
style=f"bold {HELP_COLOR}")
2830

29-
with rich_logging() as console:
30-
from holmes.utils.colors import HELP_COLOR
31-
console.print(
32-
"Welcome to AKS Agent LLM configuration setup. Type '/exit' to exit.",
33-
style=f"bold {HELP_COLOR}")
34-
35-
provider = prompt_provider_choice()
36-
params = provider.prompt_params()
31+
provider = prompt_provider_choice()
32+
params = provider.prompt_params()
3733

38-
llm_config_manager = LLMConfigManager()
39-
# If the connection to the model endpoint is valid, save the configuration
40-
is_valid, msg, action = provider.validate_connection(params)
34+
llm_config_manager = LLMConfigManager()
35+
# If the connection to the model endpoint is valid, save the configuration
36+
is_valid, msg, action = provider.validate_connection(params)
4137

42-
if is_valid and action == "save":
43-
llm_config_manager.save(provider.model_route if provider.model_route else "openai", params)
44-
console.print(
45-
f"LLM configuration setup successfully and is saved to {llm_config_manager.config_path}.",
46-
style=f"bold {HELP_COLOR}")
47-
elif not is_valid and action == "retry_input":
48-
raise AzCLIError(f"Please re-run `az aks agent-init` to correct the input parameters. {str(msg)}")
49-
else:
50-
raise AzCLIError(f"Please check your deployed model and network connectivity. {str(msg)}")
38+
if is_valid and action == "save":
39+
llm_config_manager.save(provider.model_route if provider.model_route else "openai", params)
40+
console.print(
41+
f"LLM configuration setup successfully and is saved to {llm_config_manager.config_path}.",
42+
style=f"bold {HELP_COLOR}")
43+
elif not is_valid and action == "retry_input":
44+
raise AzCLIError(f"Please re-run `az aks agent-init` to correct the input parameters. {str(msg)}")
45+
else:
46+
raise AzCLIError(f"Please check your deployed model and network connectivity. {str(msg)}")
5147

5248

5349
# pylint: disable=unused-argument
@@ -81,18 +77,7 @@ def aks_agent(
8177
llm_config_manager = LLMConfigManager(config_file)
8278
llm_config_manager.validate_config()
8379
llm_config = llm_config_manager.get_model_config(model)
84-
85-
# Check if the configuration is complete
86-
provider_name = llm_config.get("provider")
87-
provider_instance = PROVIDER_REGISTRY.get(provider_name)()
88-
model = provider_instance.model_name(llm_config.get("MODEL_NAME"))
89-
90-
# Set environment variables for the model provider
91-
for k, v in llm_config.items():
92-
if k not in ["provider", "MODEL_NAME"]:
93-
os.environ[k] = v
94-
logger.info(
95-
"Using provider: %s, model: %s, Env vars setup successfully.", provider_name, llm_config.get("MODEL_NAME"))
80+
llm_config_manager.export_model_config(llm_config)
9681

9782
with rich_logging():
9883
aks_agent_internal(

0 commit comments

Comments
 (0)