diff --git a/src/aks-agent/HISTORY.rst b/src/aks-agent/HISTORY.rst index 07b6ddfc8ca..6acf6b4cc64 100644 --- a/src/aks-agent/HISTORY.rst +++ b/src/aks-agent/HISTORY.rst @@ -11,6 +11,12 @@ To release a new version, please select a new version number (usually plus 1 to Pending +++++++ + +1.0.0b8 ++++++++ +* Error handling: dont raise traceback for init prompt and holmesgpt interaction. +* Improve aks agent-init user experience +* Improve the user holmesgpt interaction error handling * Fix stdin reading hang in CI/CD pipelines by using select with timeout for non-interactive mode. * Update pytest marker registration and fix datetime.utcnow() deprecation warning in tests. * Improve test framework with real-time stderr output visibility and subprocess timeout. diff --git a/src/aks-agent/README.rst b/src/aks-agent/README.rst index 2d33953250a..7de9c76e391 100644 --- a/src/aks-agent/README.rst +++ b/src/aks-agent/README.rst @@ -37,6 +37,98 @@ For more details about supported model providers and required variables, see: https://docs.litellm.ai/docs/providers +LLM Configuration Explained +--------------------------- + +The AKS Agent uses YAML configuration files to define LLM connections. Each configuration contains a provider specification and the required environment variables for that provider. + +Configuration Structure +^^^^^^^^^^^^^^^^^^^^^^^^ + +.. code-block:: yaml + + llms: + - provider: azure + MODEL_NAME: gpt-4.1 + AZURE_API_KEY: ******* + AZURE_API_BASE: https://{azure-openai-service}.openai.azure.com/ + AZURE_API_VERSION: 2025-04-01-preview + +Field Explanations +^^^^^^^^^^^^^^^^^^ + +**provider** + The LiteLLM provider route that determines which LLM service to use. This follows the LiteLLM provider specification from https://docs.litellm.ai/docs/providers. + + Common values: + + * ``azure`` - Azure OpenAI Service + * ``openai`` - OpenAI API and OpenAI-compatible APIs (e.g., local models, other services) + * ``anthropic`` - Anthropic Claude + * ``gemini`` - Google's Gemini + * ``openai_compatible`` - OpenAI-compatible APIs (e.g., local models, other services) + +**MODEL_NAME** + The specific model or deployment name to use. This varies by provider: + + * For Azure OpenAI: Your deployment name (e.g., ``gpt-4.1``, ``gpt-35-turbo``) + * For OpenAI: Model name (e.g., ``gpt-4``, ``gpt-3.5-turbo``) + * For other providers: Check the specific model names in LiteLLM documentation + +**Environment Variables by Provider** + +The remaining fields are environment variables required by each provider. These correspond to the authentication and configuration requirements of each LLM service: + +**Azure OpenAI (provider: azure)** + * ``AZURE_API_KEY`` - Your Azure OpenAI API key + * ``AZURE_API_BASE`` - Your Azure OpenAI endpoint URL (e.g., https://your-resource.openai.azure.com/) + * ``AZURE_API_VERSION`` - API version (e.g., 2024-02-01, 2025-04-01-preview) + +**OpenAI (provider: openai)** + * ``OPENAI_API_KEY`` - Your OpenAI API key (starts with sk-) + +**Gemini (provider: gemini)** + * ``GOOGLE_API_KEY`` - Your Google Cloud API key + * ``GOOGLE_API_ENDPOINT`` - Base URL for the Gemini API endpoint + +**Anthropic (provider: anthropic)** + * ``ANTHROPIC_API_KEY`` - Your Anthropic API key + +**OpenAI Compatible (provider: openai_compatible)** + * ``OPENAI_API_BASE`` - Base URL for the API endpoint + * ``OPENAI_API_KEY`` - API key (if required by the service) + +Multiple Model Configuration +^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +You can configure multiple models in a single file: + +.. code-block:: yaml + + llms: + - provider: azure + MODEL_NAME: gpt-4 + AZURE_API_KEY: your-azure-key + AZURE_API_BASE: https://your-azure-endpoint.openai.azure.com/ + AZURE_API_VERSION: 2024-02-01 + - provider: openai + MODEL_NAME: gpt-4 + OPENAI_API_KEY: your-openai-key + - provider: anthropic + MODEL_NAME: claude-3-sonnet-20240229 + ANTHROPIC_API_KEY: your-anthropic-key + +When using ``--model``, specify the provider and model as ``provider/model_name`` (e.g., ``azure/gpt-4``, ``openai/gpt-4``). + +Security Note +^^^^^^^^^^^^^ + +API keys and credentials in configuration files should be kept secure. Consider using: + +* Restricted file permissions (``chmod 600 config.yaml``) +* Environment variable substitution where supported +* Separate configuration files for different environments (dev/prod) + Quick start and examples ========================= diff --git a/src/aks-agent/azext_aks_agent/__init__.py b/src/aks-agent/azext_aks_agent/__init__.py index 213cda1be6f..c7fdcb8c3b2 100644 --- a/src/aks-agent/azext_aks_agent/__init__.py +++ b/src/aks-agent/azext_aks_agent/__init__.py @@ -4,10 +4,20 @@ # -------------------------------------------------------------------------------------------- -from azure.cli.core import AzCommandsLoader +import os # pylint: disable=unused-import import azext_aks_agent._help +from azext_aks_agent._consts import ( + CONST_AGENT_CONFIG_PATH_DIR_ENV_KEY, + CONST_AGENT_NAME, + CONST_AGENT_NAME_ENV_KEY, + CONST_DISABLE_PROMETHEUS_TOOLSET_ENV_KEY, + CONST_PRIVACY_NOTICE_BANNER, + CONST_PRIVACY_NOTICE_BANNER_ENV_KEY, +) +from azure.cli.core import AzCommandsLoader +from azure.cli.core.api import get_config_dir class ContainerServiceCommandsLoader(AzCommandsLoader): @@ -34,3 +44,14 @@ def load_arguments(self, command): COMMAND_LOADER_CLS = ContainerServiceCommandsLoader + + +# NOTE(mainred): holmesgpt leverages the environment variables to customize its behavior. +def customize_holmesgpt(): + os.environ[CONST_DISABLE_PROMETHEUS_TOOLSET_ENV_KEY] = "true" + os.environ[CONST_AGENT_CONFIG_PATH_DIR_ENV_KEY] = get_config_dir() + os.environ[CONST_AGENT_NAME_ENV_KEY] = CONST_AGENT_NAME + os.environ[CONST_PRIVACY_NOTICE_BANNER_ENV_KEY] = CONST_PRIVACY_NOTICE_BANNER + + +customize_holmesgpt() diff --git a/src/aks-agent/azext_aks_agent/_help.py b/src/aks-agent/azext_aks_agent/_help.py index becc69c98f4..b5324d00a96 100644 --- a/src/aks-agent/azext_aks_agent/_help.py +++ b/src/aks-agent/azext_aks_agent/_help.py @@ -8,7 +8,6 @@ from knack.help_files import helps - helps[ "aks agent" ] = """ @@ -78,10 +77,11 @@ Here is an example of config file: ```json llms: - - provider: "azure" - MODEL_NAME: "gpt-4.1" - AZURE_API_BASE: "https://" - AZURE_API_KEY: "" + - provider: azure + MODEL_NAME: gpt-4.1 + AZURE_API_KEY: ******* + AZURE_API_BASE: https://{azure-openai-service-name}.openai.azure.com/ + AZURE_API_VERSION: 2025-04-01-preview # define a list of mcp servers, mcp server can be defined mcp_servers: aks_mcp: diff --git a/src/aks-agent/azext_aks_agent/agent/agent.py b/src/aks-agent/azext_aks_agent/agent/agent.py index fefdde74809..02a2bc4ed40 100644 --- a/src/aks-agent/azext_aks_agent/agent/agent.py +++ b/src/aks-agent/azext_aks_agent/agent/agent.py @@ -3,20 +3,13 @@ # Licensed under the MIT License. See License.txt in the project root for license information. # -------------------------------------------------------------------------------------------- -import logging import os import select import sys -from azext_aks_agent._consts import ( - CONST_AGENT_CONFIG_PATH_DIR_ENV_KEY, - CONST_AGENT_NAME, - CONST_AGENT_NAME_ENV_KEY, - CONST_DISABLE_PROMETHEUS_TOOLSET_ENV_KEY, - CONST_PRIVACY_NOTICE_BANNER, - CONST_PRIVACY_NOTICE_BANNER_ENV_KEY, -) +from azext_aks_agent.agent.logging import init_log from azure.cli.core.api import get_config_dir +from azure.cli.core.azclierror import CLIInternalError from azure.cli.core.commands.client_factory import get_subscription_id from knack.util import CLIError @@ -25,34 +18,6 @@ from .telemetry import CLITelemetryClient -# NOTE(mainred): environment variables to disable prometheus toolset loading should be set before importing holmes. -def customize_holmesgpt(): - os.environ[CONST_DISABLE_PROMETHEUS_TOOLSET_ENV_KEY] = "true" - os.environ[CONST_AGENT_CONFIG_PATH_DIR_ENV_KEY] = get_config_dir() - os.environ[CONST_AGENT_NAME_ENV_KEY] = CONST_AGENT_NAME - os.environ[CONST_PRIVACY_NOTICE_BANNER_ENV_KEY] = CONST_PRIVACY_NOTICE_BANNER - - -# NOTE(mainred): holmes leverage the log handler RichHandler to provide colorful, readable and well-formatted logs -# making the interactive mode more user-friendly. -# And we removed exising log handlers to avoid duplicate logs. -# Also make the console log consistent, we remove the telemetry and data logger to skip redundant logs. -def init_log(): - # NOTE(mainred): we need to disable INFO logs from LiteLLM before LiteLLM library is loaded, to avoid logging the - # debug logs from heading of LiteLLM. - logging.getLogger("LiteLLM").setLevel(logging.WARNING) - logging.getLogger("telemetry.main").setLevel(logging.WARNING) - logging.getLogger("telemetry.process").setLevel(logging.WARNING) - logging.getLogger("telemetry.save").setLevel(logging.WARNING) - logging.getLogger("telemetry.client").setLevel(logging.WARNING) - logging.getLogger("az_command_data_logger").setLevel(logging.WARNING) - - from holmes.utils.console.logging import init_logging - - # TODO: make log verbose configurable, currently disabled by []. - return init_logging([]) - - def _get_mode_state_file() -> str: """Get the path to the mode state file.""" config_dir = get_config_dir() @@ -168,8 +133,6 @@ def aks_agent( raise CLIError( "Please upgrade the python version to 3.10 or above to use aks agent." ) - # customizing holmesgpt should called before importing holmes - customize_holmesgpt() # Initialize variables interactive = not no_interactive @@ -213,85 +176,88 @@ def aks_agent( # MCP Lifecycle Manager mcp_lifecycle = MCPLifecycleManager() - try: - config = None + config = None - if use_aks_mcp: - try: - config_params = { - 'config_file': config_file, - 'model': model, - 'api_key': api_key, - 'max_steps': max_steps, - 'verbose': show_tool_output - } - mcp_info = mcp_lifecycle.setup_mcp_sync(config_params) - config = mcp_info['config'] - - if show_tool_output: - from .user_feedback import ProgressReporter - ProgressReporter.show_status_message("MCP mode active - enhanced capabilities enabled", "info") - - except Exception as e: # pylint: disable=broad-exception-caught - # Fallback to traditional mode on any MCP setup failure - from .error_handler import AgentErrorHandler - mcp_error = AgentErrorHandler.handle_mcp_setup_error(e, "MCP initialization") - if show_tool_output: - console.print(f"[yellow]MCP setup failed, using traditional mode: {mcp_error.message}[/yellow]") - if mcp_error.suggestions: - console.print("[dim]Suggestions for next time:[/dim]") - for suggestion in mcp_error.suggestions[:3]: # Show only first 3 suggestions - console.print(f"[dim] • {suggestion}[/dim]") - use_aks_mcp = False - current_mode = "traditional" - - # Fallback to traditional mode if MCP setup failed or was disabled - if not config: - config = _setup_traditional_mode_sync(config_file, model, api_key, max_steps, show_tool_output) - if show_tool_output: - console.print("[yellow]Traditional mode active (MCP disabled)[/yellow]") + if use_aks_mcp: + try: + config_params = { + 'config_file': config_file, + 'model': model, + 'api_key': api_key, + 'max_steps': max_steps, + 'verbose': show_tool_output + } + mcp_info = mcp_lifecycle.setup_mcp_sync(config_params) + config = mcp_info['config'] - # Save the current mode to state file for next run - _save_current_mode(current_mode) + if show_tool_output: + from .user_feedback import ProgressReporter + ProgressReporter.show_status_message("MCP mode active - enhanced capabilities enabled", "info") - # Use smart refresh logic - effective_refresh_toolsets = smart_refresh + except Exception as e: # pylint: disable=broad-exception-caught + # Fallback to traditional mode on any MCP setup failure + from .error_handler import AgentErrorHandler + mcp_error = AgentErrorHandler.handle_mcp_setup_error(e, "MCP initialization") + if show_tool_output: + console.print(f"[yellow]MCP setup failed, using traditional mode: {mcp_error.message}[/yellow]") + if mcp_error.suggestions: + console.print("[dim]Suggestions for next time:[/dim]") + for suggestion in mcp_error.suggestions[:3]: # Show only first 3 suggestions + console.print(f"[dim] • {suggestion}[/dim]") + use_aks_mcp = False + current_mode = "traditional" + + # Fallback to traditional mode if MCP setup failed or was disabled + if not config: + config = _setup_traditional_mode_sync(config_file, model, api_key, max_steps, show_tool_output) if show_tool_output: - from .user_feedback import ProgressReporter - ProgressReporter.show_status_message( - f"Toolset refresh: {effective_refresh_toolsets} (Mode: {current_mode})", "info" - ) + console.print("[yellow]Traditional mode active (MCP disabled)[/yellow]") - # Create AI client once with proper refresh settings + # Save the current mode to state file for next run + _save_current_mode(current_mode) + + # Use smart refresh logic + effective_refresh_toolsets = smart_refresh + if show_tool_output: + from .user_feedback import ProgressReporter + ProgressReporter.show_status_message( + f"Toolset refresh: {effective_refresh_toolsets} (Mode: {current_mode})", "info" + ) + + # Validate inputs + if not prompt and not interactive and not piped_data: + raise CLIError( + "Either the 'prompt' argument must be provided (unless using --interactive mode)." + ) + try: + # prepare the toolsets ai = config.create_console_toolcalling_llm( dal=None, refresh_toolsets=effective_refresh_toolsets, ) + except Exception as e: + raise CLIError(f"Failed to create AI executor: {str(e)}") + + # Handle piped data + if piped_data: + if prompt: + # User provided both piped data and a prompt + prompt = f"Here's some piped output:\n\n{piped_data}\n\n{prompt}" + else: + # Only piped data, no prompt - ask what to do with it + prompt = f"Here's some piped output:\n\n{piped_data}\n\nWhat can you tell me about this output?" - # Validate inputs - if not prompt and not interactive and not piped_data: - raise CLIError( - "Either the 'prompt' argument must be provided (unless using --interactive mode)." - ) - - # Handle piped data - if piped_data: - if prompt: - # User provided both piped data and a prompt - prompt = f"Here's some piped output:\n\n{piped_data}\n\n{prompt}" - else: - # Only piped data, no prompt - ask what to do with it - prompt = f"Here's some piped output:\n\n{piped_data}\n\nWhat can you tell me about this output?" - - # Phase 2: Holmes Execution (synchronous - no event loop conflicts) - is_mcp_mode = current_mode == "mcp" + # Phase 2: Holmes Execution (synchronous - no event loop conflicts) + is_mcp_mode = current_mode == "mcp" + try: if interactive: _run_interactive_mode_sync(ai, cmd, resource_group_name, name, prompt, console, show_tool_output, is_mcp_mode, telemetry) else: _run_noninteractive_mode_sync(ai, config, cmd, resource_group_name, name, prompt, console, echo, show_tool_output, is_mcp_mode) - + except Exception as e: # pylint: disable=broad-exception-caught + raise CLIInternalError(f"Error occurred during execution: {str(e)}") finally: # Phase 3: MCP Cleanup (isolated async if needed) mcp_lifecycle.cleanup_mcp_sync() diff --git a/src/aks-agent/azext_aks_agent/agent/llm_config_manager.py b/src/aks-agent/azext_aks_agent/agent/llm_config_manager.py index 73ae82c4d94..42b8866ea26 100644 --- a/src/aks-agent/azext_aks_agent/agent/llm_config_manager.py +++ b/src/aks-agent/azext_aks_agent/agent/llm_config_manager.py @@ -5,11 +5,12 @@ import os -from typing import List, Dict, Optional -import yaml +from typing import Dict, List, Optional -from azure.cli.core.api import get_config_dir +import yaml from azext_aks_agent._consts import CONST_AGENT_CONFIG_FILE_NAME +from azure.cli.core.api import get_config_dir +from azure.cli.core.azclierror import AzCLIError class LLMConfigManager: @@ -21,6 +22,45 @@ def __init__(self, config_path=None): get_config_dir(), CONST_AGENT_CONFIG_FILE_NAME) self.config_path = os.path.expanduser(config_path) + def validate_config(self): + default_config_path = os.path.join(get_config_dir(), CONST_AGENT_CONFIG_FILE_NAME) + # suppose the default config is always valid since it's created by the CLI + if self.config_path == default_config_path: + return + + try: + with open(self.config_path, "r") as f: + config_data = yaml.safe_load(f) + + # Validate the configuration structure + if not isinstance(config_data, dict): + raise ValueError( + f"Configuration file {self.config_path} must contain a YAML dictionary/mapping.") + + if "llms" not in config_data: + raise ValueError( + f"Configuration file {self.config_path} must contain an 'llms' key.") + + if not isinstance(config_data["llms"], list): + raise ValueError( + f"Configuration file {self.config_path}: 'llms' must be a list.") + + if len(config_data["llms"]) == 0: + raise ValueError( + f"Configuration file {self.config_path}: 'llms' list cannot be empty.") + + for llm_config in config_data["llms"]: + if not isinstance(llm_config, dict): + raise ValueError( + f"Configuration file {self.config_path}: " + "each LLM configuration must be a dictionary/mapping.") + except FileNotFoundError: + raise ValueError(f"Configuration file {self.config_path} not found.") + except yaml.YAMLError as e: + raise ValueError(f"Invalid YAML syntax in configuration file {self.config_path}: {e}") + except Exception as e: + raise ValueError(f"Failed to load configuration file {self.config_path}: {e}") + def save(self, provider_name: str, params: dict): configs = self.load() if not isinstance(configs, Dict): @@ -61,8 +101,7 @@ def get_latest(self) -> Optional[Dict]: model_configs = self.get_list() if model_configs: return model_configs[-1] - raise ValueError( - "No configurations found. Please run `az aks agent-init`") + return None def get_specific( self, @@ -76,9 +115,28 @@ def get_specific( if cfg.get("provider") == provider_name and cfg.get( "MODEL_NAME") == model_name: return cfg - raise ValueError( - f"No configuration found for provider '{provider_name}' with model '{model_name}'. " - f"Please run `az aks agent-init`") + return None + + def get_model_config(self, model) -> Optional[Dict]: + prompt_for_init = "Run 'az aks agent-init' to set up your LLM endpoint (recommended path).\n" \ + "To configure your LLM manually, create a config file using the templates provided here: "\ + "https://aka.ms/aks/agentic-cli/init" + + if not model: + llm_config: Optional[Dict] = self.get_latest() + if not llm_config: + raise AzCLIError(f"No LLM configurations found. {prompt_for_init}") + return llm_config + + provider_name = "openai" + model_name = model + if "/" in model: + provider_name, model_name = model.split("/", 1) + llm_config = self.get_specific(provider_name, model_name) + if not llm_config: + raise AzCLIError( + f"No configuration found for model '{model}'. {prompt_for_init}") + return llm_config def is_config_complete(self, config, provider_schema): """ diff --git a/src/aks-agent/azext_aks_agent/agent/llm_providers/__init__.py b/src/aks-agent/azext_aks_agent/agent/llm_providers/__init__.py index 0efd01fbb48..12b379ef201 100644 --- a/src/aks-agent/azext_aks_agent/agent/llm_providers/__init__.py +++ b/src/aks-agent/azext_aks_agent/agent/llm_providers/__init__.py @@ -4,14 +4,15 @@ # -------------------------------------------------------------------------------------------- from typing import List, Tuple + from rich.console import Console -from .base import LLMProvider -from .azure_provider import AzureProvider -from .openai_provider import OpenAIProvider + from .anthropic_provider import AnthropicProvider +from .azure_provider import AzureProvider +from .base import LLMProvider from .gemini_provider import GeminiProvider from .openai_compatible_provider import OpenAICompatibleProvider - +from .openai_provider import OpenAIProvider console = Console() @@ -26,19 +27,19 @@ PROVIDER_REGISTRY = {} for cls in _PROVIDER_CLASSES: - key = cls.name.lower() + key = cls().name.lower() if key not in PROVIDER_REGISTRY: PROVIDER_REGISTRY[key] = cls def _available_providers() -> List[str]: """Return a list of registered provider names (lowercase): ["azure", "openai", ...]""" - return list(PROVIDER_REGISTRY.keys()) + return _PROVIDER_CLASSES def _provider_choices_numbered() -> List[Tuple[int, str]]: """Return numbered choices: [(1, "azure"), (2, "openai"), ...].""" - return [(i + 1, name) for i, name in enumerate(_available_providers())] + return [(i + 1, provider().readable_name) for i, provider in enumerate(_available_providers())] def _get_provider_by_index(idx: int) -> LLMProvider: @@ -48,7 +49,7 @@ def _get_provider_by_index(idx: int) -> LLMProvider: """ from holmes.utils.colors import HELP_COLOR if 1 <= idx <= len(_PROVIDER_CLASSES): - console.print("You selected provider:", _PROVIDER_CLASSES[idx - 1].name, style=f"bold {HELP_COLOR}") + console.print("You selected provider:", _PROVIDER_CLASSES[idx - 1]().readable_name, style=f"bold {HELP_COLOR}") return _PROVIDER_CLASSES[idx - 1]() raise ValueError(f"Invalid provider index: {idx}") @@ -58,16 +59,17 @@ def prompt_provider_choice() -> LLMProvider: Show a numbered menu and return the chosen provider instance. Keeps prompting until a valid selection is made. """ - from holmes.utils.colors import HELP_COLOR, ERROR_COLOR - from holmes.interactive import SlashCommands + from holmes.utils.colors import ERROR_COLOR, HELP_COLOR choices = _provider_choices_numbered() if not choices: raise ValueError("No providers are registered.") while True: for idx, name in choices: console.print(f" {idx}. {name}", style=f"bold {HELP_COLOR}") + console.print(f" {len(choices) + 1}. For other providers, see https://aka.ms/aks/agentic-cli/init", + style=f"bold {HELP_COLOR}") sel_idx = console.input( - f"[bold {HELP_COLOR}]Enter the number of your LLM provider: [/bold {HELP_COLOR}]").strip().lower() + f"[bold {HELP_COLOR}]Please choose the LLM provider (1-{len(choices)}): [/bold {HELP_COLOR}]").strip() if sel_idx == "/exit": raise SystemExit(0) @@ -75,7 +77,7 @@ def prompt_provider_choice() -> LLMProvider: return _get_provider_by_index(int(sel_idx)) except ValueError as e: console.print( - f"{e}. Please enter a valid number, or type '{SlashCommands.EXIT.command}' to exit.", + f"{e}. Please enter a valid number, or type '/exit' to exit.", style=f"{ERROR_COLOR}") diff --git a/src/aks-agent/azext_aks_agent/agent/llm_providers/anthropic_provider.py b/src/aks-agent/azext_aks_agent/agent/llm_providers/anthropic_provider.py index 6a911c2e4db..091889ddf06 100644 --- a/src/aks-agent/azext_aks_agent/agent/llm_providers/anthropic_provider.py +++ b/src/aks-agent/azext_aks_agent/agent/llm_providers/anthropic_provider.py @@ -5,11 +5,18 @@ import requests + from .base import LLMProvider, non_empty class AnthropicProvider(LLMProvider): - name = "anthropic" + @property + def readable_name(self) -> str: + return "Anthropic" + + @property + def model_route(self) -> str: + return "anthropic" @property def parameter_schema(self): @@ -22,7 +29,7 @@ def parameter_schema(self): }, "MODEL_NAME": { "secret": False, - "default": "claude-3", + "default": "claude-sonnet-4", "hint": None, "validator": non_empty }, diff --git a/src/aks-agent/azext_aks_agent/agent/llm_providers/azure_provider.py b/src/aks-agent/azext_aks_agent/agent/llm_providers/azure_provider.py index ce7e0510635..5d7e73658d8 100644 --- a/src/aks-agent/azext_aks_agent/agent/llm_providers/azure_provider.py +++ b/src/aks-agent/azext_aks_agent/agent/llm_providers/azure_provider.py @@ -4,14 +4,30 @@ # -------------------------------------------------------------------------------------------- -import requests from typing import Tuple -from urllib.parse import urljoin, urlencode +from urllib.parse import urlencode, urljoin + +import requests + from .base import LLMProvider, is_valid_url, non_empty +def is_valid_api_base(v: str) -> bool: + # validate the v follows the pattern https://{azure-openai-service-name}.openai.azure.com/ + if not v.startswith("https://") or not v.endswith(".openai.azure.com/"): + return False + + return is_valid_url(v) + + class AzureProvider(LLMProvider): - name = "azure" + @property + def readable_name(self) -> str: + return "Azure OpenAI" + + @property + def model_route(self) -> str: + return "azure" @property def parameter_schema(self): @@ -19,7 +35,7 @@ def parameter_schema(self): "MODEL_NAME": { "secret": False, "default": None, - "hint": "should be consistent with your deployed name, e.g., gpt-4.1", + "hint": "should be consistent with your deployed name, e.g., gpt-5", "validator": non_empty }, "AZURE_API_KEY": { @@ -31,8 +47,8 @@ def parameter_schema(self): "AZURE_API_BASE": { "secret": False, "default": None, - "hint": "https://{your-custom-endpoint}.openai.azure.com/", - "validator": is_valid_url + "hint": "https://{azure-openai-service-name}.openai.azure.com/", + "validator": is_valid_api_base }, "AZURE_API_VERSION": { "secret": False, @@ -52,12 +68,16 @@ def validate_connection(self, params: dict) -> Tuple[bool, str, str]: return False, "Missing required Azure parameters.", "retry_input" # REST API reference: https://learn.microsoft.com/en-us/azure/ai-foundry/openai/api-version-lifecycle?tabs=rest - url = urljoin(api_base, "openai/responses") + url = urljoin(api_base, f"openai/deployments/{model_name}/chat/completions") + query = {"api-version": api_version} full_url = f"{url}?{urlencode(query)}" - headers = {"api-key": api_key, "Content-Type": "application/json"} - payload = {"model": model_name, - "input": "ping", "max_output_tokens": 16} + headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"} + payload = { + "model": model_name, + "messages": [{"role": "user", "content": "ping"}], + "max_tokens": 16 + } try: resp = requests.post(full_url, headers=headers, diff --git a/src/aks-agent/azext_aks_agent/agent/llm_providers/base.py b/src/aks-agent/azext_aks_agent/agent/llm_providers/base.py index 1a59fd1824d..662af07c600 100644 --- a/src/aks-agent/azext_aks_agent/agent/llm_providers/base.py +++ b/src/aks-agent/azext_aks_agent/agent/llm_providers/base.py @@ -5,11 +5,11 @@ from abc import ABC, abstractmethod -from typing import Dict, Callable, Tuple, Any -from rich.console import Console -from rich.prompt import Prompt +from typing import Any, Callable, Dict, Tuple from urllib.parse import urlparse +from rich.console import Console + console = Console() HINT_COLOR = "bright_black" DEFAULT_COLOR = "bright_black" @@ -30,7 +30,41 @@ def is_valid_url(v: str) -> bool: class LLMProvider(ABC): - name = "base" + + @property + @abstractmethod + def readable_name(self) -> str: + """Return the provider name for this provider. + The provider name is a human-readable string, e.g., "Azure OpenAI", "OpenAI", etc. + """ + return "Base Provider" + + @property + def name(self) -> str: + """Return the provider name for this provider. + provider name is the key to identity a llmprovider. + https://docs.litellm.ai/docs/providers + """ + return self.model_route + + @property + @abstractmethod + def model_route(self) -> str: + """Return the model route parameter key for this provider. + This model route indicates the model prefix of llm providers supported by LiteLLM, for example the azure openai. + https://docs.litellm.ai/docs/providers + """ + return "base" + + def model_name(self, model_name) -> str: + """Return the model name for this provider. + The models name combines the model route and model name, e.g., "azure/gpt-5" + https://docs.litellm.ai/docs/providers + """ + if self.model_route: + return f"{self.model_route}/{model_name}" + + return model_name @property @abstractmethod @@ -51,8 +85,8 @@ def parameter_schema(self) -> Dict[str, Dict[str, Any]]: def prompt_params(self): """Prompt user for parameters using parameter_schema when available.""" - from holmes.utils.colors import HELP_COLOR, ERROR_COLOR from holmes.interactive import SlashCommands + from holmes.utils.colors import ERROR_COLOR, HELP_COLOR schema = self.parameter_schema params = {} @@ -71,12 +105,24 @@ def prompt_params(self): while True: if secret: - value = Prompt.ask( - f"[bold {HELP_COLOR}]Enter your API key[/]", - password=True - ) + # For password input, we'll handle the display differently + value = console.input(prompt, password=secret) + # Calculate the masked display value following OpenAI pattern + if len(value) <= 8: + # For short passwords, show all as asterisks + display_value = '*' * len(value) + else: + # Show first 3 chars + 3 dots + last 4 chars (OpenAI pattern) + first_chars = value[:3] + last_chars = value[-4:] + display_value = f"{first_chars}...{last_chars}" + # It seems rich renders the cursor up as plain text not a control sequence, + # so when we combine the cursor up and re-print, console prints extra "[1A" unexpectedly. + # To avoid that, we use a workaround by printing the cursor up separately. + print("\033[1A", end='') + console.print(f"{prompt}{display_value}") else: - value = console.input(prompt) + value = console.input(prompt, password=False) if not value and default is not None: value = default diff --git a/src/aks-agent/azext_aks_agent/agent/llm_providers/gemini_provider.py b/src/aks-agent/azext_aks_agent/agent/llm_providers/gemini_provider.py index 14b17eb2c19..461ad534b1c 100644 --- a/src/aks-agent/azext_aks_agent/agent/llm_providers/gemini_provider.py +++ b/src/aks-agent/azext_aks_agent/agent/llm_providers/gemini_provider.py @@ -5,11 +5,18 @@ import requests + from .base import LLMProvider, non_empty class GeminiProvider(LLMProvider): - name = "gemini" + @property + def readable_name(self) -> str: + return "Gemini" + + @property + def model_route(self) -> str: + return "gemini" @property def parameter_schema(self): diff --git a/src/aks-agent/azext_aks_agent/agent/llm_providers/openai_compatible_provider.py b/src/aks-agent/azext_aks_agent/agent/llm_providers/openai_compatible_provider.py index 952431ab2e8..5c75f4817d1 100644 --- a/src/aks-agent/azext_aks_agent/agent/llm_providers/openai_compatible_provider.py +++ b/src/aks-agent/azext_aks_agent/agent/llm_providers/openai_compatible_provider.py @@ -3,13 +3,27 @@ # Licensed under the MIT License. See License.txt in the project root for license information. # -------------------------------------------------------------------------------------------- -import requests from urllib.parse import urljoin -from .base import LLMProvider, non_empty, is_valid_url + +import requests + +from .base import LLMProvider, is_valid_url, non_empty class OpenAICompatibleProvider(LLMProvider): - name = "openai_compatible" + @property + def readable_name(self) -> str: + return "OpenAI Compatible" + + @property + def name(self) -> str: + return "openai_compatible" + + @property + def model_route(self) -> str: + # LiteLLM uses "openai" as the provider to route the request to an OpenAI-compatible endpoint + # https://docs.litellm.ai/docs/providers/openai_compatible + return "openai" @property def parameter_schema(self): @@ -22,7 +36,7 @@ def parameter_schema(self): }, "API_KEY": { "secret": True, - "default": "ollama", + "default": None, "hint": None, "validator": non_empty }, diff --git a/src/aks-agent/azext_aks_agent/agent/llm_providers/openai_provider.py b/src/aks-agent/azext_aks_agent/agent/llm_providers/openai_provider.py index c495b0197b6..4ea4e9372bd 100644 --- a/src/aks-agent/azext_aks_agent/agent/llm_providers/openai_provider.py +++ b/src/aks-agent/azext_aks_agent/agent/llm_providers/openai_provider.py @@ -5,19 +5,31 @@ import requests + from .base import LLMProvider, non_empty class OpenAIProvider(LLMProvider): - name = "openai" + @property + def readable_name(self) -> str: + return "OpenAI" + + @property + def name(self) -> str: + return "openai" + + @property + def model_route(self) -> str: + # Openai model route is empty under the Litellm provider scheme + return "" @property def parameter_schema(self): return { "MODEL_NAME": { "secret": False, - "default": None, - "hint": "gpt-4.1", + "default": "gpt-5", + "hint": None, "validator": non_empty }, "OPENAI_API_KEY": { diff --git a/src/aks-agent/azext_aks_agent/agent/logging.py b/src/aks-agent/azext_aks_agent/agent/logging.py new file mode 100644 index 00000000000..4d67a595ea9 --- /dev/null +++ b/src/aks-agent/azext_aks_agent/agent/logging.py @@ -0,0 +1,72 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- +# pylint: disable=line-too-long + +import logging +from contextlib import contextmanager + + +# azure cli core will handle and log the exceptions with metadata_logger, while rich handler may cause duplicate logs. +# ref: https://github.com/Azure/azure-cli/blob/37e3d6e857bfe05dbad6b9594d65589b8cfaee5a/src/azure-cli-core/azure/cli/core/azlogging.py#L207 +def mute_rich_logging(): + """ + Remove rich handlers from the root logger to prevent duplicate logging when raising errors. + """ + root_logger = logging.getLogger() + rich_handlers = [] + + # Find and remove rich handlers + for handler in root_logger.handlers[:]: # Create a copy to iterate safely + if (hasattr(handler, 'console') or + handler.__class__.__name__ == 'RichHandler' or + 'rich' in handler.__class__.__module__.lower()): + rich_handlers.append(handler) + root_logger.removeHandler(handler) + + +# NOTE(mainred): holmes leverage the log handler RichHandler to provide colorful, readable and well-formatted logs +# making the interactive mode more user-friendly. +# And we removed exising log handlers to avoid duplicate logs. +# Also make the console log consistent, we remove the telemetry and data logger to skip redundant logs. +def init_log(): + # NOTE(mainred): we need to disable INFO logs from LiteLLM before LiteLLM library is loaded, to avoid logging the + # debug logs from heading of LiteLLM. + logging.getLogger("LiteLLM").setLevel(logging.WARNING) + logging.getLogger("telemetry.main").setLevel(logging.WARNING) + logging.getLogger("telemetry.process").setLevel(logging.WARNING) + logging.getLogger("telemetry.save").setLevel(logging.WARNING) + logging.getLogger("telemetry.client").setLevel(logging.WARNING) + logging.getLogger("az_command_data_logger").setLevel(logging.WARNING) + + from holmes.utils.console.logging import init_logging + + # TODO: make log verbose configurable, currently disabled by []. + return init_logging([]) + + +@contextmanager +def rich_logging(): + """ + Context manager that initializes logging and automatically mutes rich logging on errors. + This combines initialization with automatic error handling. + + Usage: + with rich_logging() as console: + # Rich logging is available + console.print("This will use rich logging") + + # If any error is raised, rich logging will be muted automatically + raise CLIError("This won't be logged by rich handler") + """ + # Initialize logging first + console = init_log() + + try: + yield console + except Exception: + # When any exception occurs, mute rich logging to prevent duplicates + mute_rich_logging() + # Re-raise the exception so it can be handled normally by CLI + raise diff --git a/src/aks-agent/azext_aks_agent/custom.py b/src/aks-agent/azext_aks_agent/custom.py index 01b28f65b61..5ffb428802f 100644 --- a/src/aks-agent/azext_aks_agent/custom.py +++ b/src/aks-agent/azext_aks_agent/custom.py @@ -3,21 +3,22 @@ # Licensed under the MIT License. See License.txt in the project root for license information. # -------------------------------------------------------------------------------------------- -# pylint: disable=too-many-lines, disable=broad-except import os -import sys -from typing import Dict, Optional -from azure.cli.core.api import get_config_dir + from azext_aks_agent._consts import CONST_AGENT_CONFIG_FILE_NAME + +# pylint: disable=too-many-lines, disable=broad-except from azext_aks_agent.agent.agent import aks_agent as aks_agent_internal -from azext_aks_agent.agent.llm_providers import prompt_provider_choice, PROVIDER_REGISTRY from azext_aks_agent.agent.llm_config_manager import LLMConfigManager - -from azext_aks_agent.agent.agent import init_log - +from azext_aks_agent.agent.llm_providers import ( + PROVIDER_REGISTRY, + prompt_provider_choice, +) +from azext_aks_agent.agent.logging import rich_logging +from azure.cli.core.api import get_config_dir +from azure.cli.core.azclierror import AzCLIError from knack.log import get_logger - logger = get_logger(__name__) @@ -25,40 +26,28 @@ def aks_agent_init(cmd): """Initialize AKS agent llm configuration.""" - init_log() - - from rich.console import Console - from holmes.utils.colors import HELP_COLOR, ERROR_COLOR - from holmes.interactive import SlashCommands - - console = Console() - console.print( - f"Welcome to AKS Agent LLM configuration setup. Type '{SlashCommands.EXIT.command}' to exit.", - style=f"bold {HELP_COLOR}") - - provider = prompt_provider_choice() - params = provider.prompt_params() - - llm_config_manager = LLMConfigManager() - # If the connection to the model endpoint is valid, save the configuration - is_valid, message, action = provider.validate_connection(params) - - if is_valid and action == "save": - logger.info("%s", message) - llm_config_manager.save(provider.name, params) - console.print("LLM configuration setup successfully.", style=f"bold {HELP_COLOR}") - - elif not is_valid and action == "retry_input": - logger.warning("%s", message) - console.print( - "Please re-run [bold]`az aks agent-init`[/bold] to correct the input parameters.", style=f"{ERROR_COLOR}") - sys.exit(1) - - else: - logger.error("%s", message) + with rich_logging() as console: + from holmes.utils.colors import HELP_COLOR console.print( - "Please check your deployed model and network connectivity.", style=f"bold {ERROR_COLOR}") - sys.exit(1) + "Welcome to AKS Agent LLM configuration setup. Type '/exit' to exit.", + style=f"bold {HELP_COLOR}") + + provider = prompt_provider_choice() + params = provider.prompt_params() + + llm_config_manager = LLMConfigManager() + # If the connection to the model endpoint is valid, save the configuration + is_valid, msg, action = provider.validate_connection(params) + + if is_valid and action == "save": + llm_config_manager.save(provider.model_route if provider.model_route else "openai", params) + console.print( + f"LLM configuration setup successfully and is saved to {llm_config_manager.config_path}.", + style=f"bold {HELP_COLOR}") + elif not is_valid and action == "retry_input": + raise AzCLIError(f"Please re-run `az aks agent-init` to correct the input parameters. {str(msg)}") + else: + raise AzCLIError(f"Please check your deployed model and network connectivity. {str(msg)}") # pylint: disable=unused-argument @@ -83,112 +72,44 @@ def aks_agent( if status: return aks_agent_status(cmd) - llm_config_manager = LLMConfigManager() - llm_config = None - default_llm_config_path = os.path.join( - get_config_dir(), CONST_AGENT_CONFIG_FILE_NAME) - - if config_file == default_llm_config_path: - if not model: - logger.info("Using default configuration file: %s", config_file) - llm_config: Optional[Dict] = llm_config_manager.get_latest() - if not llm_config: - raise ValueError( - "No llm configurations found. " - "Please run `az aks agent init` " - "or provide a config file using --config-file.") - - else: - logger.info("Using specified model: %s", model) - # parsing model into provider/model - if "/" in model: - provider_name, model_name = model.split("/", 1) - else: - provider_name = "openai" - model_name = model - llm_config = llm_config_manager.get_specific( - provider_name, model_name) - + if not config_file: + config_file = os.path.join( + get_config_dir(), CONST_AGENT_CONFIG_FILE_NAME) + logger.info("Using default configuration file: %s", config_file) else: - if config_file: - logger.info("Using user configuration file: %s", config_file) - import yaml - try: - with open(config_file, "r") as f: - llm_config = yaml.safe_load(f)["llms"][0] - if not isinstance(llm_config, Dict): - raise ValueError( - "Configuration file format is invalid. It should be a YAML mapping.") - except Exception as e: - raise ValueError(f"Failed to load configuration file: {e}") - - else: - raise ValueError( - "No configuration found. " - "Please run `az aks agent-init` or provide a config file using --config-file, " - "or specify a model using --model.") + logger.info("Using user configuration file: %s", config_file) + llm_config_manager = LLMConfigManager(config_file) + llm_config_manager.validate_config() + llm_config = llm_config_manager.get_model_config(model) # Check if the configuration is complete provider_name = llm_config.get("provider") provider_instance = PROVIDER_REGISTRY.get(provider_name)() - parameter_schema = provider_instance.parameter_schema - if _check_provider( - provider_name, - parameter_schema, - llm_config, - llm_config_manager): - # get model for holmesgpt/litellm: provider_name/model_name - model_name = llm_config.get("MODEL_NAME") - if provider_name == "openai": - model = model or model_name - elif provider_name == "openai_compatiable": - model = model or f"openai/{model_name}" - else: - model = model or f"{provider_name}/{model_name}" - # Set environment variables for the model provider - for k, v in llm_config.items(): - if k not in ["provider", "MODEL_NAME"]: - os.environ[k] = v - logger.info( - "Using provider: %s, model: %s, Env vars setup successfully.", provider_name, model_name) - - aks_agent_internal( - cmd, - resource_group_name, - name, - prompt, - model, - api_key, - max_steps, - config_file, - no_interactive, - no_echo_request, - show_tool_output, - refresh_toolsets, - use_aks_mcp=use_aks_mcp, - ) - - -def _check_provider( - provider_name: str, - parameter_schema: Dict, - llm_config: Dict, - llm_config_manager: LLMConfigManager -) -> bool: - # Check if provider name is not empty - if not provider_name: - raise ValueError("No provider name.") - # Check if provider is supported - if provider_name not in PROVIDER_REGISTRY: - supported = list(PROVIDER_REGISTRY.keys()) - raise ValueError( - f"Unsupported provider {provider_name} for LLM initialization." - f"Supported llm providers are {supported}. Please refer to doc.") - # check if provider config is complete - if not llm_config_manager.is_config_complete(llm_config, parameter_schema): - raise ValueError( - "Incomplete configuration in user config, please run `az aks agent-init` to initialize.") - return True + model = provider_instance.model_name(llm_config.get("MODEL_NAME")) + + # Set environment variables for the model provider + for k, v in llm_config.items(): + if k not in ["provider", "MODEL_NAME"]: + os.environ[k] = v + logger.info( + "Using provider: %s, model: %s, Env vars setup successfully.", provider_name, llm_config.get("MODEL_NAME")) + + with rich_logging(): + aks_agent_internal( + cmd, + resource_group_name, + name, + prompt, + model, + api_key, + max_steps, + config_file, + no_interactive, + no_echo_request, + show_tool_output, + refresh_toolsets, + use_aks_mcp=use_aks_mcp, + ) def aks_agent_status(cmd): @@ -199,9 +120,9 @@ def aks_agent_status(cmd): :return: None (displays status via console output) """ try: + from azext_aks_agent._consts import CONST_MCP_BINARY_DIR from azext_aks_agent.agent.binary_manager import AksMcpBinaryManager from azext_aks_agent.agent.mcp_manager import MCPManager - from azext_aks_agent._consts import CONST_MCP_BINARY_DIR # Initialize status information status_info = { @@ -278,8 +199,8 @@ def aks_agent_status(cmd): def _display_agent_status(status_info): """Display formatted status with rich console output.""" from rich.console import Console - from rich.table import Table from rich.panel import Panel + from rich.table import Table console = Console() diff --git a/src/aks-agent/azext_aks_agent/tests/latest/test_aks_agent.py b/src/aks-agent/azext_aks_agent/tests/latest/test_aks_agent.py index bde484b0ca2..70d12195ba9 100644 --- a/src/aks-agent/azext_aks_agent/tests/latest/test_aks_agent.py +++ b/src/aks-agent/azext_aks_agent/tests/latest/test_aks_agent.py @@ -8,7 +8,8 @@ import unittest from unittest.mock import Mock, patch -from azext_aks_agent.agent.agent import aks_agent, init_log +from azext_aks_agent.agent.agent import aks_agent +from azext_aks_agent.agent.logging import init_log from azure.cli.core.util import CLIError # Mock the holmes modules before any imports that might trigger holmes imports @@ -33,28 +34,6 @@ def setUpModule(): raise unittest.SkipTest("Tests in this module require Python >= 3.10") -class TestInitLog(unittest.TestCase): - """Test cases for init_log function""" - - @patch('azext_aks_agent.agent.agent.logging.getLogger') - def test_init_log_logger_level_setting(self, mock_get_logger): - """Test that specific loggers get WARNING level set""" - # Arrange - mock_logger = Mock() - mock_get_logger.return_value = mock_logger - - with patch('holmes.utils.console.logging.init_logging') as mock_init_logging: - mock_init_logging.return_value = Mock() - - # Act - init_log() - - # Assert that setLevel was called 6 times with WARNING - self.assertEqual(mock_logger.setLevel.call_count, 6) - for call_args in mock_logger.setLevel.call_args_list: - self.assertEqual(call_args[0][0], logging.WARNING) - - class TestAksAgent(unittest.TestCase): """Test cases for aks_agent function""" diff --git a/src/aks-agent/azext_aks_agent/tests/latest/test_aks_agent_init.py b/src/aks-agent/azext_aks_agent/tests/latest/test_aks_agent_init.py index 360e8361fd9..eca9dfc2ef5 100644 --- a/src/aks-agent/azext_aks_agent/tests/latest/test_aks_agent_init.py +++ b/src/aks-agent/azext_aks_agent/tests/latest/test_aks_agent_init.py @@ -5,9 +5,10 @@ import types import unittest -from unittest.mock import patch, MagicMock -from azext_aks_agent.custom import aks_agent_init +from unittest.mock import MagicMock, patch +from azext_aks_agent.custom import aks_agent_init +from azure.cli.core.azclierror import AzCLIError mock_logging = MagicMock(name="init_logging") mock_console_mod = types.SimpleNamespace(logging=types.SimpleNamespace(init_logging=mock_logging)) @@ -73,6 +74,7 @@ def test_init_successful_save( mock_provider.prompt_params.return_value = {'MODEL_NAME': 'test-model', 'param': 'value'} mock_provider.validate_connection.return_value = (True, 'Valid', 'save') mock_provider.name = 'openai' + mock_provider.model_route = 'openai' mock_prompt_provider_choice.return_value = mock_provider mock_config_manager = MagicMock() @@ -84,7 +86,6 @@ def test_init_successful_save( aks_agent_init(cmd=None) mock_config_manager.save.assert_called_once_with('openai', {'MODEL_NAME': 'test-model', 'param': 'value'}) - mock_console.print.assert_any_call("LLM configuration setup successfully.", style=unittest.mock.ANY) @patch('holmes.interactive.SlashCommands') @patch('holmes.utils.colors.ERROR_COLOR') @@ -116,13 +117,10 @@ def test_init_retry_input( mock_error_color.__str__.return_value = "red" mock_slash_commands.EXIT.command = "exit" - with self.assertRaises(SystemExit) as cm: + with self.assertRaises(AzCLIError) as cm: aks_agent_init(cmd=None) - self.assertEqual(cm.exception.code, 1) - mock_console.print.assert_any_call( - "Please re-run [bold]`az aks agent-init`[/bold] to correct the input parameters.", - style=unittest.mock.ANY, - ) + self.assertEqual(str(cm.exception), + "Please re-run `az aks agent-init` to correct the input parameters. Invalid input") @patch('holmes.interactive.SlashCommands') @patch('holmes.utils.colors.ERROR_COLOR') @@ -154,13 +152,9 @@ def test_init_connection_error( mock_error_color.__str__.return_value = "red" mock_slash_commands.EXIT.command = "exit" - with self.assertRaises(SystemExit) as cm: + with self.assertRaises(AzCLIError) as cm: aks_agent_init(cmd=None) - self.assertEqual(cm.exception.code, 1) - mock_console.print.assert_any_call( - "Please check your deployed model and network connectivity.", - style=unittest.mock.ANY, - ) + self.assertEqual(str(cm.exception), "Please check your deployed model and network connectivity. Connection failed") if __name__ == '__main__': diff --git a/src/aks-agent/azext_aks_agent/tests/latest/test_aks_agent_llm_config_manager.py b/src/aks-agent/azext_aks_agent/tests/latest/test_aks_agent_llm_config_manager.py index 2e68fcd31d1..3a3df9bfae1 100644 --- a/src/aks-agent/azext_aks_agent/tests/latest/test_aks_agent_llm_config_manager.py +++ b/src/aks-agent/azext_aks_agent/tests/latest/test_aks_agent_llm_config_manager.py @@ -6,63 +6,450 @@ import os import tempfile import unittest +from unittest.mock import patch + +import yaml from azext_aks_agent.agent.llm_config_manager import LLMConfigManager +from azure.cli.core.azclierror import AzCLIError class TestLLMConfigManager(unittest.TestCase): + """Test cases for LLMConfigManager class.""" + def setUp(self): - # Create a temporary config file for testing - self.temp_file = tempfile.NamedTemporaryFile(delete=False) - self.config_path = self.temp_file.name - self.manager = LLMConfigManager(config_path=self.config_path) + """Set up test fixtures.""" + self.temp_dir = tempfile.mkdtemp() + self.config_file = os.path.join(self.temp_dir, "test_config.yaml") + self.manager = LLMConfigManager() + self.manager.config_path = self.config_file def tearDown(self): - # Remove the temporary file after each test - if os.path.exists(self.config_path): - os.unlink(self.config_path) - - def test_save_and_load(self): - params = {"MODEL_NAME": "test-model", "param1": "value1"} - self.manager.save("openai", params) - loaded = self.manager.load() - self.assertIn("llms", loaded) - self.assertEqual(loaded["llms"][0]["MODEL_NAME"], "test-model") - self.assertEqual(loaded["llms"][0]["provider"], "openai") - - def test_get_list_and_latest(self): - params1 = {"MODEL_NAME": "model1", "param": "v1"} - params2 = {"MODEL_NAME": "model2", "param": "v2"} - self.manager.save("openai", params1) - self.manager.save("openai", params2) - model_list = self.manager.get_list() - self.assertEqual(len(model_list), 2) - latest = self.manager.get_latest() - self.assertEqual(latest["MODEL_NAME"], "model2") - - def test_get_specific(self): - params1 = {"MODEL_NAME": "modelA", "param": "foo"} - params2 = {"MODEL_NAME": "modelB", "param": "bar"} - self.manager.save("openai", params1) - self.manager.save("openai", params2) - specific = self.manager.get_specific("openai", "modelA") - self.assertEqual(specific["param"], "foo") - with self.assertRaises(ValueError): - self.manager.get_specific("openai", "not_exist") - - def test_is_config_complete(self): - config = {"key1": "val1", "key2": "val2"} - schema = { - "key1": {"validator": lambda v: v == "val1"}, - "key2": {"validator": lambda v: v == "val2"} - } - self.assertTrue(self.manager.is_config_complete(config, schema)) - config["key2"] = "wrong" - self.assertFalse(self.manager.is_config_complete(config, schema)) - - def test_load_returns_empty_when_file_missing(self): - # Remove file and test load fallback - os.unlink(self.config_path) - self.assertEqual(self.manager.load(), {}) + """Clean up test fixtures.""" + if os.path.exists(self.config_file): + os.unlink(self.config_file) + os.rmdir(self.temp_dir) + + def test_save_new_config(self): + """Test saving a new configuration when file doesn't exist.""" + config = { + "MODEL_NAME": "gpt-4", + "OPENAI_API_KEY": "test-key", + "OPENAI_API_BASE": "https://api.openai.com/v1" + } + + self.manager.save("openai", config) + + # Verify file was created and contains correct data + self.assertTrue(os.path.exists(self.config_file)) + with open(self.config_file, 'r') as f: + data = yaml.safe_load(f) + + self.assertIn("llms", data) + self.assertEqual(len(data["llms"]), 1) + expected_config = {"provider": "openai", **config} + self.assertEqual(data["llms"][0], expected_config) + + def test_save_append_to_existing_config(self): + """Test saving a configuration to an existing file.""" + # Create initial config + initial_config = { + "provider": "azure", + "MODEL_NAME": "gpt-3.5", + "AZURE_OPENAI_API_KEY": "initial-key" + } + initial_data = {"llms": [initial_config]} + + with open(self.config_file, 'w') as f: + yaml.safe_dump(initial_data, f) + + # Add new config + new_config = { + "MODEL_NAME": "gpt-4", + "OPENAI_API_KEY": "new-key" + } + + self.manager.save("openai", new_config) + + # Verify both configs exist + with open(self.config_file, 'r') as f: + data = yaml.safe_load(f) + + self.assertEqual(len(data["llms"]), 2) + self.assertEqual(data["llms"][0], initial_config) + expected_new_config = {"provider": "openai", **new_config} + self.assertEqual(data["llms"][1], expected_new_config) + + def test_save_creates_llms_key_if_missing(self): + """Test that save creates 'llms' key if config file exists but is malformed.""" + # Create config file without 'llms' key + malformed_data = {"other_key": "value"} + with open(self.config_file, 'w') as f: + yaml.safe_dump(malformed_data, f) + + config = {"MODEL_NAME": "gpt-4"} + self.manager.save("openai", config) + + with open(self.config_file, 'r') as f: + data = yaml.safe_load(f) + + self.assertIn("llms", data) + self.assertEqual(len(data["llms"]), 1) + expected_config = {"provider": "openai", **config} + self.assertEqual(data["llms"][0], expected_config) + + @patch("builtins.open", side_effect=IOError("Permission denied")) + def test_save_handles_file_write_error(self, mock_file): + """Test that save handles file write errors gracefully.""" + config = {"MODEL_NAME": "gpt-4"} + + with self.assertRaises(IOError): + self.manager.save("openai", config) + + def test_load_existing_file(self): + """Test loading configurations from an existing file.""" + configs = [ + {"provider": "openai", "MODEL_NAME": "gpt-4"}, + {"provider": "azure", "MODEL_NAME": "gpt-3.5"} + ] + data = {"llms": configs} + + with open(self.config_file, 'w') as f: + yaml.safe_dump(data, f) + + result = self.manager.load() + self.assertEqual(result, data) + + def test_load_handles_invalid_yaml(self): + """Test that load handles invalid YAML content.""" + # Write invalid YAML + with open(self.config_file, 'w') as f: + f.write("invalid: yaml: content: {\n") + + with self.assertRaises(yaml.YAMLError): + self.manager.load() + + def test_get_list_with_configs(self): + """Test get_list returns list of configurations.""" + configs = [ + {"provider": "openai", "MODEL_NAME": "gpt-4"}, + {"provider": "azure", "MODEL_NAME": "gpt-3.5"} + ] + data = {"llms": configs} + + with open(self.config_file, 'w') as f: + yaml.safe_dump(data, f) + + result = self.manager.get_list() + self.assertEqual(result, configs) + + def test_get_list_empty_file(self): + """Test get_list returns empty list when no configs exist.""" + result = self.manager.get_list() + self.assertEqual(result, []) + + def test_get_list_missing_llms_key(self): + """Test get_list handles missing 'llms' key gracefully.""" + data = {"other_key": "value"} + with open(self.config_file, 'w') as f: + yaml.safe_dump(data, f) + + result = self.manager.get_list() + self.assertEqual(result, []) + + def test_get_latest_with_configs(self): + """Test get_latest returns the most recent configuration.""" + configs = [ + {"provider": "openai", "MODEL_NAME": "gpt-3.5"}, + {"provider": "azure", "MODEL_NAME": "gpt-4"} + ] + data = {"llms": configs} + + with open(self.config_file, 'w') as f: + yaml.safe_dump(data, f) + + result = self.manager.get_latest() + self.assertEqual(result, configs[-1]) # Should return last config + + def test_get_latest_no_configs(self): + """Test get_latest returns None when no configurations exist.""" + result = self.manager.get_latest() + self.assertIsNone(result) + + def test_get_specific_found(self): + """Test get_specific returns correct config when found.""" + configs = [ + {"provider": "openai", "MODEL_NAME": "gpt-3.5"}, + {"provider": "azure", "MODEL_NAME": "gpt-4"}, + {"provider": "openai", "MODEL_NAME": "gpt-4"} + ] + data = {"llms": configs} + + with open(self.config_file, 'w') as f: + yaml.safe_dump(data, f) + + result = self.manager.get_specific("openai", "gpt-4") + self.assertEqual(result, configs[2]) + + def test_get_specific_not_found(self): + """Test get_specific returns None when config not found.""" + configs = [ + {"provider": "openai", "MODEL_NAME": "gpt-3.5"}, + {"provider": "azure", "MODEL_NAME": "gpt-4"} + ] + data = {"llms": configs} + + with open(self.config_file, 'w') as f: + yaml.safe_dump(data, f) + + result = self.manager.get_specific("openai", "gpt-4") + self.assertIsNone(result) + + def test_get_model_config_no_model_param_with_configs(self): + """Test get_model_config returns latest when no model specified.""" + configs = [ + {"provider": "openai", "MODEL_NAME": "gpt-3.5"}, + {"provider": "azure", "MODEL_NAME": "gpt-4"} + ] + data = {"llms": configs} + + with open(self.config_file, 'w') as f: + yaml.safe_dump(data, f) + + result = self.manager.get_model_config(None) + self.assertEqual(result, configs[-1]) + + def test_get_model_config_no_model_param_no_configs(self): + """Test get_model_config raises error when no model and no configs.""" + with self.assertRaises(AzCLIError) as cm: + self.manager.get_model_config(None) + + self.assertIn("No LLM configurations found", str(cm.exception)) + self.assertIn("az aks agent-init", str(cm.exception)) + + def test_get_model_config_with_provider_model(self): + """Test get_model_config with provider/model format.""" + configs = [ + {"provider": "openai", "MODEL_NAME": "gpt-3.5"}, + {"provider": "azure", "MODEL_NAME": "gpt-4"} + ] + data = {"llms": configs} + + with open(self.config_file, 'w') as f: + yaml.safe_dump(data, f) + + result = self.manager.get_model_config("azure/gpt-4") + self.assertEqual(result, configs[1]) + + def test_get_model_config_with_model_only(self): + """Test get_model_config with model only (defaults to openai).""" + configs = [ + {"provider": "openai", "MODEL_NAME": "gpt-4"}, + {"provider": "azure", "MODEL_NAME": "gpt-4"} + ] + data = {"llms": configs} + + with open(self.config_file, 'w') as f: + yaml.safe_dump(data, f) + + result = self.manager.get_model_config("gpt-4") + self.assertEqual(result, configs[0]) # Should find openai provider + + def test_get_model_config_model_not_found(self): + """Test get_model_config raises error when specified model not found.""" + configs = [ + {"provider": "openai", "MODEL_NAME": "gpt-3.5"} + ] + data = {"llms": configs} + + with open(self.config_file, 'w') as f: + yaml.safe_dump(data, f) + + with self.assertRaises(AzCLIError) as cm: + self.manager.get_model_config("azure/gpt-4") + + self.assertIn("No configuration found for model 'azure/gpt-4'", str(cm.exception)) + + def test_is_config_complete_all_valid(self): + """Test is_config_complete returns True when all validations pass.""" + config = { + "OPENAI_API_KEY": "test-key", + "MODEL_NAME": "gpt-4" + } + + provider_schema = { + "OPENAI_API_KEY": {"validator": lambda x: x and len(x) > 0}, + "MODEL_NAME": {"validator": lambda x: x and len(x) > 0} + } + + result = self.manager.is_config_complete(config, provider_schema) + self.assertTrue(result) + + def test_is_config_complete_missing_key(self): + """Test is_config_complete returns False when required key is missing.""" + config = { + "OPENAI_API_KEY": "test-key" + # Missing MODEL_NAME + } + + provider_schema = { + "OPENAI_API_KEY": {"validator": lambda x: x and len(x) > 0}, + "MODEL_NAME": {"validator": lambda x: x and len(x) > 0} + } + + result = self.manager.is_config_complete(config, provider_schema) + self.assertFalse(result) + + def test_is_config_complete_invalid_value(self): + """Test is_config_complete returns False when validation fails.""" + config = { + "OPENAI_API_KEY": "", # Empty string should fail validation + "MODEL_NAME": "gpt-4" + } + + provider_schema = { + "OPENAI_API_KEY": {"validator": lambda x: x and len(x) > 0}, + "MODEL_NAME": {"validator": lambda x: x and len(x) > 0} + } + + result = self.manager.is_config_complete(config, provider_schema) + self.assertFalse(result) + + def test_is_config_complete_no_validator(self): + """Test is_config_complete skips keys without validators.""" + config = { + "OPENAI_API_KEY": "test-key", + "MODEL_NAME": "gpt-4" + } + + provider_schema = { + "OPENAI_API_KEY": {}, # No validator + "MODEL_NAME": {"validator": lambda x: x and len(x) > 0} + } + + result = self.manager.is_config_complete(config, provider_schema) + self.assertTrue(result) + + def test_validate_config_valid_structure(self): + """Test validate_config with valid YAML structure.""" + valid_config = { + "llms": [ + {"provider": "openai", "MODEL_NAME": "gpt-4"} + ] + } + + # Write valid config to file + with open(self.config_file, 'w') as f: + yaml.safe_dump(valid_config, f) + + # Should not raise any exception + self.manager.validate_config() + + def test_validate_config_missing_llms_key(self): + """Test validate_config raises error when 'llms' key is missing.""" + invalid_config = { + "other_key": "value" + } + + # Write invalid config to file + with open(self.config_file, 'w') as f: + yaml.safe_dump(invalid_config, f) + + with self.assertRaises(ValueError) as cm: + self.manager.validate_config() + + self.assertIn("must contain an 'llms' key", str(cm.exception)) + + def test_validate_config_llms_not_list(self): + """Test validate_config raises error when 'llms' is not a list.""" + invalid_config = { + "llms": "not a list" + } + + # Write invalid config to file + with open(self.config_file, 'w') as f: + yaml.safe_dump(invalid_config, f) + + with self.assertRaises(ValueError) as cm: + self.manager.validate_config() + + self.assertIn("'llms' must be a list", str(cm.exception)) + + def test_validate_config_empty_llms_list(self): + """Test validate_config raises error when llms list is empty.""" + invalid_config = { + "llms": [] + } + + # Write config with empty llms list to file + with open(self.config_file, 'w') as f: + yaml.safe_dump(invalid_config, f) + + with self.assertRaises(ValueError) as cm: + self.manager.validate_config() + + self.assertIn("'llms' list cannot be empty", str(cm.exception)) + + def test_validate_config_file_not_found(self): + """Test validate_config raises error when config file doesn't exist.""" + # Don't create the config file, so it doesn't exist + with self.assertRaises(ValueError) as cm: + self.manager.validate_config() + + self.assertIn("Configuration file", str(cm.exception)) + self.assertIn("not found", str(cm.exception)) + + def test_validate_config_invalid_yaml(self): + """Test validate_config raises error for invalid YAML syntax.""" + # Write invalid YAML to file + with open(self.config_file, 'w') as f: + f.write("invalid: yaml: content: {\n") + + with self.assertRaises(ValueError) as cm: + self.manager.validate_config() + + self.assertIn("Invalid YAML syntax", str(cm.exception)) + + def test_validate_config_not_dict(self): + """Test validate_config raises error when config is not a dictionary.""" + # Write a list instead of dict to file + with open(self.config_file, 'w') as f: + yaml.safe_dump(["not", "a", "dict"], f) + + with self.assertRaises(ValueError) as cm: + self.manager.validate_config() + + self.assertIn("must contain a YAML dictionary/mapping", str(cm.exception)) + + def test_validate_config_llm_not_dict(self): + """Test validate_config raises error when LLM config is not a dictionary.""" + invalid_config = { + "llms": ["not a dict"] + } + + # Write config with non-dict LLM config to file + with open(self.config_file, 'w') as f: + yaml.safe_dump(invalid_config, f) + + with self.assertRaises(ValueError) as cm: + self.manager.validate_config() + + self.assertIn("each LLM configuration must be a dictionary/mapping", str(cm.exception)) + + @patch("azext_aks_agent.agent.llm_config_manager.get_config_dir") + def test_validate_config_skips_default_config_path(self, mock_get_config_dir): + """Test validate_config skips validation for default config path.""" + from azext_aks_agent._consts import CONST_AGENT_CONFIG_FILE_NAME + + # Mock the config directory to match our test setup + mock_get_config_dir.return_value = self.temp_dir + + # Set the manager to use the default config path + default_config_path = os.path.join(self.temp_dir, CONST_AGENT_CONFIG_FILE_NAME) + self.manager.config_path = default_config_path + + # Don't create the file - validation should be skipped for default path + # Should not raise any exception + self.manager.validate_config() if __name__ == '__main__': diff --git a/src/aks-agent/azext_aks_agent/tests/latest/test_aks_agent_llm_providers.py b/src/aks-agent/azext_aks_agent/tests/latest/test_aks_agent_llm_providers.py index 0636af94f47..4e38ce1c266 100644 --- a/src/aks-agent/azext_aks_agent/tests/latest/test_aks_agent_llm_providers.py +++ b/src/aks-agent/azext_aks_agent/tests/latest/test_aks_agent_llm_providers.py @@ -4,7 +4,15 @@ # -------------------------------------------------------------------------------------------- import unittest -from azext_aks_agent.agent.llm_providers import PROVIDER_REGISTRY, AnthropicProvider, GeminiProvider, AzureProvider, OpenAIProvider, OpenAICompatibleProvider + +from azext_aks_agent.agent.llm_providers import ( + PROVIDER_REGISTRY, + AnthropicProvider, + AzureProvider, + GeminiProvider, + OpenAICompatibleProvider, + OpenAIProvider, +) class TestLLMProviders(unittest.TestCase): @@ -18,11 +26,14 @@ def test_provider_registry(self): def test_provider_choices_numbered(self): """Test numbered provider choices are correct and ordered.""" - from azext_aks_agent.agent.llm_providers import _provider_choices_numbered, _available_providers + from azext_aks_agent.agent.llm_providers import ( + _available_providers, + _provider_choices_numbered, + ) choices = _provider_choices_numbered() providers = _available_providers() for idx, name in choices: - self.assertEqual(name, providers[idx-1]) + self.assertEqual(name, providers[idx - 1]().readable_name) if __name__ == '__main__': diff --git a/src/aks-agent/setup.py b/src/aks-agent/setup.py index bb5a288bd02..ddc8ccb1267 100644 --- a/src/aks-agent/setup.py +++ b/src/aks-agent/setup.py @@ -9,7 +9,7 @@ from setuptools import find_packages, setup -VERSION = "1.0.0b7" +VERSION = "1.0.0b8" CLASSIFIERS = [ "Development Status :: 4 - Beta",