Skip to content

Commit 00cf42d

Browse files
author
Motta Kin
committed
Add LiteLLM provider for OpenAI API compatible models
1 parent 10339bb commit 00cf42d

File tree

4 files changed

+240
-0
lines changed

4 files changed

+240
-0
lines changed

pydantic_ai_slim/pydantic_ai/models/openai.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,7 @@ def __init__(
225225
'openrouter',
226226
'together',
227227
'vercel',
228+
'litellm',
228229
]
229230
| Provider[AsyncOpenAI] = 'openai',
230231
profile: ModelProfileSpec | None = None,
@@ -252,6 +253,7 @@ def __init__(
252253
'openrouter',
253254
'together',
254255
'vercel',
256+
'litellm',
255257
]
256258
| Provider[AsyncOpenAI] = 'openai',
257259
profile: ModelProfileSpec | None = None,
@@ -278,6 +280,7 @@ def __init__(
278280
'openrouter',
279281
'together',
280282
'vercel',
283+
'litellm',
281284
]
282285
| Provider[AsyncOpenAI] = 'openai',
283286
profile: ModelProfileSpec | None = None,

pydantic_ai_slim/pydantic_ai/providers/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,10 @@ def infer_provider_class(provider: str) -> type[Provider[Any]]: # noqa: C901
135135
from .github import GitHubProvider
136136

137137
return GitHubProvider
138+
elif provider == 'litellm':
139+
from .litellm import LiteLLMProvider
140+
141+
return LiteLLMProvider
138142
else: # pragma: no cover
139143
raise ValueError(f'Unknown provider: {provider}')
140144

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
from __future__ import annotations as _annotations
2+
3+
from typing import overload
4+
5+
from httpx import AsyncClient as AsyncHTTPClient
6+
from openai import AsyncOpenAI
7+
8+
from pydantic_ai.profiles import ModelProfile
9+
from pydantic_ai.profiles.openai import OpenAIJsonSchemaTransformer, OpenAIModelProfile, openai_model_profile
10+
from pydantic_ai.providers import Provider
11+
from pydantic_ai_slim.pydantic_ai.models import cached_async_http_client
12+
13+
try:
14+
from openai import AsyncOpenAI
15+
except ImportError as _import_error: # pragma: no cover
16+
raise ImportError(
17+
'Please install the `openai` package to use the LiteLLM provider, '
18+
'you can use the `openai` optional group — `pip install "pydantic-ai-slim[openai]"`'
19+
) from _import_error
20+
21+
22+
class LiteLLMProvider(Provider[AsyncOpenAI]):
23+
"""Provider for LiteLLM API."""
24+
25+
@property
26+
def name(self) -> str:
27+
return 'litellm'
28+
29+
@property
30+
def base_url(self) -> str:
31+
return self._base_url
32+
33+
@property
34+
def client(self) -> AsyncOpenAI:
35+
return self._client
36+
37+
def model_profile(self, model_name: str) -> ModelProfile | None:
38+
# For LiteLLM, we use a basic OpenAI profile since it's OpenAI-compatible
39+
# Users can override this with their own profile if needed
40+
profile = openai_model_profile(model_name)
41+
42+
# As LiteLLMProvider is used with OpenAIModel, which used to use OpenAIJsonSchemaTransformer,
43+
# we maintain that behavior
44+
return OpenAIModelProfile(json_schema_transformer=OpenAIJsonSchemaTransformer).update(profile)
45+
46+
@overload
47+
def __init__(
48+
self,
49+
*,
50+
api_key: str | None = None,
51+
api_base: str | None = None,
52+
custom_llm_provider: str | None = None,
53+
) -> None: ...
54+
55+
@overload
56+
def __init__(
57+
self,
58+
*,
59+
api_key: str | None = None,
60+
api_base: str | None = None,
61+
custom_llm_provider: str | None = None,
62+
http_client: AsyncHTTPClient,
63+
) -> None: ...
64+
65+
@overload
66+
def __init__(self, *, openai_client: AsyncOpenAI) -> None: ...
67+
68+
def __init__(
69+
self,
70+
*,
71+
api_key: str | None = None,
72+
api_base: str | None = None,
73+
custom_llm_provider: str | None = None,
74+
openai_client: AsyncOpenAI | None = None,
75+
http_client: AsyncHTTPClient | None = None,
76+
) -> None:
77+
"""Initialize a LiteLLM provider.
78+
79+
Args:
80+
api_key: API key for the model provider. If None, LiteLLM will try to get it from environment variables.
81+
api_base: Base URL for the model provider. Use this for custom endpoints or self-hosted models.
82+
custom_llm_provider: Custom LLM provider name for LiteLLM. Use this if LiteLLM can't auto-detect the provider.
83+
openai_client: Pre-configured OpenAI client. If provided, other parameters are ignored.
84+
http_client: Custom HTTP client to use.
85+
"""
86+
if openai_client is not None:
87+
self._client = openai_client
88+
self._base_url = str(openai_client.base_url)
89+
return
90+
91+
# Set up LiteLLM configuration
92+
if api_key:
93+
# Store API key in LiteLLM's global config if needed
94+
# LiteLLM will handle provider-specific API key names
95+
pass
96+
97+
if custom_llm_provider:
98+
# LiteLLM can auto-detect most providers, but this allows override
99+
pass
100+
101+
# Use api_base if provided, otherwise use a generic base URL
102+
# LiteLLM doesn't actually use this URL - it routes internally
103+
self._base_url = api_base or 'https://api.litellm.ai/v1'
104+
105+
# Create OpenAI client that will be used with LiteLLM's completion function
106+
# The actual API calls will be intercepted and routed through LiteLLM
107+
if http_client is not None:
108+
self._client = AsyncOpenAI(
109+
base_url=self._base_url, api_key=api_key or 'litellm-placeholder', http_client=http_client
110+
)
111+
else:
112+
http_client = cached_async_http_client(provider='litellm')
113+
self._client = AsyncOpenAI(
114+
base_url=self._base_url, api_key=api_key or 'litellm-placeholder', http_client=http_client
115+
)
116+
117+
# Store configuration for LiteLLM
118+
self._api_key = api_key
119+
self._api_base = api_base
120+
self._custom_llm_provider = custom_llm_provider

tests/providers/test_litellm.py

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
import httpx
2+
import pytest
3+
from pytest_mock import MockerFixture
4+
5+
from ..conftest import try_import
6+
7+
with try_import() as imports_successful:
8+
from openai import AsyncOpenAI
9+
10+
from pydantic_ai.profiles.openai import OpenAIJsonSchemaTransformer, OpenAIModelProfile
11+
from pydantic_ai.providers.litellm import LiteLLMProvider
12+
13+
pytestmark = [
14+
pytest.mark.skipif(not imports_successful(), reason='OpenAI client not installed'),
15+
pytest.mark.anyio,
16+
]
17+
18+
19+
def test_init_with_api_config():
20+
provider = LiteLLMProvider(api_key='test-key', api_base='https://custom.litellm.com/v1')
21+
assert provider.base_url == 'https://custom.litellm.com/v1'
22+
assert provider.client.api_key == 'test-key'
23+
24+
25+
def test_init_with_custom_llm_provider():
26+
provider = LiteLLMProvider(api_key='test-key', custom_llm_provider='anthropic')
27+
assert provider.name == 'litellm'
28+
29+
30+
def test_init_without_api_key():
31+
provider = LiteLLMProvider()
32+
assert provider.name == 'litellm'
33+
assert provider.base_url == 'https://api.litellm.ai/v1'
34+
assert provider.client.api_key == 'litellm-placeholder'
35+
36+
37+
def test_init_with_openai_client():
38+
openai_client = AsyncOpenAI(api_key='custom-key', base_url='https://custom.openai.com/v1')
39+
provider = LiteLLMProvider(openai_client=openai_client)
40+
assert provider.client == openai_client
41+
assert provider.base_url == 'https://custom.openai.com/v1/'
42+
43+
44+
def test_model_profile_returns_openai_compatible_profile(mocker: MockerFixture):
45+
provider = LiteLLMProvider(api_key='test-key')
46+
47+
# Create a proper mock profile object that can be updated
48+
from dataclasses import dataclass
49+
50+
@dataclass
51+
class MockProfile:
52+
max_tokens: int = 4096
53+
supports_streaming: bool = True
54+
55+
mock_profile = MockProfile()
56+
mock_openai_profile = mocker.patch('pydantic_ai.providers.litellm.openai_model_profile', return_value=mock_profile)
57+
58+
profile = provider.model_profile('gpt-3.5-turbo')
59+
60+
# Verify openai_model_profile was called with the model name
61+
mock_openai_profile.assert_called_once_with('gpt-3.5-turbo')
62+
63+
# Verify the returned profile is an OpenAIModelProfile with OpenAIJsonSchemaTransformer
64+
assert isinstance(profile, OpenAIModelProfile)
65+
assert profile.json_schema_transformer == OpenAIJsonSchemaTransformer
66+
67+
68+
def test_model_profile_with_different_models(mocker: MockerFixture):
69+
provider = LiteLLMProvider(api_key='test-key')
70+
71+
mock_openai_profile = mocker.patch('pydantic_ai.providers.litellm.openai_model_profile', return_value={})
72+
73+
# Test with different model formats that LiteLLM supports
74+
test_models = [
75+
'gpt-4',
76+
'claude-3-sonnet-20240229',
77+
'gemini-pro',
78+
'llama2-70b-chat',
79+
'anthropic/claude-3-haiku-20240307',
80+
]
81+
82+
for model in test_models:
83+
profile = provider.model_profile(model)
84+
assert isinstance(profile, OpenAIModelProfile)
85+
assert profile.json_schema_transformer == OpenAIJsonSchemaTransformer
86+
87+
# Verify openai_model_profile was called for each model
88+
assert mock_openai_profile.call_count == len(test_models)
89+
90+
91+
async def test_cached_http_client_usage(mocker: MockerFixture):
92+
# Create a real AsyncClient for the mock
93+
async with httpx.AsyncClient() as mock_cached_client:
94+
mock_cached_http_client_func = mocker.patch(
95+
'pydantic_ai.providers.litellm.cached_async_http_client', return_value=mock_cached_client
96+
)
97+
98+
provider = LiteLLMProvider(api_key='test-key')
99+
100+
# Verify cached_async_http_client was called with 'litellm' provider
101+
mock_cached_http_client_func.assert_called_once_with(provider='litellm')
102+
103+
# Verify the client was created
104+
assert isinstance(provider.client, AsyncOpenAI)
105+
106+
107+
async def test_init_with_http_client_overrides_cached():
108+
async with httpx.AsyncClient() as custom_client:
109+
provider = LiteLLMProvider(api_key='test-key', http_client=custom_client)
110+
111+
# Verify the provider was created successfully with custom client
112+
assert isinstance(provider.client, AsyncOpenAI)
113+
assert provider.client.api_key == 'test-key'

0 commit comments

Comments
 (0)