Skip to content

Commit 59dd875

Browse files
committed
feat: migrate to instructor.from_provider for universal provider support (#2424)
A fix to support latest instructor, as they removed `from_anthropic` and `from_gemini` methods for a more standard `from_provider`. Ref: [PR 1898](567-labs/instructor#1898) Also added support for #2422
1 parent 10311c0 commit 59dd875

File tree

2 files changed

+103
-45
lines changed

2 files changed

+103
-45
lines changed

src/ragas/llms/base.py

Lines changed: 75 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -590,6 +590,62 @@ def __repr__(self) -> str:
590590
return f"{self.__class__.__name__}(llm={self.llm.__class__.__name__}(...))"
591591

592592

593+
def _patch_client_for_provider(client: t.Any, provider: str) -> t.Any:
594+
"""
595+
Patch a client with Instructor for generic providers.
596+
597+
Maps provider names to Provider enum and instantiates Instructor/AsyncInstructor.
598+
Supports anthropic, google, and any other provider Instructor recognizes.
599+
"""
600+
from instructor import Provider
601+
602+
provider_map = {
603+
"anthropic": Provider.ANTHROPIC,
604+
"google": Provider.GENAI,
605+
"gemini": Provider.GENAI,
606+
"azure": Provider.OPENAI,
607+
"groq": Provider.GROQ,
608+
"mistral": Provider.MISTRAL,
609+
"cohere": Provider.COHERE,
610+
"xai": Provider.XAI,
611+
"bedrock": Provider.BEDROCK,
612+
"deepseek": Provider.DEEPSEEK,
613+
}
614+
615+
provider_enum = provider_map.get(provider, Provider.OPENAI)
616+
617+
if hasattr(client, "acompletion"):
618+
return instructor.AsyncInstructor(
619+
client=client,
620+
create=client.messages.create,
621+
provider=provider_enum,
622+
)
623+
else:
624+
return instructor.Instructor(
625+
client=client,
626+
create=client.messages.create,
627+
provider=provider_enum,
628+
)
629+
630+
631+
def _get_instructor_client(client: t.Any, provider: str) -> t.Any:
632+
"""
633+
Get an instructor-patched client for the specified provider.
634+
635+
Uses provider-specific methods when available, falls back to generic patcher.
636+
"""
637+
provider_lower = provider.lower()
638+
639+
if provider_lower == "openai":
640+
return instructor.from_openai(client)
641+
elif provider_lower == "litellm":
642+
return instructor.from_litellm(client)
643+
elif provider_lower == "perplexity":
644+
return instructor.from_perplexity(client)
645+
else:
646+
return _patch_client_for_provider(client, provider_lower)
647+
648+
593649
def llm_factory(
594650
model: str,
595651
provider: str = "openai",
@@ -605,8 +661,8 @@ def llm_factory(
605661
606662
Args:
607663
model: Model name (e.g., "gpt-4o", "gpt-4o-mini", "claude-3-sonnet").
608-
provider: LLM provider. Default: "openai".
609-
Supported: openai, anthropic, google, litellm.
664+
provider: LLM provider (default: "openai").
665+
Can be any provider supported by Instructor: openai, anthropic, google, litellm, etc.
610666
client: Pre-initialized client instance (required). For OpenAI, can be
611667
OpenAI(...) or AsyncOpenAI(...).
612668
**kwargs: Additional model arguments (temperature, max_tokens, top_p, etc).
@@ -621,13 +677,18 @@ def llm_factory(
621677
from openai import OpenAI
622678
623679
client = OpenAI(api_key="...")
624-
llm = llm_factory("gpt-4o", client=client)
680+
llm = llm_factory("gpt-4o-mini", client=client)
625681
response = llm.generate(prompt, ResponseModel)
626682
683+
# Anthropic
684+
from anthropic import Anthropic
685+
client = Anthropic(api_key="...")
686+
llm = llm_factory("claude-3-sonnet", provider="anthropic", client=client)
687+
627688
# Async
628689
from openai import AsyncOpenAI
629690
client = AsyncOpenAI(api_key="...")
630-
llm = llm_factory("gpt-4o", client=client)
691+
llm = llm_factory("gpt-4o-mini", client=client)
631692
response = await llm.agenerate(prompt, ResponseModel)
632693
"""
633694
if client is None:
@@ -646,21 +707,8 @@ def llm_factory(
646707

647708
provider_lower = provider.lower()
648709

649-
instructor_funcs = {
650-
"openai": lambda c: instructor.from_openai(c),
651-
"anthropic": lambda c: instructor.from_anthropic(c),
652-
"google": lambda c: instructor.from_gemini(c),
653-
"litellm": lambda c: instructor.from_litellm(c),
654-
}
655-
656-
if provider_lower not in instructor_funcs:
657-
raise ValueError(
658-
f"Unsupported provider: '{provider}'. "
659-
f"Supported: {', '.join(instructor_funcs.keys())}"
660-
)
661-
662710
try:
663-
patched_client = instructor_funcs[provider_lower](client)
711+
patched_client = _get_instructor_client(client, provider_lower)
664712
except Exception as e:
665713
raise ValueError(
666714
f"Failed to initialize {provider} client with instructor. "
@@ -753,29 +801,33 @@ def _map_provider_params(self) -> t.Dict[str, t.Any]:
753801
754802
Each provider may have different parameter requirements:
755803
- Google: Wraps parameters in generation_config and renames max_tokens
756-
- OpenAI: Maps max_tokens to max_completion_tokens for o-series models
804+
- OpenAI/Azure: Maps max_tokens to max_completion_tokens for o-series models
757805
- Anthropic: No special handling required (pass-through)
758806
- LiteLLM: No special handling required (routes internally, pass-through)
759807
"""
760808
provider_lower = self.provider.lower()
761809

762810
if provider_lower == "google":
763811
return self._map_google_params()
764-
elif provider_lower == "openai":
812+
elif provider_lower in ("openai", "azure"):
765813
return self._map_openai_params()
766814
else:
767-
# Anthropic, LiteLLM - pass through unchanged
815+
# Anthropic, LiteLLM, and other providers - pass through unchanged
768816
return self.model_args.copy()
769817

770818
def _map_openai_params(self) -> t.Dict[str, t.Any]:
771-
"""Map parameters for OpenAI reasoning models with special constraints.
819+
"""Map parameters for OpenAI/Azure reasoning models with special constraints.
772820
773821
Reasoning models (o-series and gpt-5 series) have unique requirements:
774822
1. max_tokens must be mapped to max_completion_tokens
775823
2. temperature must be set to 1.0 (only supported value)
776824
3. top_p parameter must be removed (not supported)
777825
778-
Legacy OpenAI models (gpt-4, gpt-4o, etc.) continue to use max_tokens unchanged.
826+
Legacy OpenAI/Azure models (gpt-4, gpt-4o, etc.) continue to use max_tokens unchanged.
827+
828+
Note on Azure deployments: Some Azure deployments restrict temperature to 1.0.
829+
If your Azure deployment has this constraint, pass temperature=1.0 explicitly:
830+
llm_factory("gpt-4o-mini", provider="azure", client=client, temperature=1.0)
779831
780832
For GPT-5 and o-series models with structured output (Pydantic models):
781833
- Default max_tokens=1024 may not be sufficient

tests/unit/llms/test_instructor_factory.py

Lines changed: 28 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,22 @@ def __init__(self, is_async=False):
1717
self.is_async = is_async
1818
self.chat = Mock()
1919
self.chat.completions = Mock()
20+
self.messages = Mock()
21+
self.messages.create = Mock()
2022
if is_async:
2123

2224
async def async_create(*args, **kwargs):
2325
return LLMResponseModel(response="Mock response")
2426

2527
self.chat.completions.create = async_create
28+
self.messages.create = async_create
2629
else:
2730

2831
def sync_create(*args, **kwargs):
2932
return LLMResponseModel(response="Mock response")
3033

3134
self.chat.completions.create = sync_create
35+
self.messages.create = sync_create
3236

3337

3438
class MockInstructor:
@@ -109,11 +113,12 @@ def mock_from_openai(client):
109113
assert llm.model_args.get("temperature") == 0.7 # type: ignore
110114

111115

112-
def test_unsupported_provider():
113-
"""Test that unsupported providers raise ValueError."""
116+
def test_unsupported_provider(monkeypatch):
117+
"""Test that invalid clients are handled gracefully for unknown providers."""
114118
mock_client = Mock()
119+
mock_client.messages = None
115120

116-
with pytest.raises(ValueError, match="Unsupported provider"):
121+
with pytest.raises(ValueError, match="Failed to initialize"):
117122
llm_factory("test-model", provider="unsupported", client=mock_client)
118123

119124

@@ -168,28 +173,29 @@ def mock_from_openai(client):
168173
asyncio.run(llm.agenerate("Test prompt", LLMResponseModel))
169174

170175

171-
def test_provider_support():
172-
"""Test that all expected providers are supported."""
173-
supported_providers = {
174-
"openai": "from_openai",
175-
"anthropic": "from_anthropic",
176-
"google": "from_gemini",
177-
"litellm": "from_litellm",
178-
}
176+
def test_provider_support(monkeypatch):
177+
"""Test that major providers are supported."""
179178

180-
for provider, func_name in supported_providers.items():
181-
mock_client = Mock()
182-
183-
import instructor
179+
# OpenAI and LiteLLM use provider-specific methods
180+
def mock_from_openai(client):
181+
return MockInstructor(client)
184182

185-
mock_instructor_func = Mock(return_value=MockInstructor(mock_client))
186-
setattr(instructor, func_name, mock_instructor_func)
183+
def mock_from_litellm(client):
184+
return MockInstructor(client)
187185

188-
try:
189-
llm = llm_factory("test-model", provider=provider, client=mock_client)
190-
assert llm.model == "test-model" # type: ignore
191-
except Exception as e:
192-
pytest.fail(f"Provider {provider} should be supported but got error: {e}")
186+
monkeypatch.setattr("instructor.from_openai", mock_from_openai)
187+
monkeypatch.setattr("instructor.from_litellm", mock_from_litellm)
188+
189+
for provider in ["openai", "litellm"]:
190+
mock_client = MockClient(is_async=False)
191+
llm = llm_factory("test-model", provider=provider, client=mock_client)
192+
assert llm.model == "test-model" # type: ignore
193+
194+
# Anthropic and Google use generic wrapper
195+
for provider in ["anthropic", "google"]:
196+
mock_client = MockClient(is_async=False)
197+
llm = llm_factory("test-model", provider=provider, client=mock_client)
198+
assert llm.model == "test-model" # type: ignore
193199

194200

195201
def test_llm_model_args_storage(mock_sync_client, monkeypatch):

0 commit comments

Comments
 (0)