Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/api/providers.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,5 @@
::: pydantic_ai.providers.moonshotai.MoonshotAIProvider

::: pydantic_ai.providers.ollama.OllamaProvider

::: pydantic_ai.providers.litellm.LiteLLMProvider
28 changes: 28 additions & 0 deletions docs/models/openai.md
Original file line number Diff line number Diff line change
Expand Up @@ -563,3 +563,31 @@ result = agent.run_sync('What is the capital of France?')
print(result.output)
#> The capital of France is Paris.
```

### LiteLLM

To use [LiteLLM](https://www.litellm.ai/), set the configs as outlined in the [doc](https://docs.litellm.ai/docs/set_keys). In `LiteLLMProvider`, you can pass `api_base` and `api_key`. The value of these configs will depend on your setup. For example, if you are using OpenAI models, then you need to pass `https://api.openai.com/v1` as the `api_base` and your OpenAI API key as the `api_key`. If you are using a LiteLLM proxy server running on your local machine, then you need to pass `http://localhost:<port>` as the `api_base` and your LiteLLM API key (or a placeholder) as the `api_key`.

To use custom LLMs, use `custom/` prefix in the model name.

Once you have the configs, use the [`LiteLLMProvider`][pydantic_ai.providers.litellm.LiteLLMProvider] as follows:

```python
from pydantic_ai import Agent
from pydantic_ai.models.openai import OpenAIChatModel
from pydantic_ai.providers.litellm import LiteLLMProvider

model = OpenAIChatModel(
'openai/gpt-3.5-turbo',
provider=LiteLLMProvider(
api_base='<api-base-url>',
api_key='<api-key>'
)
)
agent = Agent(model)

result = agent.run_sync('What is the capital of France?')
print(result.output)
#> The capital of France is Paris.
...
```
1 change: 1 addition & 0 deletions docs/models/overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ In addition, many providers are compatible with the OpenAI API, and can be used
- [Heroku](openai.md#heroku-ai)
- [GitHub Models](openai.md#github-models)
- [Cerebras](openai.md#cerebras)
- [LiteLLM](openai.md#litellm)

Pydantic AI also comes with [`TestModel`](../api/models/test.md) and [`FunctionModel`](../api/models/function.md)
for testing and development.
Expand Down
1 change: 1 addition & 0 deletions pydantic_ai_slim/pydantic_ai/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -728,6 +728,7 @@ def infer_model(model: Model | KnownModelName | str) -> Model: # noqa: C901
'openrouter',
'together',
'vercel',
'litellm',
):
from .openai import OpenAIChatModel

Expand Down
3 changes: 3 additions & 0 deletions pydantic_ai_slim/pydantic_ai/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ def __init__(
'openrouter',
'together',
'vercel',
'litellm',
]
| Provider[AsyncOpenAI] = 'openai',
profile: ModelProfileSpec | None = None,
Expand Down Expand Up @@ -252,6 +253,7 @@ def __init__(
'openrouter',
'together',
'vercel',
'litellm',
]
| Provider[AsyncOpenAI] = 'openai',
profile: ModelProfileSpec | None = None,
Expand All @@ -278,6 +280,7 @@ def __init__(
'openrouter',
'together',
'vercel',
'litellm',
]
| Provider[AsyncOpenAI] = 'openai',
profile: ModelProfileSpec | None = None,
Expand Down
4 changes: 4 additions & 0 deletions pydantic_ai_slim/pydantic_ai/providers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,10 @@ def infer_provider_class(provider: str) -> type[Provider[Any]]: # noqa: C901
from .github import GitHubProvider

return GitHubProvider
elif provider == 'litellm':
from .litellm import LiteLLMProvider

return LiteLLMProvider
else: # pragma: no cover
raise ValueError(f'Unknown provider: {provider}')

Expand Down
134 changes: 134 additions & 0 deletions pydantic_ai_slim/pydantic_ai/providers/litellm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
from __future__ import annotations as _annotations

from typing import overload

from httpx import AsyncClient as AsyncHTTPClient
from openai import AsyncOpenAI

from pydantic_ai.models import cached_async_http_client
from pydantic_ai.profiles import ModelProfile
from pydantic_ai.profiles.amazon import amazon_model_profile
from pydantic_ai.profiles.anthropic import anthropic_model_profile
from pydantic_ai.profiles.cohere import cohere_model_profile
from pydantic_ai.profiles.deepseek import deepseek_model_profile
from pydantic_ai.profiles.google import google_model_profile
from pydantic_ai.profiles.grok import grok_model_profile
from pydantic_ai.profiles.groq import groq_model_profile
from pydantic_ai.profiles.meta import meta_model_profile
from pydantic_ai.profiles.mistral import mistral_model_profile
from pydantic_ai.profiles.moonshotai import moonshotai_model_profile
from pydantic_ai.profiles.openai import OpenAIJsonSchemaTransformer, OpenAIModelProfile, openai_model_profile
from pydantic_ai.profiles.qwen import qwen_model_profile
from pydantic_ai.providers import Provider

try:
from openai import AsyncOpenAI
except ImportError as _import_error: # pragma: no cover
raise ImportError(
'Please install the `openai` package to use the LiteLLM provider, '
'you can use the `openai` optional group — `pip install "pydantic-ai-slim[openai]"`'
) from _import_error


class LiteLLMProvider(Provider[AsyncOpenAI]):
"""Provider for LiteLLM API."""

@property
def name(self) -> str:
return 'litellm'

@property
def base_url(self) -> str:
return str(self.client.base_url)

@property
def client(self) -> AsyncOpenAI:
return self._client

def model_profile(self, model_name: str) -> ModelProfile | None:
# Map provider prefixes to their profile functions
provider_to_profile = {
'anthropic': anthropic_model_profile,
'openai': openai_model_profile,
'google': google_model_profile,
'mistralai': mistral_model_profile,
'mistral': mistral_model_profile,
'cohere': cohere_model_profile,
'amazon': amazon_model_profile,
'bedrock': amazon_model_profile,
'meta-llama': meta_model_profile,
'meta': meta_model_profile,
'groq': groq_model_profile,
'deepseek': deepseek_model_profile,
'moonshotai': moonshotai_model_profile,
'x-ai': grok_model_profile,
'qwen': qwen_model_profile,
}

profile = None

# Check if model name contains a provider prefix (e.g., "anthropic/claude-3")
if '/' in model_name:
provider_prefix, model_suffix = model_name.split('/', 1)
if provider_prefix in provider_to_profile:
profile = provider_to_profile[provider_prefix](model_suffix)

# If no profile found, default to OpenAI profile
if profile is None:
profile = openai_model_profile(model_name)

# As LiteLLMProvider is used with OpenAIModel, which uses OpenAIJsonSchemaTransformer,
# we maintain that behavior
return OpenAIModelProfile(json_schema_transformer=OpenAIJsonSchemaTransformer).update(profile)

@overload
def __init__(
self,
*,
api_key: str | None = None,
api_base: str | None = None,
) -> None: ...

@overload
def __init__(
self,
*,
api_key: str | None = None,
api_base: str | None = None,
http_client: AsyncHTTPClient,
) -> None: ...

@overload
def __init__(self, *, openai_client: AsyncOpenAI) -> None: ...

def __init__(
self,
*,
api_key: str | None = None,
api_base: str | None = None,
openai_client: AsyncOpenAI | None = None,
http_client: AsyncHTTPClient | None = None,
) -> None:
"""Initialize a LiteLLM provider.

Args:
api_key: API key for the model provider. If None, LiteLLM will try to get it from environment variables.
api_base: Base URL for the model provider. Use this for custom endpoints or self-hosted models.
openai_client: Pre-configured OpenAI client. If provided, other parameters are ignored.
http_client: Custom HTTP client to use.
"""
if openai_client is not None:
self._client = openai_client
return

# Create OpenAI client that will be used with LiteLLM's completion function
# The actual API calls will be intercepted and routed through LiteLLM
if http_client is not None:
self._client = AsyncOpenAI(
base_url=api_base, api_key=api_key or 'litellm-placeholder', http_client=http_client
)
else:
http_client = cached_async_http_client(provider='litellm')
self._client = AsyncOpenAI(
base_url=api_base, api_key=api_key or 'litellm-placeholder', http_client=http_client
)
160 changes: 160 additions & 0 deletions tests/providers/test_litellm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
import httpx
import pytest
from pytest_mock import MockerFixture

from ..conftest import try_import

with try_import() as imports_successful:
from openai import AsyncOpenAI

from pydantic_ai.profiles.openai import OpenAIJsonSchemaTransformer, OpenAIModelProfile
from pydantic_ai.providers.litellm import LiteLLMProvider

pytestmark = [
pytest.mark.skipif(not imports_successful(), reason='OpenAI client not installed'),
pytest.mark.anyio,
]


def test_init_with_api_config():
provider = LiteLLMProvider(api_key='test-key', api_base='https://custom.litellm.com/v1')
assert provider.base_url == 'https://custom.litellm.com/v1/'
assert provider.client.api_key == 'test-key'


def test_init_without_api_key():
provider = LiteLLMProvider()
assert provider.name == 'litellm'
assert provider.client.api_key == 'litellm-placeholder'


def test_init_with_openai_client():
openai_client = AsyncOpenAI(api_key='custom-key', base_url='https://custom.openai.com/v1')
provider = LiteLLMProvider(openai_client=openai_client)
assert provider.client == openai_client
assert provider.base_url == 'https://custom.openai.com/v1/'


def test_model_profile_returns_openai_compatible_profile(mocker: MockerFixture):
provider = LiteLLMProvider(api_key='test-key')

# Create a proper mock profile object that can be updated
from dataclasses import dataclass

@dataclass
class MockProfile:
max_tokens: int = 4096
supports_streaming: bool = True

mock_profile = MockProfile()
mock_openai_profile = mocker.patch('pydantic_ai.providers.litellm.openai_model_profile', return_value=mock_profile)

profile = provider.model_profile('gpt-3.5-turbo')

# Verify openai_model_profile was called with the model name
mock_openai_profile.assert_called_once_with('gpt-3.5-turbo')

# Verify the returned profile is an OpenAIModelProfile with OpenAIJsonSchemaTransformer
assert isinstance(profile, OpenAIModelProfile)
assert profile.json_schema_transformer == OpenAIJsonSchemaTransformer


def test_model_profile_with_different_models(mocker: MockerFixture):
provider = LiteLLMProvider(api_key='test-key')

# Create mocks for all profile functions
from dataclasses import dataclass

@dataclass
class MockProfile:
max_tokens: int = 4096
supports_streaming: bool = True

# Mock all profile functions
mock_profiles = {
'openai': mocker.patch('pydantic_ai.providers.litellm.openai_model_profile', return_value=MockProfile()),
'anthropic': mocker.patch('pydantic_ai.providers.litellm.anthropic_model_profile', return_value=MockProfile()),
'google': mocker.patch('pydantic_ai.providers.litellm.google_model_profile', return_value=MockProfile()),
'meta': mocker.patch('pydantic_ai.providers.litellm.meta_model_profile', return_value=MockProfile()),
'mistral': mocker.patch('pydantic_ai.providers.litellm.mistral_model_profile', return_value=MockProfile()),
'cohere': mocker.patch('pydantic_ai.providers.litellm.cohere_model_profile', return_value=MockProfile()),
'amazon': mocker.patch('pydantic_ai.providers.litellm.amazon_model_profile', return_value=MockProfile()),
'deepseek': mocker.patch('pydantic_ai.providers.litellm.deepseek_model_profile', return_value=MockProfile()),
'groq': mocker.patch('pydantic_ai.providers.litellm.groq_model_profile', return_value=MockProfile()),
'grok': mocker.patch('pydantic_ai.providers.litellm.grok_model_profile', return_value=MockProfile()),
'moonshotai': mocker.patch(
'pydantic_ai.providers.litellm.moonshotai_model_profile', return_value=MockProfile()
),
'qwen': mocker.patch('pydantic_ai.providers.litellm.qwen_model_profile', return_value=MockProfile()),
}

# Test models without provider prefix (should use openai profile)
models_without_prefix = ['gpt-4', 'claude-3-sonnet', 'gemini-pro', 'llama2-70b']

for model in models_without_prefix:
profile = provider.model_profile(model)
assert isinstance(profile, OpenAIModelProfile)
assert profile.json_schema_transformer == OpenAIJsonSchemaTransformer

# Verify openai_model_profile was called for each model without prefix
assert mock_profiles['openai'].call_count == len(models_without_prefix)

# Reset all call counts
for mock in mock_profiles.values():
mock.reset_mock()

# Test all provider prefixes
test_cases = [
('anthropic/claude-3-haiku', 'anthropic', 'claude-3-haiku'),
('openai/gpt-4-turbo', 'openai', 'gpt-4-turbo'),
('google/gemini-1.5-pro', 'google', 'gemini-1.5-pro'),
('mistralai/mistral-large', 'mistral', 'mistral-large'),
('mistral/mistral-7b', 'mistral', 'mistral-7b'),
('cohere/command-r', 'cohere', 'command-r'),
('amazon/titan-text', 'amazon', 'titan-text'),
('bedrock/claude-v2', 'amazon', 'claude-v2'),
('meta-llama/llama-3-8b', 'meta', 'llama-3-8b'),
('meta/llama-2-70b', 'meta', 'llama-2-70b'),
('groq/llama3-70b', 'groq', 'llama3-70b'),
('deepseek/deepseek-coder', 'deepseek', 'deepseek-coder'),
('moonshotai/moonshot-v1', 'moonshotai', 'moonshot-v1'),
('x-ai/grok-beta', 'grok', 'grok-beta'),
('qwen/qwen-72b', 'qwen', 'qwen-72b'),
]

for model_name, expected_profile, expected_suffix in test_cases:
profile = provider.model_profile(model_name)
assert isinstance(profile, OpenAIModelProfile)
assert profile.json_schema_transformer == OpenAIJsonSchemaTransformer
# Verify the correct profile function was called with the correct suffix
mock_profiles[expected_profile].assert_called_with(expected_suffix)
mock_profiles[expected_profile].reset_mock()

# Test unknown provider prefix (should fall back to openai)
provider.model_profile('unknown-provider/some-model')
mock_profiles['openai'].assert_called_once_with('unknown-provider/some-model')


async def test_cached_http_client_usage(mocker: MockerFixture):
# Create a real AsyncClient for the mock
async with httpx.AsyncClient() as mock_cached_client:
mock_cached_http_client_func = mocker.patch(
'pydantic_ai.providers.litellm.cached_async_http_client', return_value=mock_cached_client
)

provider = LiteLLMProvider(api_key='test-key')

# Verify cached_async_http_client was called with 'litellm' provider
mock_cached_http_client_func.assert_called_once_with(provider='litellm')

# Verify the client was created
assert isinstance(provider.client, AsyncOpenAI)


async def test_init_with_http_client_overrides_cached():
async with httpx.AsyncClient() as custom_client:
provider = LiteLLMProvider(api_key='test-key', http_client=custom_client)

# Verify the provider was created successfully with custom client
assert isinstance(provider.client, AsyncOpenAI)
assert provider.client.api_key == 'test-key'
Loading