diff --git a/.env.example b/.env.example index 2d29bdbe1..e80bce5a5 100644 --- a/.env.example +++ b/.env.example @@ -29,6 +29,12 @@ AZURE_OPENAI_ENDPOINT=https://your-resource.openai.azure.com/ # Get your X.AI API key from: https://console.x.ai/ XAI_API_KEY=your_xai_api_key_here +# Get your Cerebras API key from: https://inference-docs.cerebras.ai/ +# Cerebras provides ultra-fast inference for ZAI-GLM, OpenAI GPT-OSS, Qwen3, and Llama models +CEREBRAS_API_KEY=your_cerebras_api_key_here +# CEREBRAS_ALLOWED_MODELS=zai-glm-4.7 +# CEREBRAS_MODELS_CONFIG_PATH=/path/to/custom_cerebras_models.json + # Get your DIAL API key and configure host URL # DIAL provides unified access to multiple AI models through a single API DIAL_API_KEY=your_dial_api_key_here @@ -105,6 +111,16 @@ DEFAULT_THINKING_MODE_THINKDEEP=high # - grok3 (shorthand for grok-3) # - grokfast (shorthand for grok-3-fast) # +# Supported Cerebras models: +# - gpt-oss-120b (128K context, ~3000 tok/s, OpenAI OSS reasoning model) +# - gpt-oss, oss-120b, openai-oss (shorthands for gpt-oss-120b) +# - qwen-3-235b-a22b-instruct-2507 (128K context, ~1400 tok/s, frontier coding/reasoning) +# - qwen3, qwen-3, qwen235b, qwen3-235b (shorthands for qwen-3-235b-a22b-instruct-2507) +# - zai-glm-4.7 (128K context, ~1000 tok/s, reasoning, tool calling) +# - cerebras, glm, glm-4.7, zai, zai-glm (shorthands for zai-glm-4.7) +# - llama3.1-8b (32K context, ~2200 tok/s, fastest small model) +# - llama8b, llama-8b, llama3.1, llama3-8b (shorthands for llama3.1-8b) +# # Supported DIAL models (when available in your DIAL deployment): # - o3-2025-04-16 (200K context, latest O3 release) # - o4-mini-2025-04-16 (200K context, latest O4 mini) @@ -142,6 +158,7 @@ DEFAULT_THINKING_MODE_THINKDEEP=high # OPENAI_ALLOWED_MODELS= # GOOGLE_ALLOWED_MODELS= # XAI_ALLOWED_MODELS= +# CEREBRAS_ALLOWED_MODELS=zai-glm-4.7 # Only allow zai-glm-4.7 # DIAL_ALLOWED_MODELS= # Optional: Custom model configuration file path diff --git a/conf/cerebras_models.json b/conf/cerebras_models.json new file mode 100644 index 000000000..45390629a --- /dev/null +++ b/conf/cerebras_models.json @@ -0,0 +1,111 @@ +{ + "_README": { + "description": "Model metadata for Cerebras Inference API.", + "documentation": "https://inference-docs.cerebras.ai/models", + "usage": "Models listed here are exposed directly through the Cerebras provider. Aliases are case-insensitive.", + "field_notes": "Matches providers/shared/model_capabilities.py.", + "field_descriptions": { + "model_name": "The model identifier (e.g., 'zai-glm-4.7')", + "aliases": "Array of short names users can type instead of the full model name", + "context_window": "Total number of tokens the model can process (input + output combined)", + "max_output_tokens": "Maximum number of tokens the model can generate in a single response", + "supports_extended_thinking": "Whether the model supports extended reasoning tokens", + "supports_json_mode": "Whether the model can guarantee valid JSON output", + "supports_function_calling": "Whether the model supports function/tool calling", + "supports_images": "Whether the model can process images/visual input", + "supports_temperature": "Whether the model accepts temperature parameter in API calls", + "description": "Human-readable description of the model", + "intelligence_score": "1-20 human rating used as the primary signal for auto-mode model ordering" + } + }, + "models": [ + { + "model_name": "gpt-oss-120b", + "friendly_name": "Cerebras (gpt-oss-120b)", + "aliases": [ + "gpt-oss", + "oss-120b", + "openai-oss" + ], + "intelligence_score": 17, + "description": "OpenAI GPT-OSS 120B — ultra-fast inference (~3000 tok/s), 128K context, internal chain-of-thought reasoning, strong agentic/tool-use", + "context_window": 131072, + "max_output_tokens": 40000, + "supports_extended_thinking": false, + "supports_system_prompts": true, + "supports_streaming": true, + "supports_function_calling": true, + "supports_json_mode": true, + "supports_images": false, + "supports_temperature": true, + "temperature_constraint": "range" + }, + { + "model_name": "qwen-3-235b-a22b-instruct-2507", + "friendly_name": "Cerebras (qwen-3-235b-a22b-instruct-2507)", + "aliases": [ + "qwen3", + "qwen-3", + "qwen235b", + "qwen3-235b" + ], + "intelligence_score": 16, + "description": "Qwen3-235B-A22B Instruct 2507 — fast inference (~1400 tok/s), 128K context, strong coding/reasoning/tool use", + "context_window": 131072, + "max_output_tokens": 40000, + "supports_extended_thinking": false, + "supports_system_prompts": true, + "supports_streaming": true, + "supports_function_calling": true, + "supports_json_mode": true, + "supports_images": false, + "supports_temperature": true, + "temperature_constraint": "range" + }, + { + "model_name": "zai-glm-4.7", + "friendly_name": "Cerebras (zai-glm-4.7)", + "aliases": [ + "cerebras", + "glm", + "glm-4.7", + "zai", + "zai-glm" + ], + "intelligence_score": 14, + "description": "Cerebras ZAI-GLM 4.7 — fast inference (~1000 tok/s), 128K context, reasoning model with tool calling", + "context_window": 131072, + "max_output_tokens": 40000, + "supports_extended_thinking": false, + "supports_system_prompts": true, + "supports_streaming": true, + "supports_function_calling": true, + "supports_json_mode": true, + "supports_images": false, + "supports_temperature": true, + "temperature_constraint": "range" + }, + { + "model_name": "llama3.1-8b", + "friendly_name": "Cerebras (llama3.1-8b)", + "aliases": [ + "llama8b", + "llama-8b", + "llama3.1", + "llama3-8b" + ], + "intelligence_score": 9, + "description": "Meta Llama 3.1 8B — fastest small model on Cerebras (~2200 tok/s), 32K context, ideal for real-time and high-throughput tasks", + "context_window": 32768, + "max_output_tokens": 8192, + "supports_extended_thinking": false, + "supports_system_prompts": true, + "supports_streaming": true, + "supports_function_calling": true, + "supports_json_mode": true, + "supports_images": false, + "supports_temperature": true, + "temperature_constraint": "range" + } + ] +} diff --git a/docs/configuration.md b/docs/configuration.md index d084f2bd9..9d89175c0 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -71,6 +71,7 @@ DEFAULT_MODEL=auto # Claude picks best model for each task (recommended) - `conf/openai_models.json` – OpenAI catalogue (can be overridden with `OPENAI_MODELS_CONFIG_PATH`) - `conf/gemini_models.json` – Gemini catalogue (`GEMINI_MODELS_CONFIG_PATH`) - `conf/xai_models.json` – X.AI / GROK catalogue (`XAI_MODELS_CONFIG_PATH`) + - `conf/cerebras_models.json` – Cerebras Inference catalogue (`CEREBRAS_MODELS_CONFIG_PATH`) - `conf/openrouter_models.json` – OpenRouter catalogue (`OPENROUTER_MODELS_CONFIG_PATH`) - `conf/dial_models.json` – DIAL aggregation catalogue (`DIAL_MODELS_CONFIG_PATH`) - `conf/custom_models.json` – Custom/OpenAI-compatible endpoints (`CUSTOM_MODELS_CONFIG_PATH`) @@ -84,6 +85,7 @@ DEFAULT_MODEL=auto # Claude picks best model for each task (recommended) | OpenAI | `gpt-5.2`, `gpt-5.1-codex`, `gpt-5.1-codex-mini`, `gpt-5`, `gpt-5.2-pro`, `gpt-5-mini`, `gpt-5-nano`, `gpt-5-codex`, `gpt-4.1`, `o3`, `o3-mini`, `o3-pro`, `o4-mini` | `gpt5.2`, `gpt-5.2`, `5.2`, `gpt5.1-codex`, `codex-5.1`, `codex-mini`, `gpt5`, `gpt5pro`, `mini`, `nano`, `codex`, `o3mini`, `o3pro`, `o4mini` | | Gemini | `gemini-2.5-pro`, `gemini-2.5-flash`, `gemini-2.0-flash`, `gemini-2.0-flash-lite` | `pro`, `gemini-pro`, `flash`, `flash-2.0`, `flashlite` | | X.AI | `grok-4`, `grok-4.1-fast` | `grok`, `grok4`, `grok-4.1-fast-reasoning` | + | Cerebras | `zai-glm-4.7`, `gpt-oss-120b`, `qwen-3-235b-a22b-instruct-2507`, `llama3.1-8b` | `cerebras`, `glm`, `zai`, `gpt-oss`, `oss-120b`, `qwen3`, `qwen235b`, `llama8b`, `llama3.1` | | OpenRouter | See `conf/openrouter_models.json` for the continually evolving catalogue | e.g., `opus`, `sonnet`, `flash`, `pro`, `mistral` | | Custom | User-managed entries such as `llama3.2` | Define your own aliases per entry | diff --git a/providers/cerebras.py b/providers/cerebras.py new file mode 100644 index 000000000..2ee8fd874 --- /dev/null +++ b/providers/cerebras.py @@ -0,0 +1,83 @@ +"""Cerebras Inference model provider implementation.""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, ClassVar + +if TYPE_CHECKING: + from tools.models import ToolModelCategory + +from .openai_compatible import OpenAICompatibleProvider +from .registries.cerebras import CerebrasModelRegistry +from .registry_provider_mixin import RegistryBackedProviderMixin +from .shared import ModelCapabilities, ProviderType + +logger = logging.getLogger(__name__) + + +class CerebrasModelProvider(RegistryBackedProviderMixin, OpenAICompatibleProvider): + """Integration for Cerebras Inference API. + + Publishes capability metadata for the officially supported deployments and + maps tool-category preferences to the appropriate Cerebras model. + + Model routing by category: + BALANCED → zai-glm-4.7 (default; only model on Cerebras Code plan) + EXTENDED_REASONING → gpt-oss-120b (strongest reasoning, ~3000 tok/s; paid tier) + FAST_RESPONSE → llama3.1-8b (fastest small model, ~2200 tok/s; paid tier) + """ + + FRIENDLY_NAME = "Cerebras" + + REGISTRY_CLASS = CerebrasModelRegistry + MODEL_CAPABILITIES: ClassVar[dict[str, ModelCapabilities]] = {} + + # Category routing — ordered preference lists (first available wins). + # zai-glm-4.7 is the default: it is the only model on the Cerebras Code + # (free) plan and must always be the BALANCED fallback. + _REASONING_PREFERENCE = ["gpt-oss-120b", "qwen-3-235b-a22b-instruct-2507", "zai-glm-4.7", "llama3.1-8b"] + _BALANCED_PREFERENCE = ["zai-glm-4.7", "qwen-3-235b-a22b-instruct-2507", "gpt-oss-120b", "llama3.1-8b"] + _FAST_PREFERENCE = ["llama3.1-8b", "zai-glm-4.7", "qwen-3-235b-a22b-instruct-2507", "gpt-oss-120b"] + + def __init__(self, api_key: str, **kwargs): + """Initialize Cerebras provider with API key.""" + kwargs.setdefault("base_url", "https://api.cerebras.ai/v1") + self._ensure_registry() + super().__init__(api_key, **kwargs) + self._invalidate_capability_cache() + + def get_provider_type(self) -> ProviderType: + """Get the provider type.""" + return ProviderType.CEREBRAS + + def get_preferred_model(self, category: ToolModelCategory, allowed_models: list[str]) -> str | None: + """Get Cerebras's preferred model for a given category from allowed models. + + Args: + category: The tool category requiring a model + allowed_models: Pre-filtered list of models allowed by restrictions + + Returns: + Preferred model name or None + """ + if not allowed_models: + return None + + from tools.models import ToolModelCategory + + if category == ToolModelCategory.EXTENDED_REASONING: + preference = self._REASONING_PREFERENCE + elif category == ToolModelCategory.FAST_RESPONSE: + preference = self._FAST_PREFERENCE + else: # BALANCED or default + preference = self._BALANCED_PREFERENCE + + for model in preference: + if model in allowed_models: + return model + return allowed_models[0] + + +# Load registry data at import time +CerebrasModelProvider._ensure_registry() diff --git a/providers/registries/cerebras.py b/providers/registries/cerebras.py new file mode 100644 index 000000000..d6e151750 --- /dev/null +++ b/providers/registries/cerebras.py @@ -0,0 +1,19 @@ +"""Registry loader for Cerebras model capabilities.""" + +from __future__ import annotations + +from ..shared import ProviderType +from .base import CapabilityModelRegistry + + +class CerebrasModelRegistry(CapabilityModelRegistry): + """Capability registry backed by ``conf/cerebras_models.json``.""" + + def __init__(self, config_path: str | None = None) -> None: + super().__init__( + env_var_name="CEREBRAS_MODELS_CONFIG_PATH", + default_filename="cerebras_models.json", + provider=ProviderType.CEREBRAS, + friendly_prefix="Cerebras ({model})", + config_path=config_path, + ) diff --git a/providers/registry.py b/providers/registry.py index cd28c4266..a61330dc7 100644 --- a/providers/registry.py +++ b/providers/registry.py @@ -40,6 +40,7 @@ class ModelProviderRegistry: ProviderType.OPENAI, # Direct OpenAI access ProviderType.AZURE, # Azure-hosted OpenAI deployments ProviderType.XAI, # Direct X.AI GROK access + ProviderType.CEREBRAS, # Cerebras Inference (ZAI-GLM, GPT-OSS, Qwen3, Llama) ProviderType.DIAL, # DIAL unified API access ProviderType.CUSTOM, # Local/self-hosted models ProviderType.OPENROUTER, # Catch-all for cloud models @@ -336,6 +337,7 @@ def _get_api_key_for_provider(cls, provider_type: ProviderType) -> Optional[str] ProviderType.OPENAI: "OPENAI_API_KEY", ProviderType.AZURE: "AZURE_OPENAI_API_KEY", ProviderType.XAI: "XAI_API_KEY", + ProviderType.CEREBRAS: "CEREBRAS_API_KEY", ProviderType.OPENROUTER: "OPENROUTER_API_KEY", ProviderType.CUSTOM: "CUSTOM_API_KEY", # Can be empty for providers that don't need auth ProviderType.DIAL: "DIAL_API_KEY", diff --git a/providers/shared/provider_type.py b/providers/shared/provider_type.py index a1b31377f..639e02813 100644 --- a/providers/shared/provider_type.py +++ b/providers/shared/provider_type.py @@ -12,6 +12,7 @@ class ProviderType(Enum): OPENAI = "openai" AZURE = "azure" XAI = "xai" - OPENROUTER = "openrouter" - CUSTOM = "custom" + CEREBRAS = "cerebras" DIAL = "dial" + CUSTOM = "custom" + OPENROUTER = "openrouter" diff --git a/pyproject.toml b/pyproject.toml index c60506dc1..fae4645a3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,6 +32,7 @@ py-modules = ["server", "config"] "conf/openai_models.json", "conf/gemini_models.json", "conf/xai_models.json", + "conf/cerebras_models.json", "conf/dial_models.json", ] diff --git a/server.py b/server.py index 74f7ed83f..7b243df27 100644 --- a/server.py +++ b/server.py @@ -387,12 +387,20 @@ def configure_providers(): """ # Log environment variable status for debugging logger.debug("Checking environment variables for API keys...") - api_keys_to_check = ["OPENAI_API_KEY", "OPENROUTER_API_KEY", "GEMINI_API_KEY", "XAI_API_KEY", "CUSTOM_API_URL"] + api_keys_to_check = [ + "OPENAI_API_KEY", + "OPENROUTER_API_KEY", + "GEMINI_API_KEY", + "XAI_API_KEY", + "CEREBRAS_API_KEY", + "CUSTOM_API_URL", + ] for key in api_keys_to_check: value = get_env(key) logger.debug(f" {key}: {'[PRESENT]' if value else '[MISSING]'}") from providers import ModelProviderRegistry from providers.azure_openai import AzureOpenAIProvider + from providers.cerebras import CerebrasModelProvider from providers.custom import CustomProvider from providers.dial import DIALModelProvider from providers.gemini import GeminiModelProvider @@ -455,6 +463,13 @@ def configure_providers(): has_native_apis = True logger.info("X.AI API key found - GROK models available") + # Check for Cerebras API key + cerebras_key = get_env("CEREBRAS_API_KEY") + if cerebras_key and cerebras_key != "your_cerebras_api_key_here": + valid_providers.append("Cerebras") + has_native_apis = True + logger.info("Cerebras API key found - Cerebras Inference models available") + # Check for DIAL API key dial_key = get_env("DIAL_API_KEY") if dial_key and dial_key != "your_dial_api_key_here": @@ -513,6 +528,10 @@ def configure_providers(): ModelProviderRegistry.register_provider(ProviderType.XAI, XAIModelProvider) registered_providers.append(ProviderType.XAI.value) logger.debug(f"Registered provider: {ProviderType.XAI.value}") + if cerebras_key and cerebras_key != "your_cerebras_api_key_here": + ModelProviderRegistry.register_provider(ProviderType.CEREBRAS, CerebrasModelProvider) + registered_providers.append(ProviderType.CEREBRAS.value) + logger.debug(f"Registered provider: {ProviderType.CEREBRAS.value}") if dial_key and dial_key != "your_dial_api_key_here": ModelProviderRegistry.register_provider(ProviderType.DIAL, DIALModelProvider) registered_providers.append(ProviderType.DIAL.value) @@ -600,7 +619,13 @@ def cleanup_providers(): # Validate restrictions against known models provider_instances = {} - provider_types_to_validate = [ProviderType.GOOGLE, ProviderType.OPENAI, ProviderType.XAI, ProviderType.DIAL] + provider_types_to_validate = [ + ProviderType.GOOGLE, + ProviderType.OPENAI, + ProviderType.XAI, + ProviderType.CEREBRAS, + ProviderType.DIAL, + ] for provider_type in provider_types_to_validate: provider = ModelProviderRegistry.get_provider(provider_type) if provider: diff --git a/simulator_tests/test_chat_simple_validation.py b/simulator_tests/test_chat_simple_validation.py index a452d71e9..c6709584d 100644 --- a/simulator_tests/test_chat_simple_validation.py +++ b/simulator_tests/test_chat_simple_validation.py @@ -13,7 +13,6 @@ - Conversation context preservation across turns """ - from .conversation_base_test import ConversationBaseTest diff --git a/simulator_tests/test_conversation_chain_validation.py b/simulator_tests/test_conversation_chain_validation.py index 2d70b862b..5ca53338d 100644 --- a/simulator_tests/test_conversation_chain_validation.py +++ b/simulator_tests/test_conversation_chain_validation.py @@ -21,7 +21,6 @@ - Properly traverse parent relationships for history reconstruction """ - from .conversation_base_test import ConversationBaseTest diff --git a/simulator_tests/test_cross_tool_comprehensive.py b/simulator_tests/test_cross_tool_comprehensive.py index 8389953ec..6cdd33901 100644 --- a/simulator_tests/test_cross_tool_comprehensive.py +++ b/simulator_tests/test_cross_tool_comprehensive.py @@ -12,7 +12,6 @@ 5. Proper tool chaining with context """ - from .conversation_base_test import ConversationBaseTest diff --git a/simulator_tests/test_ollama_custom_url.py b/simulator_tests/test_ollama_custom_url.py index f23b6ee8d..f40c1e106 100644 --- a/simulator_tests/test_ollama_custom_url.py +++ b/simulator_tests/test_ollama_custom_url.py @@ -9,7 +9,6 @@ - Model alias resolution for local models """ - from .base_test import BaseSimulatorTest diff --git a/simulator_tests/test_openrouter_fallback.py b/simulator_tests/test_openrouter_fallback.py index 91fc058ab..74023437f 100644 --- a/simulator_tests/test_openrouter_fallback.py +++ b/simulator_tests/test_openrouter_fallback.py @@ -8,7 +8,6 @@ - Auto mode correctly selects OpenRouter models """ - from .base_test import BaseSimulatorTest diff --git a/simulator_tests/test_openrouter_models.py b/simulator_tests/test_openrouter_models.py index bd69806a5..5fb3348bb 100644 --- a/simulator_tests/test_openrouter_models.py +++ b/simulator_tests/test_openrouter_models.py @@ -9,7 +9,6 @@ - Error handling when models are not available """ - from .base_test import BaseSimulatorTest diff --git a/simulator_tests/test_xai_models.py b/simulator_tests/test_xai_models.py index 41c57e3a4..e8d32740a 100644 --- a/simulator_tests/test_xai_models.py +++ b/simulator_tests/test_xai_models.py @@ -9,7 +9,6 @@ - API integration and response validation """ - from .base_test import BaseSimulatorTest diff --git a/tests/test_auto_mode_model_listing.py b/tests/test_auto_mode_model_listing.py index 5f1ae1586..701cc56f4 100644 --- a/tests/test_auto_mode_model_listing.py +++ b/tests/test_auto_mode_model_listing.py @@ -110,7 +110,7 @@ def test_error_listing_respects_env_restrictions(monkeypatch, reset_registry): ): monkeypatch.setenv(key, value) - for var in ("XAI_API_KEY", "CUSTOM_API_URL", "CUSTOM_API_KEY", "DIAL_API_KEY"): + for var in ("XAI_API_KEY", "CEREBRAS_API_KEY", "CUSTOM_API_URL", "CUSTOM_API_KEY", "DIAL_API_KEY"): monkeypatch.delenv(var, raising=False) for azure_var in ( "AZURE_OPENAI_API_KEY", @@ -202,6 +202,7 @@ def test_error_listing_without_restrictions_shows_full_catalog(monkeypatch, rese "DIAL_ALLOWED_MODELS", "CUSTOM_API_URL", "CUSTOM_API_KEY", + "CEREBRAS_API_KEY", ): monkeypatch.delenv(var, raising=False) diff --git a/tests/test_cerebras_provider.py b/tests/test_cerebras_provider.py new file mode 100644 index 000000000..c1c22efa1 --- /dev/null +++ b/tests/test_cerebras_provider.py @@ -0,0 +1,462 @@ +"""Tests for Cerebras provider implementation.""" + +import os +from unittest.mock import MagicMock, patch + +import pytest + +from providers.cerebras import CerebrasModelProvider +from providers.shared import ProviderType + + +class TestCerebrasProvider: + """Test Cerebras provider functionality.""" + + def setup_method(self): + """Set up clean state before each test.""" + # Clear restriction service cache before each test + import utils.model_restrictions + + utils.model_restrictions._restriction_service = None + + def teardown_method(self): + """Clean up after each test to avoid singleton issues.""" + # Clear restriction service cache after each test + import utils.model_restrictions + + utils.model_restrictions._restriction_service = None + + @patch.dict(os.environ, {"CEREBRAS_API_KEY": "test-key"}) + def test_initialization(self): + """Test provider initialization.""" + provider = CerebrasModelProvider("test-key") + assert provider.api_key == "test-key" + assert provider.get_provider_type() == ProviderType.CEREBRAS + assert provider.base_url == "https://api.cerebras.ai/v1" + + def test_initialization_with_custom_url(self): + """Test provider initialization with custom base URL.""" + provider = CerebrasModelProvider("test-key", base_url="https://custom.cerebras.ai/v1") + assert provider.api_key == "test-key" + assert provider.base_url == "https://custom.cerebras.ai/v1" + + def test_model_validation(self): + """Test model name validation.""" + provider = CerebrasModelProvider("test-key") + + # Test valid models + assert provider.validate_model_name("zai-glm-4.7") is True + assert provider.validate_model_name("cerebras") is True + assert provider.validate_model_name("glm") is True + assert provider.validate_model_name("glm-4.7") is True + assert provider.validate_model_name("zai") is True + assert provider.validate_model_name("zai-glm") is True + + # Test invalid model + assert provider.validate_model_name("invalid-model") is False + assert provider.validate_model_name("gpt-4") is False + assert provider.validate_model_name("gemini-pro") is False + assert provider.validate_model_name("grok-4") is False + + def test_resolve_model_name(self): + """Test model name resolution.""" + provider = CerebrasModelProvider("test-key") + + # Test shorthand resolution + assert provider._resolve_model_name("cerebras") == "zai-glm-4.7" + assert provider._resolve_model_name("glm") == "zai-glm-4.7" + assert provider._resolve_model_name("glm-4.7") == "zai-glm-4.7" + assert provider._resolve_model_name("zai") == "zai-glm-4.7" + assert provider._resolve_model_name("zai-glm") == "zai-glm-4.7" + + # Test full name passthrough + assert provider._resolve_model_name("zai-glm-4.7") == "zai-glm-4.7" + + def test_get_capabilities(self): + """Test getting model capabilities for zai-glm-4.7.""" + provider = CerebrasModelProvider("test-key") + + capabilities = provider.get_capabilities("zai-glm-4.7") + assert capabilities.model_name == "zai-glm-4.7" + assert capabilities.friendly_name == "Cerebras (zai-glm-4.7)" + assert capabilities.context_window == 131072 + assert capabilities.max_output_tokens == 40000 + assert capabilities.provider == ProviderType.CEREBRAS + assert capabilities.supports_extended_thinking is False + assert capabilities.supports_system_prompts is True + assert capabilities.supports_streaming is True + assert capabilities.supports_function_calling is True + assert capabilities.supports_json_mode is True + assert capabilities.supports_images is False + assert capabilities.supports_temperature is True + + # Test temperature range (default range constraint from registry) + assert capabilities.temperature_constraint.min_temp == 0.0 + assert capabilities.temperature_constraint.max_temp == 2.0 + assert capabilities.temperature_constraint.default_temp == 0.3 + + def test_get_capabilities_with_shorthand(self): + """Test getting model capabilities with shorthand.""" + provider = CerebrasModelProvider("test-key") + + capabilities = provider.get_capabilities("cerebras") + assert capabilities.model_name == "zai-glm-4.7" # Should resolve to full name + assert capabilities.context_window == 131072 + + capabilities_glm = provider.get_capabilities("glm") + assert capabilities_glm.model_name == "zai-glm-4.7" + + def test_unsupported_model_capabilities(self): + """Test error handling for unsupported models.""" + provider = CerebrasModelProvider("test-key") + + with pytest.raises(ValueError, match="Unsupported model 'invalid-model' for provider cerebras"): + provider.get_capabilities("invalid-model") + + def test_extended_thinking_flags(self): + """Cerebras does not support extended thinking (no reasoning-token protocol).""" + provider = CerebrasModelProvider("test-key") + + all_aliases = [ + "zai-glm-4.7", + "cerebras", + "glm", + "glm-4.7", + "zai", + "zai-glm", + ] + for alias in all_aliases: + assert provider.get_capabilities(alias).supports_extended_thinking is False + + def test_provider_type(self): + """Test provider type identification.""" + provider = CerebrasModelProvider("test-key") + assert provider.get_provider_type() == ProviderType.CEREBRAS + + @patch.dict(os.environ, {"CEREBRAS_ALLOWED_MODELS": "zai-glm-4.7"}) + def test_model_restrictions(self): + """Test that CEREBRAS_ALLOWED_MODELS env var is wired into the restriction service.""" + # Clear cached restriction service + import utils.model_restrictions + from providers.registry import ModelProviderRegistry + + utils.model_restrictions._restriction_service = None + ModelProviderRegistry.reset_for_testing() + + provider = CerebrasModelProvider("test-key") + + # zai-glm-4.7 should be allowed (including alias) + assert provider.validate_model_name("zai-glm-4.7") is True + assert provider.validate_model_name("cerebras") is True + + # Paid-tier models must be REJECTED when only zai-glm-4.7 is allowed. + # This catches the bug where CEREBRAS was missing from + # ModelRestrictionService.ENV_VARS and the env var was silently ignored. + assert provider.validate_model_name("gpt-oss-120b") is False + assert provider.validate_model_name("gpt-oss") is False + assert provider.validate_model_name("qwen-3-235b-a22b-instruct-2507") is False + assert provider.validate_model_name("qwen3") is False + assert provider.validate_model_name("llama3.1-8b") is False + assert provider.validate_model_name("llama8b") is False + + @patch.dict(os.environ, {"CEREBRAS_API_KEY": "test-key", "CEREBRAS_ALLOWED_MODELS": "zai-glm-4.7"}) + def test_restrictions_filter_auto_mode_routing(self): + """Auto-mode routing must respect CEREBRAS_ALLOWED_MODELS via the registry filter. + + Regression test for the missing ENV_VARS wiring: the provider's + get_preferred_model() expects the registry to pre-filter allowed_models, + so the centralized restriction service must know about CEREBRAS. + """ + import utils.model_restrictions + from providers.registry import ModelProviderRegistry + + utils.model_restrictions._restriction_service = None + ModelProviderRegistry.reset_for_testing() + ModelProviderRegistry.register_provider(ProviderType.CEREBRAS, CerebrasModelProvider) + + provider = ModelProviderRegistry.get_provider(ProviderType.CEREBRAS) + assert provider is not None + + # The registry's allowlist filter must return only zai-glm-4.7. + allowed = ModelProviderRegistry._get_allowed_models_for_provider(provider, ProviderType.CEREBRAS) + assert allowed == ["zai-glm-4.7"], f"Expected only zai-glm-4.7, got {allowed}" + + # And category routing must therefore always return zai-glm-4.7, + # not gpt-oss-120b or llama3.1-8b — even for EXTENDED_REASONING/FAST_RESPONSE + # whose preference lists would otherwise pick those paid-tier models first. + from tools.models import ToolModelCategory + + for cat in ( + ToolModelCategory.BALANCED, + ToolModelCategory.EXTENDED_REASONING, + ToolModelCategory.FAST_RESPONSE, + ): + assert provider.get_preferred_model(cat, allowed) == "zai-glm-4.7" + + @patch.dict(os.environ, {"CEREBRAS_API_KEY": "test-key", "CEREBRAS_ALLOWED_MODELS": "cerebras"}) + def test_multiple_model_restrictions(self): + """Restrictions specified via alias must accept the canonical name too.""" + import utils.model_restrictions + from providers.registry import ModelProviderRegistry + + utils.model_restrictions._restriction_service = None + ModelProviderRegistry.reset_for_testing() + # Provider must be registered so the restriction service can resolve + # the "cerebras" alias to its canonical name during validation. + ModelProviderRegistry.register_provider(ProviderType.CEREBRAS, CerebrasModelProvider) + provider = ModelProviderRegistry.get_provider(ProviderType.CEREBRAS) + + # Alias should be allowed (resolves to zai-glm-4.7) + assert provider.validate_model_name("cerebras") is True + assert provider.validate_model_name("zai-glm-4.7") is True + # And paid-tier models must still be rejected + assert provider.validate_model_name("gpt-oss-120b") is False + assert provider.validate_model_name("llama3.1-8b") is False + + @patch.dict(os.environ, {"CEREBRAS_ALLOWED_MODELS": "zai-glm-4.7,cerebras,glm"}) + def test_both_shorthand_and_full_name_allowed(self): + """Test that aliases and canonical names can be allowed together.""" + # Clear cached restriction service + import utils.model_restrictions + + utils.model_restrictions._restriction_service = None + + provider = CerebrasModelProvider("test-key") + + # Both shorthand and full name should be allowed when explicitly listed + assert provider.validate_model_name("zai-glm-4.7") is True + assert provider.validate_model_name("cerebras") is True + assert provider.validate_model_name("glm") is True + + @patch.dict(os.environ, {"CEREBRAS_ALLOWED_MODELS": ""}) + def test_empty_restrictions_allows_all(self): + """Test that empty restrictions allow all models.""" + # Clear cached restriction service + import utils.model_restrictions + + utils.model_restrictions._restriction_service = None + + provider = CerebrasModelProvider("test-key") + + assert provider.validate_model_name("zai-glm-4.7") is True + assert provider.validate_model_name("cerebras") is True + assert provider.validate_model_name("glm") is True + + def test_friendly_name(self): + """Test friendly name constant.""" + provider = CerebrasModelProvider("test-key") + assert provider.FRIENDLY_NAME == "Cerebras" + + capabilities = provider.get_capabilities("zai-glm-4.7") + assert capabilities.friendly_name == "Cerebras (zai-glm-4.7)" + + def test_supported_models_structure(self): + """Test that MODEL_CAPABILITIES has all four models with correct structure.""" + provider = CerebrasModelProvider("test-key") + + from providers.shared import ModelCapabilities + + expected_models = { + "gpt-oss-120b": {"context_window": 131072, "max_output_tokens": 40000, "intelligence_score": 17}, + "qwen-3-235b-a22b-instruct-2507": { + "context_window": 131072, + "max_output_tokens": 40000, + "intelligence_score": 16, + }, + "zai-glm-4.7": {"context_window": 131072, "max_output_tokens": 40000, "intelligence_score": 14}, + "llama3.1-8b": {"context_window": 32768, "max_output_tokens": 8192, "intelligence_score": 9}, + } + for model_name, expected in expected_models.items(): + assert model_name in provider.MODEL_CAPABILITIES, f"{model_name} missing from MODEL_CAPABILITIES" + config = provider.MODEL_CAPABILITIES[model_name] + assert isinstance(config, ModelCapabilities) + assert config.context_window == expected["context_window"], f"{model_name} context_window mismatch" + assert config.max_output_tokens == expected["max_output_tokens"], f"{model_name} max_output_tokens mismatch" + assert config.supports_extended_thinking is False, f"{model_name} should not claim extended thinking" + + # Spot-check aliases + assert "cerebras" in provider.MODEL_CAPABILITIES["zai-glm-4.7"].aliases + assert "gpt-oss" in provider.MODEL_CAPABILITIES["gpt-oss-120b"].aliases + assert "qwen3" in provider.MODEL_CAPABILITIES["qwen-3-235b-a22b-instruct-2507"].aliases + assert "llama8b" in provider.MODEL_CAPABILITIES["llama3.1-8b"].aliases + + def test_new_model_capabilities_gpt_oss(self): + """Test gpt-oss-120b capabilities and alias resolution.""" + provider = CerebrasModelProvider("test-key") + + for alias in ("gpt-oss-120b", "gpt-oss", "oss-120b", "openai-oss"): + caps = provider.get_capabilities(alias) + assert caps.model_name == "gpt-oss-120b" + assert caps.context_window == 131072 + assert caps.max_output_tokens == 40000 + assert caps.supports_function_calling is True + assert caps.supports_extended_thinking is False + + def test_new_model_capabilities_qwen3(self): + """Test qwen-3-235b capabilities and alias resolution.""" + provider = CerebrasModelProvider("test-key") + + for alias in ("qwen-3-235b-a22b-instruct-2507", "qwen3", "qwen-3", "qwen235b", "qwen3-235b"): + caps = provider.get_capabilities(alias) + assert caps.model_name == "qwen-3-235b-a22b-instruct-2507" + assert caps.context_window == 131072 + assert caps.max_output_tokens == 40000 + assert caps.supports_function_calling is True + assert caps.supports_extended_thinking is False + + def test_new_model_capabilities_llama(self): + """Test llama3.1-8b capabilities and alias resolution.""" + provider = CerebrasModelProvider("test-key") + + for alias in ("llama3.1-8b", "llama8b", "llama-8b", "llama3.1", "llama3-8b"): + caps = provider.get_capabilities(alias) + assert caps.model_name == "llama3.1-8b" + assert caps.context_window == 32768 + assert caps.max_output_tokens == 8192 + assert caps.supports_function_calling is True + assert caps.supports_extended_thinking is False + + def test_get_preferred_model_routing(self): + """Test category-based model routing across all four models.""" + from tools.models import ToolModelCategory + + provider = CerebrasModelProvider("test-key") + all_models = ["gpt-oss-120b", "qwen-3-235b-a22b-instruct-2507", "zai-glm-4.7", "llama3.1-8b"] + + # BALANCED → zai-glm-4.7 (default; only model on Cerebras Code plan) + assert provider.get_preferred_model(ToolModelCategory.BALANCED, all_models) == "zai-glm-4.7" + + # EXTENDED_REASONING → gpt-oss-120b (strongest reasoner; paid tier) + assert provider.get_preferred_model(ToolModelCategory.EXTENDED_REASONING, all_models) == "gpt-oss-120b" + + # FAST_RESPONSE → llama3.1-8b (fastest small model; paid tier) + assert provider.get_preferred_model(ToolModelCategory.FAST_RESPONSE, all_models) == "llama3.1-8b" + + def test_get_preferred_model_fallback(self): + """Test category routing falls back gracefully when top choice unavailable.""" + from tools.models import ToolModelCategory + + provider = CerebrasModelProvider("test-key") + + # Code plan (zai-glm-4.7 only) → always returns zai-glm-4.7 for any category + for cat in [ToolModelCategory.BALANCED, ToolModelCategory.EXTENDED_REASONING, ToolModelCategory.FAST_RESPONSE]: + assert provider.get_preferred_model(cat, ["zai-glm-4.7"]) == "zai-glm-4.7" + + # Without gpt-oss-120b, EXTENDED_REASONING falls back to qwen3 + assert ( + provider.get_preferred_model( + ToolModelCategory.EXTENDED_REASONING, + ["qwen-3-235b-a22b-instruct-2507", "zai-glm-4.7"], + ) + == "qwen-3-235b-a22b-instruct-2507" + ) + + # Without llama3.1-8b, FAST_RESPONSE falls back to zai-glm-4.7 + assert ( + provider.get_preferred_model( + ToolModelCategory.FAST_RESPONSE, + ["zai-glm-4.7", "gpt-oss-120b"], + ) + == "zai-glm-4.7" + ) + + # Empty list → None + assert provider.get_preferred_model(ToolModelCategory.BALANCED, []) is None + + @patch("providers.openai_compatible.OpenAI") + def test_generate_content_resolves_alias_before_api_call(self, mock_openai_class): + """Test that generate_content resolves aliases before making API calls. + + This is the CRITICAL test that ensures aliases like 'cerebras' get resolved + to 'zai-glm-4.7' before being sent to Cerebras API. + """ + # Set up mock OpenAI client + mock_client = MagicMock() + mock_openai_class.return_value = mock_client + + # Mock the completion response + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message.content = "Test response" + mock_response.choices[0].finish_reason = "stop" + mock_response.model = "zai-glm-4.7" # API returns the resolved model name + mock_response.id = "test-id" + mock_response.created = 1234567890 + mock_response.usage = MagicMock() + mock_response.usage.prompt_tokens = 10 + mock_response.usage.completion_tokens = 5 + mock_response.usage.total_tokens = 15 + + mock_client.chat.completions.create.return_value = mock_response + + provider = CerebrasModelProvider("test-key") + + # Call generate_content with alias 'cerebras' + result = provider.generate_content( + prompt="Test prompt", + model_name="cerebras", + temperature=0.7, # This should be resolved to "zai-glm-4.7" + ) + + # Verify the API was called with the RESOLVED model name + mock_client.chat.completions.create.assert_called_once() + call_kwargs = mock_client.chat.completions.create.call_args[1] + + # CRITICAL ASSERTION: The API should receive "zai-glm-4.7", not "cerebras" + assert ( + call_kwargs["model"] == "zai-glm-4.7" + ), f"Expected 'zai-glm-4.7' but API received '{call_kwargs['model']}'" + + # Verify other parameters + assert call_kwargs["temperature"] == 0.7 + assert len(call_kwargs["messages"]) == 1 + assert call_kwargs["messages"][0]["role"] == "user" + assert call_kwargs["messages"][0]["content"] == "Test prompt" + + # Verify response + assert result.content == "Test response" + assert result.model_name == "zai-glm-4.7" # Should be the resolved name + + @patch("providers.openai_compatible.OpenAI") + def test_generate_content_other_aliases(self, mock_openai_class): + """Test other alias resolutions in generate_content.""" + # Set up mock + mock_client = MagicMock() + mock_openai_class.return_value = mock_client + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message.content = "Test response" + mock_response.choices[0].finish_reason = "stop" + mock_response.usage = MagicMock() + mock_response.usage.prompt_tokens = 10 + mock_response.usage.completion_tokens = 5 + mock_response.usage.total_tokens = 15 + mock_client.chat.completions.create.return_value = mock_response + + provider = CerebrasModelProvider("test-key") + + # Test glm -> zai-glm-4.7 + mock_response.model = "zai-glm-4.7" + provider.generate_content(prompt="Test", model_name="glm", temperature=0.7) + call_kwargs = mock_client.chat.completions.create.call_args[1] + assert call_kwargs["model"] == "zai-glm-4.7" + + # Test glm-4.7 -> zai-glm-4.7 + provider.generate_content(prompt="Test", model_name="glm-4.7", temperature=0.7) + call_kwargs = mock_client.chat.completions.create.call_args[1] + assert call_kwargs["model"] == "zai-glm-4.7" + + # Test zai -> zai-glm-4.7 + provider.generate_content(prompt="Test", model_name="zai", temperature=0.7) + call_kwargs = mock_client.chat.completions.create.call_args[1] + assert call_kwargs["model"] == "zai-glm-4.7" + + # Test zai-glm -> zai-glm-4.7 + provider.generate_content(prompt="Test", model_name="zai-glm", temperature=0.7) + call_kwargs = mock_client.chat.completions.create.call_args[1] + assert call_kwargs["model"] == "zai-glm-4.7" + + # Test zai-glm-4.7 -> zai-glm-4.7 (passthrough) + provider.generate_content(prompt="Test", model_name="zai-glm-4.7", temperature=0.7) + call_kwargs = mock_client.chat.completions.create.call_args[1] + assert call_kwargs["model"] == "zai-glm-4.7" diff --git a/tests/test_directory_expansion_tracking.py b/tests/test_directory_expansion_tracking.py index f4e56a019..79ac5adf9 100644 --- a/tests/test_directory_expansion_tracking.py +++ b/tests/test_directory_expansion_tracking.py @@ -37,8 +37,7 @@ def temp_directory_with_files(self, project_path): files = [] for i in range(5): swift_file = temp_path / f"File{i}.swift" - swift_file.write_text( - f""" + swift_file.write_text(f""" import Foundation class TestClass{i} {{ @@ -46,18 +45,15 @@ class TestClass{i} {{ return "test{i}" }} }} -""" - ) +""") files.append(str(swift_file)) # Create a Python file as well python_file = temp_path / "helper.py" - python_file.write_text( - """ + python_file.write_text(""" def helper_function(): return "helper" -""" - ) +""") files.append(str(python_file)) try: diff --git a/tests/test_docker_implementation.py b/tests/test_docker_implementation.py index d93ca9ff4..ad99976e3 100644 --- a/tests/test_docker_implementation.py +++ b/tests/test_docker_implementation.py @@ -310,13 +310,11 @@ def temp_project_dir(): # Create base files (temp_path / "server.py").write_text("# Mock server.py") - (temp_path / "Dockerfile").write_text( - """ + (temp_path / "Dockerfile").write_text(""" FROM python:3.11-slim COPY server.py /app/ CMD ["python", "/app/server.py"] -""" - ) +""") yield temp_path diff --git a/tests/test_prompt_regression.py b/tests/test_prompt_regression.py index bf40164c7..a2bdf45c7 100644 --- a/tests/test_prompt_regression.py +++ b/tests/test_prompt_regression.py @@ -86,16 +86,14 @@ async def test_chat_with_files(self): # Create a temporary Python file for testing with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f: - f.write( - """ + f.write(""" def hello_world(): \"\"\"A simple hello world function.\"\"\" return "Hello, World!" if __name__ == "__main__": print(hello_world()) -""" - ) +""") temp_file = f.name try: @@ -155,8 +153,7 @@ async def test_codereview_normal_review(self): # Create a temporary Python file for testing with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f: - f.write( - """ + f.write(""" def process_user_input(user_input): # Potentially unsafe code for demonstration query = f"SELECT * FROM users WHERE name = '{user_input}'" @@ -166,8 +163,7 @@ def main(): user_name = input("Enter name: ") result = process_user_input(user_name) print(result) -""" - ) +""") temp_file = f.name try: @@ -241,8 +237,7 @@ async def test_analyze_normal_question(self): # Create a temporary Python file demonstrating MVC pattern with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f: - f.write( - """ + f.write(""" # Model class User: def __init__(self, name, email): @@ -262,8 +257,7 @@ def __init__(self, model, view): def get_user_display(self): return self.view.display_user(self.model) -""" - ) +""") temp_file = f.name try: diff --git a/tools/listmodels.py b/tools/listmodels.py index 120afc189..09615fbe7 100644 --- a/tools/listmodels.py +++ b/tools/listmodels.py @@ -102,6 +102,7 @@ async def execute(self, arguments: dict[str, Any]) -> list[TextContent]: ProviderType.OPENAI: {"name": "OpenAI", "env_key": "OPENAI_API_KEY"}, ProviderType.AZURE: {"name": "Azure OpenAI", "env_key": "AZURE_OPENAI_API_KEY"}, ProviderType.XAI: {"name": "X.AI (Grok)", "env_key": "XAI_API_KEY"}, + ProviderType.CEREBRAS: {"name": "Cerebras", "env_key": "CEREBRAS_API_KEY"}, ProviderType.DIAL: {"name": "AI DIAL", "env_key": "DIAL_API_KEY"}, } diff --git a/utils/model_restrictions.py b/utils/model_restrictions.py index c06839156..8ab7eb660 100644 --- a/utils/model_restrictions.py +++ b/utils/model_restrictions.py @@ -10,6 +10,7 @@ - OPENAI_ALLOWED_MODELS: Comma-separated list of allowed OpenAI models - GOOGLE_ALLOWED_MODELS: Comma-separated list of allowed Gemini models - XAI_ALLOWED_MODELS: Comma-separated list of allowed X.AI GROK models +- CEREBRAS_ALLOWED_MODELS: Comma-separated list of allowed Cerebras models - OPENROUTER_ALLOWED_MODELS: Comma-separated list of allowed OpenRouter models - DIAL_ALLOWED_MODELS: Comma-separated list of allowed DIAL models @@ -52,6 +53,7 @@ class ModelRestrictionService: ProviderType.OPENAI: "OPENAI_ALLOWED_MODELS", ProviderType.GOOGLE: "GOOGLE_ALLOWED_MODELS", ProviderType.XAI: "XAI_ALLOWED_MODELS", + ProviderType.CEREBRAS: "CEREBRAS_ALLOWED_MODELS", ProviderType.OPENROUTER: "OPENROUTER_ALLOWED_MODELS", ProviderType.DIAL: "DIAL_ALLOWED_MODELS", }