diff --git a/src/ragas/llms/base.py b/src/ragas/llms/base.py index fa6dca6cc..126d52e0c 100644 --- a/src/ragas/llms/base.py +++ b/src/ragas/llms/base.py @@ -440,6 +440,62 @@ def __repr__(self) -> str: return f"{self.__class__.__name__}(llm={self.llm.__class__.__name__}(...))" +def _patch_client_for_provider(client: t.Any, provider: str) -> t.Any: + """ + Patch a client with Instructor for generic providers. + + Maps provider names to Provider enum and instantiates Instructor/AsyncInstructor. + Supports anthropic, google, and any other provider Instructor recognizes. + """ + from instructor import Provider + + provider_map = { + "anthropic": Provider.ANTHROPIC, + "google": Provider.GENAI, + "gemini": Provider.GENAI, + "azure": Provider.OPENAI, + "groq": Provider.GROQ, + "mistral": Provider.MISTRAL, + "cohere": Provider.COHERE, + "xai": Provider.XAI, + "bedrock": Provider.BEDROCK, + "deepseek": Provider.DEEPSEEK, + } + + provider_enum = provider_map.get(provider, Provider.OPENAI) + + if hasattr(client, "acompletion"): + return instructor.AsyncInstructor( + client=client, + create=client.messages.create, + provider=provider_enum, + ) + else: + return instructor.Instructor( + client=client, + create=client.messages.create, + provider=provider_enum, + ) + + +def _get_instructor_client(client: t.Any, provider: str) -> t.Any: + """ + Get an instructor-patched client for the specified provider. + + Uses provider-specific methods when available, falls back to generic patcher. + """ + provider_lower = provider.lower() + + if provider_lower == "openai": + return instructor.from_openai(client) + elif provider_lower == "litellm": + return instructor.from_litellm(client) + elif provider_lower == "perplexity": + return instructor.from_perplexity(client) + else: + return _patch_client_for_provider(client, provider_lower) + + def llm_factory( model: str, provider: str = "openai", @@ -455,8 +511,8 @@ def llm_factory( Args: model: Model name (e.g., "gpt-4o", "gpt-4o-mini", "claude-3-sonnet"). - provider: LLM provider. Default: "openai". - Supported: openai, anthropic, google, litellm. + provider: LLM provider (default: "openai"). + Can be any provider supported by Instructor: openai, anthropic, google, litellm, etc. client: Pre-initialized client instance (required). For OpenAI, can be OpenAI(...) or AsyncOpenAI(...). **kwargs: Additional model arguments (temperature, max_tokens, top_p, etc). @@ -471,13 +527,18 @@ def llm_factory( from openai import OpenAI client = OpenAI(api_key="...") - llm = llm_factory("gpt-4o", client=client) + llm = llm_factory("gpt-4o-mini", client=client) response = llm.generate(prompt, ResponseModel) + # Anthropic + from anthropic import Anthropic + client = Anthropic(api_key="...") + llm = llm_factory("claude-3-sonnet", provider="anthropic", client=client) + # Async from openai import AsyncOpenAI client = AsyncOpenAI(api_key="...") - llm = llm_factory("gpt-4o", client=client) + llm = llm_factory("gpt-4o-mini", client=client) response = await llm.agenerate(prompt, ResponseModel) """ if client is None: @@ -496,21 +557,8 @@ def llm_factory( provider_lower = provider.lower() - instructor_funcs = { - "openai": lambda c: instructor.from_openai(c), - "anthropic": lambda c: instructor.from_anthropic(c), - "google": lambda c: instructor.from_gemini(c), - "litellm": lambda c: instructor.from_litellm(c), - } - - if provider_lower not in instructor_funcs: - raise ValueError( - f"Unsupported provider: '{provider}'. " - f"Supported: {', '.join(instructor_funcs.keys())}" - ) - try: - patched_client = instructor_funcs[provider_lower](client) + patched_client = _get_instructor_client(client, provider_lower) except Exception as e: raise ValueError( f"Failed to initialize {provider} client with instructor. " @@ -603,7 +651,7 @@ def _map_provider_params(self) -> t.Dict[str, t.Any]: Each provider may have different parameter requirements: - Google: Wraps parameters in generation_config and renames max_tokens - - OpenAI: Maps max_tokens to max_completion_tokens for o-series models + - OpenAI/Azure: Maps max_tokens to max_completion_tokens for o-series models - Anthropic: No special handling required (pass-through) - LiteLLM: No special handling required (routes internally, pass-through) """ @@ -611,21 +659,25 @@ def _map_provider_params(self) -> t.Dict[str, t.Any]: if provider_lower == "google": return self._map_google_params() - elif provider_lower == "openai": + elif provider_lower in ("openai", "azure"): return self._map_openai_params() else: - # Anthropic, LiteLLM - pass through unchanged + # Anthropic, LiteLLM, and other providers - pass through unchanged return self.model_args.copy() def _map_openai_params(self) -> t.Dict[str, t.Any]: - """Map parameters for OpenAI reasoning models with special constraints. + """Map parameters for OpenAI/Azure reasoning models with special constraints. Reasoning models (o-series and gpt-5 series) have unique requirements: 1. max_tokens must be mapped to max_completion_tokens 2. temperature must be set to 1.0 (only supported value) 3. top_p parameter must be removed (not supported) - Legacy OpenAI models (gpt-4, gpt-4o, etc.) continue to use max_tokens unchanged. + Legacy OpenAI/Azure models (gpt-4, gpt-4o, etc.) continue to use max_tokens unchanged. + + Note on Azure deployments: Some Azure deployments restrict temperature to 1.0. + If your Azure deployment has this constraint, pass temperature=1.0 explicitly: + llm_factory("gpt-4o-mini", provider="azure", client=client, temperature=1.0) For GPT-5 and o-series models with structured output (Pydantic models): - Default max_tokens=1024 may not be sufficient diff --git a/tests/unit/llms/test_instructor_factory.py b/tests/unit/llms/test_instructor_factory.py index 8367aa607..8781e6a60 100644 --- a/tests/unit/llms/test_instructor_factory.py +++ b/tests/unit/llms/test_instructor_factory.py @@ -17,18 +17,22 @@ def __init__(self, is_async=False): self.is_async = is_async self.chat = Mock() self.chat.completions = Mock() + self.messages = Mock() + self.messages.create = Mock() if is_async: async def async_create(*args, **kwargs): return LLMResponseModel(response="Mock response") self.chat.completions.create = async_create + self.messages.create = async_create else: def sync_create(*args, **kwargs): return LLMResponseModel(response="Mock response") self.chat.completions.create = sync_create + self.messages.create = sync_create class MockInstructor: @@ -109,11 +113,12 @@ def mock_from_openai(client): assert llm.model_args.get("temperature") == 0.7 # type: ignore -def test_unsupported_provider(): - """Test that unsupported providers raise ValueError.""" +def test_unsupported_provider(monkeypatch): + """Test that invalid clients are handled gracefully for unknown providers.""" mock_client = Mock() + mock_client.messages = None - with pytest.raises(ValueError, match="Unsupported provider"): + with pytest.raises(ValueError, match="Failed to initialize"): llm_factory("test-model", provider="unsupported", client=mock_client) @@ -168,28 +173,29 @@ def mock_from_openai(client): asyncio.run(llm.agenerate("Test prompt", LLMResponseModel)) -def test_provider_support(): - """Test that all expected providers are supported.""" - supported_providers = { - "openai": "from_openai", - "anthropic": "from_anthropic", - "google": "from_gemini", - "litellm": "from_litellm", - } +def test_provider_support(monkeypatch): + """Test that major providers are supported.""" - for provider, func_name in supported_providers.items(): - mock_client = Mock() - - import instructor + # OpenAI and LiteLLM use provider-specific methods + def mock_from_openai(client): + return MockInstructor(client) - mock_instructor_func = Mock(return_value=MockInstructor(mock_client)) - setattr(instructor, func_name, mock_instructor_func) + def mock_from_litellm(client): + return MockInstructor(client) - try: - llm = llm_factory("test-model", provider=provider, client=mock_client) - assert llm.model == "test-model" # type: ignore - except Exception as e: - pytest.fail(f"Provider {provider} should be supported but got error: {e}") + monkeypatch.setattr("instructor.from_openai", mock_from_openai) + monkeypatch.setattr("instructor.from_litellm", mock_from_litellm) + + for provider in ["openai", "litellm"]: + mock_client = MockClient(is_async=False) + llm = llm_factory("test-model", provider=provider, client=mock_client) + assert llm.model == "test-model" # type: ignore + + # Anthropic and Google use generic wrapper + for provider in ["anthropic", "google"]: + mock_client = MockClient(is_async=False) + llm = llm_factory("test-model", provider=provider, client=mock_client) + assert llm.model == "test-model" # type: ignore def test_llm_model_args_storage(mock_sync_client, monkeypatch):