Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 75 additions & 23 deletions src/ragas/llms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,62 @@ def __repr__(self) -> str:
return f"{self.__class__.__name__}(llm={self.llm.__class__.__name__}(...))"


def _patch_client_for_provider(client: t.Any, provider: str) -> t.Any:
"""
Patch a client with Instructor for generic providers.

Maps provider names to Provider enum and instantiates Instructor/AsyncInstructor.
Supports anthropic, google, and any other provider Instructor recognizes.
"""
from instructor import Provider

provider_map = {
"anthropic": Provider.ANTHROPIC,
"google": Provider.GENAI,
"gemini": Provider.GENAI,
"azure": Provider.OPENAI,
"groq": Provider.GROQ,
"mistral": Provider.MISTRAL,
"cohere": Provider.COHERE,
"xai": Provider.XAI,
"bedrock": Provider.BEDROCK,
"deepseek": Provider.DEEPSEEK,
}

provider_enum = provider_map.get(provider, Provider.OPENAI)

if hasattr(client, "acompletion"):
return instructor.AsyncInstructor(
client=client,
create=client.messages.create,
provider=provider_enum,
)
else:
return instructor.Instructor(
client=client,
create=client.messages.create,
provider=provider_enum,
)


def _get_instructor_client(client: t.Any, provider: str) -> t.Any:
"""
Get an instructor-patched client for the specified provider.

Uses provider-specific methods when available, falls back to generic patcher.
"""
provider_lower = provider.lower()

if provider_lower == "openai":
return instructor.from_openai(client)
elif provider_lower == "litellm":
return instructor.from_litellm(client)
elif provider_lower == "perplexity":
return instructor.from_perplexity(client)
else:
return _patch_client_for_provider(client, provider_lower)


def llm_factory(
model: str,
provider: str = "openai",
Expand All @@ -455,8 +511,8 @@ def llm_factory(

Args:
model: Model name (e.g., "gpt-4o", "gpt-4o-mini", "claude-3-sonnet").
provider: LLM provider. Default: "openai".
Supported: openai, anthropic, google, litellm.
provider: LLM provider (default: "openai").
Can be any provider supported by Instructor: openai, anthropic, google, litellm, etc.
client: Pre-initialized client instance (required). For OpenAI, can be
OpenAI(...) or AsyncOpenAI(...).
**kwargs: Additional model arguments (temperature, max_tokens, top_p, etc).
Expand All @@ -471,13 +527,18 @@ def llm_factory(
from openai import OpenAI

client = OpenAI(api_key="...")
llm = llm_factory("gpt-4o", client=client)
llm = llm_factory("gpt-4o-mini", client=client)
response = llm.generate(prompt, ResponseModel)

# Anthropic
from anthropic import Anthropic
client = Anthropic(api_key="...")
llm = llm_factory("claude-3-sonnet", provider="anthropic", client=client)

# Async
from openai import AsyncOpenAI
client = AsyncOpenAI(api_key="...")
llm = llm_factory("gpt-4o", client=client)
llm = llm_factory("gpt-4o-mini", client=client)
response = await llm.agenerate(prompt, ResponseModel)
"""
if client is None:
Expand All @@ -496,21 +557,8 @@ def llm_factory(

provider_lower = provider.lower()

instructor_funcs = {
"openai": lambda c: instructor.from_openai(c),
"anthropic": lambda c: instructor.from_anthropic(c),
"google": lambda c: instructor.from_gemini(c),
"litellm": lambda c: instructor.from_litellm(c),
}

if provider_lower not in instructor_funcs:
raise ValueError(
f"Unsupported provider: '{provider}'. "
f"Supported: {', '.join(instructor_funcs.keys())}"
)

try:
patched_client = instructor_funcs[provider_lower](client)
patched_client = _get_instructor_client(client, provider_lower)
except Exception as e:
raise ValueError(
f"Failed to initialize {provider} client with instructor. "
Expand Down Expand Up @@ -603,29 +651,33 @@ def _map_provider_params(self) -> t.Dict[str, t.Any]:

Each provider may have different parameter requirements:
- Google: Wraps parameters in generation_config and renames max_tokens
- OpenAI: Maps max_tokens to max_completion_tokens for o-series models
- OpenAI/Azure: Maps max_tokens to max_completion_tokens for o-series models
- Anthropic: No special handling required (pass-through)
- LiteLLM: No special handling required (routes internally, pass-through)
"""
provider_lower = self.provider.lower()

if provider_lower == "google":
return self._map_google_params()
elif provider_lower == "openai":
elif provider_lower in ("openai", "azure"):
return self._map_openai_params()
else:
# Anthropic, LiteLLM - pass through unchanged
# Anthropic, LiteLLM, and other providers - pass through unchanged
return self.model_args.copy()

def _map_openai_params(self) -> t.Dict[str, t.Any]:
"""Map parameters for OpenAI reasoning models with special constraints.
"""Map parameters for OpenAI/Azure reasoning models with special constraints.

Reasoning models (o-series and gpt-5 series) have unique requirements:
1. max_tokens must be mapped to max_completion_tokens
2. temperature must be set to 1.0 (only supported value)
3. top_p parameter must be removed (not supported)

Legacy OpenAI models (gpt-4, gpt-4o, etc.) continue to use max_tokens unchanged.
Legacy OpenAI/Azure models (gpt-4, gpt-4o, etc.) continue to use max_tokens unchanged.

Note on Azure deployments: Some Azure deployments restrict temperature to 1.0.
If your Azure deployment has this constraint, pass temperature=1.0 explicitly:
llm_factory("gpt-4o-mini", provider="azure", client=client, temperature=1.0)

For GPT-5 and o-series models with structured output (Pydantic models):
- Default max_tokens=1024 may not be sufficient
Expand Down
50 changes: 28 additions & 22 deletions tests/unit/llms/test_instructor_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,22 @@ def __init__(self, is_async=False):
self.is_async = is_async
self.chat = Mock()
self.chat.completions = Mock()
self.messages = Mock()
self.messages.create = Mock()
if is_async:

async def async_create(*args, **kwargs):
return LLMResponseModel(response="Mock response")

self.chat.completions.create = async_create
self.messages.create = async_create
else:

def sync_create(*args, **kwargs):
return LLMResponseModel(response="Mock response")

self.chat.completions.create = sync_create
self.messages.create = sync_create


class MockInstructor:
Expand Down Expand Up @@ -109,11 +113,12 @@ def mock_from_openai(client):
assert llm.model_args.get("temperature") == 0.7 # type: ignore


def test_unsupported_provider():
"""Test that unsupported providers raise ValueError."""
def test_unsupported_provider(monkeypatch):
"""Test that invalid clients are handled gracefully for unknown providers."""
mock_client = Mock()
mock_client.messages = None

with pytest.raises(ValueError, match="Unsupported provider"):
with pytest.raises(ValueError, match="Failed to initialize"):
llm_factory("test-model", provider="unsupported", client=mock_client)


Expand Down Expand Up @@ -168,28 +173,29 @@ def mock_from_openai(client):
asyncio.run(llm.agenerate("Test prompt", LLMResponseModel))


def test_provider_support():
"""Test that all expected providers are supported."""
supported_providers = {
"openai": "from_openai",
"anthropic": "from_anthropic",
"google": "from_gemini",
"litellm": "from_litellm",
}
def test_provider_support(monkeypatch):
"""Test that major providers are supported."""

for provider, func_name in supported_providers.items():
mock_client = Mock()

import instructor
# OpenAI and LiteLLM use provider-specific methods
def mock_from_openai(client):
return MockInstructor(client)

mock_instructor_func = Mock(return_value=MockInstructor(mock_client))
setattr(instructor, func_name, mock_instructor_func)
def mock_from_litellm(client):
return MockInstructor(client)

try:
llm = llm_factory("test-model", provider=provider, client=mock_client)
assert llm.model == "test-model" # type: ignore
except Exception as e:
pytest.fail(f"Provider {provider} should be supported but got error: {e}")
monkeypatch.setattr("instructor.from_openai", mock_from_openai)
monkeypatch.setattr("instructor.from_litellm", mock_from_litellm)

for provider in ["openai", "litellm"]:
mock_client = MockClient(is_async=False)
llm = llm_factory("test-model", provider=provider, client=mock_client)
assert llm.model == "test-model" # type: ignore

# Anthropic and Google use generic wrapper
for provider in ["anthropic", "google"]:
mock_client = MockClient(is_async=False)
llm = llm_factory("test-model", provider=provider, client=mock_client)
assert llm.model == "test-model" # type: ignore


def test_llm_model_args_storage(mock_sync_client, monkeypatch):
Expand Down
Loading