Skip to content

Commit 1ffc6d5

Browse files
authored
feat: migrate to instructor.from_provider for universal provider support (#2424)
A fix to support latest instructor, as they removed `from_anthropic` and `from_gemini` methods for a more standard `from_provider`. Ref: [PR 1898](567-labs/instructor#1898) Also added support for #2422
1 parent 41cc83b commit 1ffc6d5

File tree

2 files changed

+103
-45
lines changed

2 files changed

+103
-45
lines changed

src/ragas/llms/base.py

Lines changed: 75 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -440,6 +440,62 @@ def __repr__(self) -> str:
440440
return f"{self.__class__.__name__}(llm={self.llm.__class__.__name__}(...))"
441441

442442

443+
def _patch_client_for_provider(client: t.Any, provider: str) -> t.Any:
444+
"""
445+
Patch a client with Instructor for generic providers.
446+
447+
Maps provider names to Provider enum and instantiates Instructor/AsyncInstructor.
448+
Supports anthropic, google, and any other provider Instructor recognizes.
449+
"""
450+
from instructor import Provider
451+
452+
provider_map = {
453+
"anthropic": Provider.ANTHROPIC,
454+
"google": Provider.GENAI,
455+
"gemini": Provider.GENAI,
456+
"azure": Provider.OPENAI,
457+
"groq": Provider.GROQ,
458+
"mistral": Provider.MISTRAL,
459+
"cohere": Provider.COHERE,
460+
"xai": Provider.XAI,
461+
"bedrock": Provider.BEDROCK,
462+
"deepseek": Provider.DEEPSEEK,
463+
}
464+
465+
provider_enum = provider_map.get(provider, Provider.OPENAI)
466+
467+
if hasattr(client, "acompletion"):
468+
return instructor.AsyncInstructor(
469+
client=client,
470+
create=client.messages.create,
471+
provider=provider_enum,
472+
)
473+
else:
474+
return instructor.Instructor(
475+
client=client,
476+
create=client.messages.create,
477+
provider=provider_enum,
478+
)
479+
480+
481+
def _get_instructor_client(client: t.Any, provider: str) -> t.Any:
482+
"""
483+
Get an instructor-patched client for the specified provider.
484+
485+
Uses provider-specific methods when available, falls back to generic patcher.
486+
"""
487+
provider_lower = provider.lower()
488+
489+
if provider_lower == "openai":
490+
return instructor.from_openai(client)
491+
elif provider_lower == "litellm":
492+
return instructor.from_litellm(client)
493+
elif provider_lower == "perplexity":
494+
return instructor.from_perplexity(client)
495+
else:
496+
return _patch_client_for_provider(client, provider_lower)
497+
498+
443499
def llm_factory(
444500
model: str,
445501
provider: str = "openai",
@@ -455,8 +511,8 @@ def llm_factory(
455511
456512
Args:
457513
model: Model name (e.g., "gpt-4o", "gpt-4o-mini", "claude-3-sonnet").
458-
provider: LLM provider. Default: "openai".
459-
Supported: openai, anthropic, google, litellm.
514+
provider: LLM provider (default: "openai").
515+
Can be any provider supported by Instructor: openai, anthropic, google, litellm, etc.
460516
client: Pre-initialized client instance (required). For OpenAI, can be
461517
OpenAI(...) or AsyncOpenAI(...).
462518
**kwargs: Additional model arguments (temperature, max_tokens, top_p, etc).
@@ -471,13 +527,18 @@ def llm_factory(
471527
from openai import OpenAI
472528
473529
client = OpenAI(api_key="...")
474-
llm = llm_factory("gpt-4o", client=client)
530+
llm = llm_factory("gpt-4o-mini", client=client)
475531
response = llm.generate(prompt, ResponseModel)
476532
533+
# Anthropic
534+
from anthropic import Anthropic
535+
client = Anthropic(api_key="...")
536+
llm = llm_factory("claude-3-sonnet", provider="anthropic", client=client)
537+
477538
# Async
478539
from openai import AsyncOpenAI
479540
client = AsyncOpenAI(api_key="...")
480-
llm = llm_factory("gpt-4o", client=client)
541+
llm = llm_factory("gpt-4o-mini", client=client)
481542
response = await llm.agenerate(prompt, ResponseModel)
482543
"""
483544
if client is None:
@@ -496,21 +557,8 @@ def llm_factory(
496557

497558
provider_lower = provider.lower()
498559

499-
instructor_funcs = {
500-
"openai": lambda c: instructor.from_openai(c),
501-
"anthropic": lambda c: instructor.from_anthropic(c),
502-
"google": lambda c: instructor.from_gemini(c),
503-
"litellm": lambda c: instructor.from_litellm(c),
504-
}
505-
506-
if provider_lower not in instructor_funcs:
507-
raise ValueError(
508-
f"Unsupported provider: '{provider}'. "
509-
f"Supported: {', '.join(instructor_funcs.keys())}"
510-
)
511-
512560
try:
513-
patched_client = instructor_funcs[provider_lower](client)
561+
patched_client = _get_instructor_client(client, provider_lower)
514562
except Exception as e:
515563
raise ValueError(
516564
f"Failed to initialize {provider} client with instructor. "
@@ -603,29 +651,33 @@ def _map_provider_params(self) -> t.Dict[str, t.Any]:
603651
604652
Each provider may have different parameter requirements:
605653
- Google: Wraps parameters in generation_config and renames max_tokens
606-
- OpenAI: Maps max_tokens to max_completion_tokens for o-series models
654+
- OpenAI/Azure: Maps max_tokens to max_completion_tokens for o-series models
607655
- Anthropic: No special handling required (pass-through)
608656
- LiteLLM: No special handling required (routes internally, pass-through)
609657
"""
610658
provider_lower = self.provider.lower()
611659

612660
if provider_lower == "google":
613661
return self._map_google_params()
614-
elif provider_lower == "openai":
662+
elif provider_lower in ("openai", "azure"):
615663
return self._map_openai_params()
616664
else:
617-
# Anthropic, LiteLLM - pass through unchanged
665+
# Anthropic, LiteLLM, and other providers - pass through unchanged
618666
return self.model_args.copy()
619667

620668
def _map_openai_params(self) -> t.Dict[str, t.Any]:
621-
"""Map parameters for OpenAI reasoning models with special constraints.
669+
"""Map parameters for OpenAI/Azure reasoning models with special constraints.
622670
623671
Reasoning models (o-series and gpt-5 series) have unique requirements:
624672
1. max_tokens must be mapped to max_completion_tokens
625673
2. temperature must be set to 1.0 (only supported value)
626674
3. top_p parameter must be removed (not supported)
627675
628-
Legacy OpenAI models (gpt-4, gpt-4o, etc.) continue to use max_tokens unchanged.
676+
Legacy OpenAI/Azure models (gpt-4, gpt-4o, etc.) continue to use max_tokens unchanged.
677+
678+
Note on Azure deployments: Some Azure deployments restrict temperature to 1.0.
679+
If your Azure deployment has this constraint, pass temperature=1.0 explicitly:
680+
llm_factory("gpt-4o-mini", provider="azure", client=client, temperature=1.0)
629681
630682
For GPT-5 and o-series models with structured output (Pydantic models):
631683
- Default max_tokens=1024 may not be sufficient

tests/unit/llms/test_instructor_factory.py

Lines changed: 28 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,22 @@ def __init__(self, is_async=False):
1717
self.is_async = is_async
1818
self.chat = Mock()
1919
self.chat.completions = Mock()
20+
self.messages = Mock()
21+
self.messages.create = Mock()
2022
if is_async:
2123

2224
async def async_create(*args, **kwargs):
2325
return LLMResponseModel(response="Mock response")
2426

2527
self.chat.completions.create = async_create
28+
self.messages.create = async_create
2629
else:
2730

2831
def sync_create(*args, **kwargs):
2932
return LLMResponseModel(response="Mock response")
3033

3134
self.chat.completions.create = sync_create
35+
self.messages.create = sync_create
3236

3337

3438
class MockInstructor:
@@ -109,11 +113,12 @@ def mock_from_openai(client):
109113
assert llm.model_args.get("temperature") == 0.7 # type: ignore
110114

111115

112-
def test_unsupported_provider():
113-
"""Test that unsupported providers raise ValueError."""
116+
def test_unsupported_provider(monkeypatch):
117+
"""Test that invalid clients are handled gracefully for unknown providers."""
114118
mock_client = Mock()
119+
mock_client.messages = None
115120

116-
with pytest.raises(ValueError, match="Unsupported provider"):
121+
with pytest.raises(ValueError, match="Failed to initialize"):
117122
llm_factory("test-model", provider="unsupported", client=mock_client)
118123

119124

@@ -168,28 +173,29 @@ def mock_from_openai(client):
168173
asyncio.run(llm.agenerate("Test prompt", LLMResponseModel))
169174

170175

171-
def test_provider_support():
172-
"""Test that all expected providers are supported."""
173-
supported_providers = {
174-
"openai": "from_openai",
175-
"anthropic": "from_anthropic",
176-
"google": "from_gemini",
177-
"litellm": "from_litellm",
178-
}
176+
def test_provider_support(monkeypatch):
177+
"""Test that major providers are supported."""
179178

180-
for provider, func_name in supported_providers.items():
181-
mock_client = Mock()
182-
183-
import instructor
179+
# OpenAI and LiteLLM use provider-specific methods
180+
def mock_from_openai(client):
181+
return MockInstructor(client)
184182

185-
mock_instructor_func = Mock(return_value=MockInstructor(mock_client))
186-
setattr(instructor, func_name, mock_instructor_func)
183+
def mock_from_litellm(client):
184+
return MockInstructor(client)
187185

188-
try:
189-
llm = llm_factory("test-model", provider=provider, client=mock_client)
190-
assert llm.model == "test-model" # type: ignore
191-
except Exception as e:
192-
pytest.fail(f"Provider {provider} should be supported but got error: {e}")
186+
monkeypatch.setattr("instructor.from_openai", mock_from_openai)
187+
monkeypatch.setattr("instructor.from_litellm", mock_from_litellm)
188+
189+
for provider in ["openai", "litellm"]:
190+
mock_client = MockClient(is_async=False)
191+
llm = llm_factory("test-model", provider=provider, client=mock_client)
192+
assert llm.model == "test-model" # type: ignore
193+
194+
# Anthropic and Google use generic wrapper
195+
for provider in ["anthropic", "google"]:
196+
mock_client = MockClient(is_async=False)
197+
llm = llm_factory("test-model", provider=provider, client=mock_client)
198+
assert llm.model == "test-model" # type: ignore
193199

194200

195201
def test_llm_model_args_storage(mock_sync_client, monkeypatch):

0 commit comments

Comments
 (0)