Skip to content

Commit 85d47a1

Browse files
mochow13Motta Kin
andauthored
Add LiteLLM provider for OpenAI API compatible models (#2606)
Co-authored-by: Motta Kin <[email protected]>
1 parent 881cd7a commit 85d47a1

File tree

9 files changed

+335
-0
lines changed

9 files changed

+335
-0
lines changed

docs/api/providers.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,5 @@
3939
::: pydantic_ai.providers.moonshotai.MoonshotAIProvider
4040

4141
::: pydantic_ai.providers.ollama.OllamaProvider
42+
43+
::: pydantic_ai.providers.litellm.LiteLLMProvider

docs/models/openai.md

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -563,3 +563,31 @@ result = agent.run_sync('What is the capital of France?')
563563
print(result.output)
564564
#> The capital of France is Paris.
565565
```
566+
567+
### LiteLLM
568+
569+
To use [LiteLLM](https://www.litellm.ai/), set the configs as outlined in the [doc](https://docs.litellm.ai/docs/set_keys). In `LiteLLMProvider`, you can pass `api_base` and `api_key`. The value of these configs will depend on your setup. For example, if you are using OpenAI models, then you need to pass `https://api.openai.com/v1` as the `api_base` and your OpenAI API key as the `api_key`. If you are using a LiteLLM proxy server running on your local machine, then you need to pass `http://localhost:<port>` as the `api_base` and your LiteLLM API key (or a placeholder) as the `api_key`.
570+
571+
To use custom LLMs, use `custom/` prefix in the model name.
572+
573+
Once you have the configs, use the [`LiteLLMProvider`][pydantic_ai.providers.litellm.LiteLLMProvider] as follows:
574+
575+
```python
576+
from pydantic_ai import Agent
577+
from pydantic_ai.models.openai import OpenAIChatModel
578+
from pydantic_ai.providers.litellm import LiteLLMProvider
579+
580+
model = OpenAIChatModel(
581+
'openai/gpt-3.5-turbo',
582+
provider=LiteLLMProvider(
583+
api_base='<api-base-url>',
584+
api_key='<api-key>'
585+
)
586+
)
587+
agent = Agent(model)
588+
589+
result = agent.run_sync('What is the capital of France?')
590+
print(result.output)
591+
#> The capital of France is Paris.
592+
...
593+
```

docs/models/overview.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ In addition, many providers are compatible with the OpenAI API, and can be used
2727
- [Heroku](openai.md#heroku-ai)
2828
- [GitHub Models](openai.md#github-models)
2929
- [Cerebras](openai.md#cerebras)
30+
- [LiteLLM](openai.md#litellm)
3031

3132
Pydantic AI also comes with [`TestModel`](../api/models/test.md) and [`FunctionModel`](../api/models/function.md)
3233
for testing and development.

pydantic_ai_slim/pydantic_ai/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -728,6 +728,7 @@ def infer_model(model: Model | KnownModelName | str) -> Model: # noqa: C901
728728
'openrouter',
729729
'together',
730730
'vercel',
731+
'litellm',
731732
):
732733
from .openai import OpenAIChatModel
733734

pydantic_ai_slim/pydantic_ai/models/openai.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,7 @@ def __init__(
225225
'openrouter',
226226
'together',
227227
'vercel',
228+
'litellm',
228229
]
229230
| Provider[AsyncOpenAI] = 'openai',
230231
profile: ModelProfileSpec | None = None,
@@ -252,6 +253,7 @@ def __init__(
252253
'openrouter',
253254
'together',
254255
'vercel',
256+
'litellm',
255257
]
256258
| Provider[AsyncOpenAI] = 'openai',
257259
profile: ModelProfileSpec | None = None,
@@ -278,6 +280,7 @@ def __init__(
278280
'openrouter',
279281
'together',
280282
'vercel',
283+
'litellm',
281284
]
282285
| Provider[AsyncOpenAI] = 'openai',
283286
profile: ModelProfileSpec | None = None,

pydantic_ai_slim/pydantic_ai/providers/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,10 @@ def infer_provider_class(provider: str) -> type[Provider[Any]]: # noqa: C901
135135
from .github import GitHubProvider
136136

137137
return GitHubProvider
138+
elif provider == 'litellm':
139+
from .litellm import LiteLLMProvider
140+
141+
return LiteLLMProvider
138142
else: # pragma: no cover
139143
raise ValueError(f'Unknown provider: {provider}')
140144

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
from __future__ import annotations as _annotations
2+
3+
from typing import overload
4+
5+
from httpx import AsyncClient as AsyncHTTPClient
6+
from openai import AsyncOpenAI
7+
8+
from pydantic_ai.models import cached_async_http_client
9+
from pydantic_ai.profiles import ModelProfile
10+
from pydantic_ai.profiles.amazon import amazon_model_profile
11+
from pydantic_ai.profiles.anthropic import anthropic_model_profile
12+
from pydantic_ai.profiles.cohere import cohere_model_profile
13+
from pydantic_ai.profiles.deepseek import deepseek_model_profile
14+
from pydantic_ai.profiles.google import google_model_profile
15+
from pydantic_ai.profiles.grok import grok_model_profile
16+
from pydantic_ai.profiles.groq import groq_model_profile
17+
from pydantic_ai.profiles.meta import meta_model_profile
18+
from pydantic_ai.profiles.mistral import mistral_model_profile
19+
from pydantic_ai.profiles.moonshotai import moonshotai_model_profile
20+
from pydantic_ai.profiles.openai import OpenAIJsonSchemaTransformer, OpenAIModelProfile, openai_model_profile
21+
from pydantic_ai.profiles.qwen import qwen_model_profile
22+
from pydantic_ai.providers import Provider
23+
24+
try:
25+
from openai import AsyncOpenAI
26+
except ImportError as _import_error: # pragma: no cover
27+
raise ImportError(
28+
'Please install the `openai` package to use the LiteLLM provider, '
29+
'you can use the `openai` optional group — `pip install "pydantic-ai-slim[openai]"`'
30+
) from _import_error
31+
32+
33+
class LiteLLMProvider(Provider[AsyncOpenAI]):
34+
"""Provider for LiteLLM API."""
35+
36+
@property
37+
def name(self) -> str:
38+
return 'litellm'
39+
40+
@property
41+
def base_url(self) -> str:
42+
return str(self.client.base_url)
43+
44+
@property
45+
def client(self) -> AsyncOpenAI:
46+
return self._client
47+
48+
def model_profile(self, model_name: str) -> ModelProfile | None:
49+
# Map provider prefixes to their profile functions
50+
provider_to_profile = {
51+
'anthropic': anthropic_model_profile,
52+
'openai': openai_model_profile,
53+
'google': google_model_profile,
54+
'mistralai': mistral_model_profile,
55+
'mistral': mistral_model_profile,
56+
'cohere': cohere_model_profile,
57+
'amazon': amazon_model_profile,
58+
'bedrock': amazon_model_profile,
59+
'meta-llama': meta_model_profile,
60+
'meta': meta_model_profile,
61+
'groq': groq_model_profile,
62+
'deepseek': deepseek_model_profile,
63+
'moonshotai': moonshotai_model_profile,
64+
'x-ai': grok_model_profile,
65+
'qwen': qwen_model_profile,
66+
}
67+
68+
profile = None
69+
70+
# Check if model name contains a provider prefix (e.g., "anthropic/claude-3")
71+
if '/' in model_name:
72+
provider_prefix, model_suffix = model_name.split('/', 1)
73+
if provider_prefix in provider_to_profile:
74+
profile = provider_to_profile[provider_prefix](model_suffix)
75+
76+
# If no profile found, default to OpenAI profile
77+
if profile is None:
78+
profile = openai_model_profile(model_name)
79+
80+
# As LiteLLMProvider is used with OpenAIModel, which uses OpenAIJsonSchemaTransformer,
81+
# we maintain that behavior
82+
return OpenAIModelProfile(json_schema_transformer=OpenAIJsonSchemaTransformer).update(profile)
83+
84+
@overload
85+
def __init__(
86+
self,
87+
*,
88+
api_key: str | None = None,
89+
api_base: str | None = None,
90+
) -> None: ...
91+
92+
@overload
93+
def __init__(
94+
self,
95+
*,
96+
api_key: str | None = None,
97+
api_base: str | None = None,
98+
http_client: AsyncHTTPClient,
99+
) -> None: ...
100+
101+
@overload
102+
def __init__(self, *, openai_client: AsyncOpenAI) -> None: ...
103+
104+
def __init__(
105+
self,
106+
*,
107+
api_key: str | None = None,
108+
api_base: str | None = None,
109+
openai_client: AsyncOpenAI | None = None,
110+
http_client: AsyncHTTPClient | None = None,
111+
) -> None:
112+
"""Initialize a LiteLLM provider.
113+
114+
Args:
115+
api_key: API key for the model provider. If None, LiteLLM will try to get it from environment variables.
116+
api_base: Base URL for the model provider. Use this for custom endpoints or self-hosted models.
117+
openai_client: Pre-configured OpenAI client. If provided, other parameters are ignored.
118+
http_client: Custom HTTP client to use.
119+
"""
120+
if openai_client is not None:
121+
self._client = openai_client
122+
return
123+
124+
# Create OpenAI client that will be used with LiteLLM's completion function
125+
# The actual API calls will be intercepted and routed through LiteLLM
126+
if http_client is not None:
127+
self._client = AsyncOpenAI(
128+
base_url=api_base, api_key=api_key or 'litellm-placeholder', http_client=http_client
129+
)
130+
else:
131+
http_client = cached_async_http_client(provider='litellm')
132+
self._client = AsyncOpenAI(
133+
base_url=api_base, api_key=api_key or 'litellm-placeholder', http_client=http_client
134+
)

tests/providers/test_litellm.py

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
import httpx
2+
import pytest
3+
from pytest_mock import MockerFixture
4+
5+
from ..conftest import try_import
6+
7+
with try_import() as imports_successful:
8+
from openai import AsyncOpenAI
9+
10+
from pydantic_ai.profiles.openai import OpenAIJsonSchemaTransformer, OpenAIModelProfile
11+
from pydantic_ai.providers.litellm import LiteLLMProvider
12+
13+
pytestmark = [
14+
pytest.mark.skipif(not imports_successful(), reason='OpenAI client not installed'),
15+
pytest.mark.anyio,
16+
]
17+
18+
19+
def test_init_with_api_config():
20+
provider = LiteLLMProvider(api_key='test-key', api_base='https://custom.litellm.com/v1')
21+
assert provider.base_url == 'https://custom.litellm.com/v1/'
22+
assert provider.client.api_key == 'test-key'
23+
24+
25+
def test_init_without_api_key():
26+
provider = LiteLLMProvider()
27+
assert provider.name == 'litellm'
28+
assert provider.client.api_key == 'litellm-placeholder'
29+
30+
31+
def test_init_with_openai_client():
32+
openai_client = AsyncOpenAI(api_key='custom-key', base_url='https://custom.openai.com/v1')
33+
provider = LiteLLMProvider(openai_client=openai_client)
34+
assert provider.client == openai_client
35+
assert provider.base_url == 'https://custom.openai.com/v1/'
36+
37+
38+
def test_model_profile_returns_openai_compatible_profile(mocker: MockerFixture):
39+
provider = LiteLLMProvider(api_key='test-key')
40+
41+
# Create a proper mock profile object that can be updated
42+
from dataclasses import dataclass
43+
44+
@dataclass
45+
class MockProfile:
46+
max_tokens: int = 4096
47+
supports_streaming: bool = True
48+
49+
mock_profile = MockProfile()
50+
mock_openai_profile = mocker.patch('pydantic_ai.providers.litellm.openai_model_profile', return_value=mock_profile)
51+
52+
profile = provider.model_profile('gpt-3.5-turbo')
53+
54+
# Verify openai_model_profile was called with the model name
55+
mock_openai_profile.assert_called_once_with('gpt-3.5-turbo')
56+
57+
# Verify the returned profile is an OpenAIModelProfile with OpenAIJsonSchemaTransformer
58+
assert isinstance(profile, OpenAIModelProfile)
59+
assert profile.json_schema_transformer == OpenAIJsonSchemaTransformer
60+
61+
62+
def test_model_profile_with_different_models(mocker: MockerFixture):
63+
provider = LiteLLMProvider(api_key='test-key')
64+
65+
# Create mocks for all profile functions
66+
from dataclasses import dataclass
67+
68+
@dataclass
69+
class MockProfile:
70+
max_tokens: int = 4096
71+
supports_streaming: bool = True
72+
73+
# Mock all profile functions
74+
mock_profiles = {
75+
'openai': mocker.patch('pydantic_ai.providers.litellm.openai_model_profile', return_value=MockProfile()),
76+
'anthropic': mocker.patch('pydantic_ai.providers.litellm.anthropic_model_profile', return_value=MockProfile()),
77+
'google': mocker.patch('pydantic_ai.providers.litellm.google_model_profile', return_value=MockProfile()),
78+
'meta': mocker.patch('pydantic_ai.providers.litellm.meta_model_profile', return_value=MockProfile()),
79+
'mistral': mocker.patch('pydantic_ai.providers.litellm.mistral_model_profile', return_value=MockProfile()),
80+
'cohere': mocker.patch('pydantic_ai.providers.litellm.cohere_model_profile', return_value=MockProfile()),
81+
'amazon': mocker.patch('pydantic_ai.providers.litellm.amazon_model_profile', return_value=MockProfile()),
82+
'deepseek': mocker.patch('pydantic_ai.providers.litellm.deepseek_model_profile', return_value=MockProfile()),
83+
'groq': mocker.patch('pydantic_ai.providers.litellm.groq_model_profile', return_value=MockProfile()),
84+
'grok': mocker.patch('pydantic_ai.providers.litellm.grok_model_profile', return_value=MockProfile()),
85+
'moonshotai': mocker.patch(
86+
'pydantic_ai.providers.litellm.moonshotai_model_profile', return_value=MockProfile()
87+
),
88+
'qwen': mocker.patch('pydantic_ai.providers.litellm.qwen_model_profile', return_value=MockProfile()),
89+
}
90+
91+
# Test models without provider prefix (should use openai profile)
92+
models_without_prefix = ['gpt-4', 'claude-3-sonnet', 'gemini-pro', 'llama2-70b']
93+
94+
for model in models_without_prefix:
95+
profile = provider.model_profile(model)
96+
assert isinstance(profile, OpenAIModelProfile)
97+
assert profile.json_schema_transformer == OpenAIJsonSchemaTransformer
98+
99+
# Verify openai_model_profile was called for each model without prefix
100+
assert mock_profiles['openai'].call_count == len(models_without_prefix)
101+
102+
# Reset all call counts
103+
for mock in mock_profiles.values():
104+
mock.reset_mock()
105+
106+
# Test all provider prefixes
107+
test_cases = [
108+
('anthropic/claude-3-haiku', 'anthropic', 'claude-3-haiku'),
109+
('openai/gpt-4-turbo', 'openai', 'gpt-4-turbo'),
110+
('google/gemini-1.5-pro', 'google', 'gemini-1.5-pro'),
111+
('mistralai/mistral-large', 'mistral', 'mistral-large'),
112+
('mistral/mistral-7b', 'mistral', 'mistral-7b'),
113+
('cohere/command-r', 'cohere', 'command-r'),
114+
('amazon/titan-text', 'amazon', 'titan-text'),
115+
('bedrock/claude-v2', 'amazon', 'claude-v2'),
116+
('meta-llama/llama-3-8b', 'meta', 'llama-3-8b'),
117+
('meta/llama-2-70b', 'meta', 'llama-2-70b'),
118+
('groq/llama3-70b', 'groq', 'llama3-70b'),
119+
('deepseek/deepseek-coder', 'deepseek', 'deepseek-coder'),
120+
('moonshotai/moonshot-v1', 'moonshotai', 'moonshot-v1'),
121+
('x-ai/grok-beta', 'grok', 'grok-beta'),
122+
('qwen/qwen-72b', 'qwen', 'qwen-72b'),
123+
]
124+
125+
for model_name, expected_profile, expected_suffix in test_cases:
126+
profile = provider.model_profile(model_name)
127+
assert isinstance(profile, OpenAIModelProfile)
128+
assert profile.json_schema_transformer == OpenAIJsonSchemaTransformer
129+
# Verify the correct profile function was called with the correct suffix
130+
mock_profiles[expected_profile].assert_called_with(expected_suffix)
131+
mock_profiles[expected_profile].reset_mock()
132+
133+
# Test unknown provider prefix (should fall back to openai)
134+
provider.model_profile('unknown-provider/some-model')
135+
mock_profiles['openai'].assert_called_once_with('unknown-provider/some-model')
136+
137+
138+
async def test_cached_http_client_usage(mocker: MockerFixture):
139+
# Create a real AsyncClient for the mock
140+
async with httpx.AsyncClient() as mock_cached_client:
141+
mock_cached_http_client_func = mocker.patch(
142+
'pydantic_ai.providers.litellm.cached_async_http_client', return_value=mock_cached_client
143+
)
144+
145+
provider = LiteLLMProvider(api_key='test-key')
146+
147+
# Verify cached_async_http_client was called with 'litellm' provider
148+
mock_cached_http_client_func.assert_called_once_with(provider='litellm')
149+
150+
# Verify the client was created
151+
assert isinstance(provider.client, AsyncOpenAI)
152+
153+
154+
async def test_init_with_http_client_overrides_cached():
155+
async with httpx.AsyncClient() as custom_client:
156+
provider = LiteLLMProvider(api_key='test-key', http_client=custom_client)
157+
158+
# Verify the provider was created successfully with custom client
159+
assert isinstance(provider.client, AsyncOpenAI)
160+
assert provider.client.api_key == 'test-key'

0 commit comments

Comments
 (0)