diff --git a/docs/api/providers.md b/docs/api/providers.md index 7b922c4d07..cc90eeffbb 100644 --- a/docs/api/providers.md +++ b/docs/api/providers.md @@ -39,3 +39,5 @@ ::: pydantic_ai.providers.moonshotai.MoonshotAIProvider ::: pydantic_ai.providers.ollama.OllamaProvider + +::: pydantic_ai.providers.litellm.LiteLLMProvider diff --git a/docs/models/openai.md b/docs/models/openai.md index 396e61df71..86febbcc17 100644 --- a/docs/models/openai.md +++ b/docs/models/openai.md @@ -563,3 +563,31 @@ result = agent.run_sync('What is the capital of France?') print(result.output) #> The capital of France is Paris. ``` + +### LiteLLM + +To use [LiteLLM](https://www.litellm.ai/), set the configs as outlined in the [doc](https://docs.litellm.ai/docs/set_keys). In `LiteLLMProvider`, you can pass `api_base` and `api_key`. The value of these configs will depend on your setup. For example, if you are using OpenAI models, then you need to pass `https://api.openai.com/v1` as the `api_base` and your OpenAI API key as the `api_key`. If you are using a LiteLLM proxy server running on your local machine, then you need to pass `http://localhost:` as the `api_base` and your LiteLLM API key (or a placeholder) as the `api_key`. + +To use custom LLMs, use `custom/` prefix in the model name. + +Once you have the configs, use the [`LiteLLMProvider`][pydantic_ai.providers.litellm.LiteLLMProvider] as follows: + +```python +from pydantic_ai import Agent +from pydantic_ai.models.openai import OpenAIChatModel +from pydantic_ai.providers.litellm import LiteLLMProvider + +model = OpenAIChatModel( + 'openai/gpt-3.5-turbo', + provider=LiteLLMProvider( + api_base='', + api_key='' + ) +) +agent = Agent(model) + +result = agent.run_sync('What is the capital of France?') +print(result.output) +#> The capital of France is Paris. +... +``` diff --git a/docs/models/overview.md b/docs/models/overview.md index e68995c39c..4e52e1380d 100644 --- a/docs/models/overview.md +++ b/docs/models/overview.md @@ -27,6 +27,7 @@ In addition, many providers are compatible with the OpenAI API, and can be used - [Heroku](openai.md#heroku-ai) - [GitHub Models](openai.md#github-models) - [Cerebras](openai.md#cerebras) +- [LiteLLM](openai.md#litellm) Pydantic AI also comes with [`TestModel`](../api/models/test.md) and [`FunctionModel`](../api/models/function.md) for testing and development. diff --git a/pydantic_ai_slim/pydantic_ai/models/__init__.py b/pydantic_ai_slim/pydantic_ai/models/__init__.py index 86d1e79721..0d6e04f6e8 100644 --- a/pydantic_ai_slim/pydantic_ai/models/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/models/__init__.py @@ -728,6 +728,7 @@ def infer_model(model: Model | KnownModelName | str) -> Model: # noqa: C901 'openrouter', 'together', 'vercel', + 'litellm', ): from .openai import OpenAIChatModel diff --git a/pydantic_ai_slim/pydantic_ai/models/openai.py b/pydantic_ai_slim/pydantic_ai/models/openai.py index 48c2bfb4be..24afad6ab0 100644 --- a/pydantic_ai_slim/pydantic_ai/models/openai.py +++ b/pydantic_ai_slim/pydantic_ai/models/openai.py @@ -225,6 +225,7 @@ def __init__( 'openrouter', 'together', 'vercel', + 'litellm', ] | Provider[AsyncOpenAI] = 'openai', profile: ModelProfileSpec | None = None, @@ -252,6 +253,7 @@ def __init__( 'openrouter', 'together', 'vercel', + 'litellm', ] | Provider[AsyncOpenAI] = 'openai', profile: ModelProfileSpec | None = None, @@ -278,6 +280,7 @@ def __init__( 'openrouter', 'together', 'vercel', + 'litellm', ] | Provider[AsyncOpenAI] = 'openai', profile: ModelProfileSpec | None = None, diff --git a/pydantic_ai_slim/pydantic_ai/providers/__init__.py b/pydantic_ai_slim/pydantic_ai/providers/__init__.py index b61dc44cd4..e266fd544e 100644 --- a/pydantic_ai_slim/pydantic_ai/providers/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/providers/__init__.py @@ -135,6 +135,10 @@ def infer_provider_class(provider: str) -> type[Provider[Any]]: # noqa: C901 from .github import GitHubProvider return GitHubProvider + elif provider == 'litellm': + from .litellm import LiteLLMProvider + + return LiteLLMProvider else: # pragma: no cover raise ValueError(f'Unknown provider: {provider}') diff --git a/pydantic_ai_slim/pydantic_ai/providers/litellm.py b/pydantic_ai_slim/pydantic_ai/providers/litellm.py new file mode 100644 index 0000000000..91db24d4a0 --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/providers/litellm.py @@ -0,0 +1,134 @@ +from __future__ import annotations as _annotations + +from typing import overload + +from httpx import AsyncClient as AsyncHTTPClient +from openai import AsyncOpenAI + +from pydantic_ai.models import cached_async_http_client +from pydantic_ai.profiles import ModelProfile +from pydantic_ai.profiles.amazon import amazon_model_profile +from pydantic_ai.profiles.anthropic import anthropic_model_profile +from pydantic_ai.profiles.cohere import cohere_model_profile +from pydantic_ai.profiles.deepseek import deepseek_model_profile +from pydantic_ai.profiles.google import google_model_profile +from pydantic_ai.profiles.grok import grok_model_profile +from pydantic_ai.profiles.groq import groq_model_profile +from pydantic_ai.profiles.meta import meta_model_profile +from pydantic_ai.profiles.mistral import mistral_model_profile +from pydantic_ai.profiles.moonshotai import moonshotai_model_profile +from pydantic_ai.profiles.openai import OpenAIJsonSchemaTransformer, OpenAIModelProfile, openai_model_profile +from pydantic_ai.profiles.qwen import qwen_model_profile +from pydantic_ai.providers import Provider + +try: + from openai import AsyncOpenAI +except ImportError as _import_error: # pragma: no cover + raise ImportError( + 'Please install the `openai` package to use the LiteLLM provider, ' + 'you can use the `openai` optional group — `pip install "pydantic-ai-slim[openai]"`' + ) from _import_error + + +class LiteLLMProvider(Provider[AsyncOpenAI]): + """Provider for LiteLLM API.""" + + @property + def name(self) -> str: + return 'litellm' + + @property + def base_url(self) -> str: + return str(self.client.base_url) + + @property + def client(self) -> AsyncOpenAI: + return self._client + + def model_profile(self, model_name: str) -> ModelProfile | None: + # Map provider prefixes to their profile functions + provider_to_profile = { + 'anthropic': anthropic_model_profile, + 'openai': openai_model_profile, + 'google': google_model_profile, + 'mistralai': mistral_model_profile, + 'mistral': mistral_model_profile, + 'cohere': cohere_model_profile, + 'amazon': amazon_model_profile, + 'bedrock': amazon_model_profile, + 'meta-llama': meta_model_profile, + 'meta': meta_model_profile, + 'groq': groq_model_profile, + 'deepseek': deepseek_model_profile, + 'moonshotai': moonshotai_model_profile, + 'x-ai': grok_model_profile, + 'qwen': qwen_model_profile, + } + + profile = None + + # Check if model name contains a provider prefix (e.g., "anthropic/claude-3") + if '/' in model_name: + provider_prefix, model_suffix = model_name.split('/', 1) + if provider_prefix in provider_to_profile: + profile = provider_to_profile[provider_prefix](model_suffix) + + # If no profile found, default to OpenAI profile + if profile is None: + profile = openai_model_profile(model_name) + + # As LiteLLMProvider is used with OpenAIModel, which uses OpenAIJsonSchemaTransformer, + # we maintain that behavior + return OpenAIModelProfile(json_schema_transformer=OpenAIJsonSchemaTransformer).update(profile) + + @overload + def __init__( + self, + *, + api_key: str | None = None, + api_base: str | None = None, + ) -> None: ... + + @overload + def __init__( + self, + *, + api_key: str | None = None, + api_base: str | None = None, + http_client: AsyncHTTPClient, + ) -> None: ... + + @overload + def __init__(self, *, openai_client: AsyncOpenAI) -> None: ... + + def __init__( + self, + *, + api_key: str | None = None, + api_base: str | None = None, + openai_client: AsyncOpenAI | None = None, + http_client: AsyncHTTPClient | None = None, + ) -> None: + """Initialize a LiteLLM provider. + + Args: + api_key: API key for the model provider. If None, LiteLLM will try to get it from environment variables. + api_base: Base URL for the model provider. Use this for custom endpoints or self-hosted models. + openai_client: Pre-configured OpenAI client. If provided, other parameters are ignored. + http_client: Custom HTTP client to use. + """ + if openai_client is not None: + self._client = openai_client + return + + # Create OpenAI client that will be used with LiteLLM's completion function + # The actual API calls will be intercepted and routed through LiteLLM + if http_client is not None: + self._client = AsyncOpenAI( + base_url=api_base, api_key=api_key or 'litellm-placeholder', http_client=http_client + ) + else: + http_client = cached_async_http_client(provider='litellm') + self._client = AsyncOpenAI( + base_url=api_base, api_key=api_key or 'litellm-placeholder', http_client=http_client + ) diff --git a/tests/providers/test_litellm.py b/tests/providers/test_litellm.py new file mode 100644 index 0000000000..4acb9c7e93 --- /dev/null +++ b/tests/providers/test_litellm.py @@ -0,0 +1,160 @@ +import httpx +import pytest +from pytest_mock import MockerFixture + +from ..conftest import try_import + +with try_import() as imports_successful: + from openai import AsyncOpenAI + + from pydantic_ai.profiles.openai import OpenAIJsonSchemaTransformer, OpenAIModelProfile + from pydantic_ai.providers.litellm import LiteLLMProvider + +pytestmark = [ + pytest.mark.skipif(not imports_successful(), reason='OpenAI client not installed'), + pytest.mark.anyio, +] + + +def test_init_with_api_config(): + provider = LiteLLMProvider(api_key='test-key', api_base='https://custom.litellm.com/v1') + assert provider.base_url == 'https://custom.litellm.com/v1/' + assert provider.client.api_key == 'test-key' + + +def test_init_without_api_key(): + provider = LiteLLMProvider() + assert provider.name == 'litellm' + assert provider.client.api_key == 'litellm-placeholder' + + +def test_init_with_openai_client(): + openai_client = AsyncOpenAI(api_key='custom-key', base_url='https://custom.openai.com/v1') + provider = LiteLLMProvider(openai_client=openai_client) + assert provider.client == openai_client + assert provider.base_url == 'https://custom.openai.com/v1/' + + +def test_model_profile_returns_openai_compatible_profile(mocker: MockerFixture): + provider = LiteLLMProvider(api_key='test-key') + + # Create a proper mock profile object that can be updated + from dataclasses import dataclass + + @dataclass + class MockProfile: + max_tokens: int = 4096 + supports_streaming: bool = True + + mock_profile = MockProfile() + mock_openai_profile = mocker.patch('pydantic_ai.providers.litellm.openai_model_profile', return_value=mock_profile) + + profile = provider.model_profile('gpt-3.5-turbo') + + # Verify openai_model_profile was called with the model name + mock_openai_profile.assert_called_once_with('gpt-3.5-turbo') + + # Verify the returned profile is an OpenAIModelProfile with OpenAIJsonSchemaTransformer + assert isinstance(profile, OpenAIModelProfile) + assert profile.json_schema_transformer == OpenAIJsonSchemaTransformer + + +def test_model_profile_with_different_models(mocker: MockerFixture): + provider = LiteLLMProvider(api_key='test-key') + + # Create mocks for all profile functions + from dataclasses import dataclass + + @dataclass + class MockProfile: + max_tokens: int = 4096 + supports_streaming: bool = True + + # Mock all profile functions + mock_profiles = { + 'openai': mocker.patch('pydantic_ai.providers.litellm.openai_model_profile', return_value=MockProfile()), + 'anthropic': mocker.patch('pydantic_ai.providers.litellm.anthropic_model_profile', return_value=MockProfile()), + 'google': mocker.patch('pydantic_ai.providers.litellm.google_model_profile', return_value=MockProfile()), + 'meta': mocker.patch('pydantic_ai.providers.litellm.meta_model_profile', return_value=MockProfile()), + 'mistral': mocker.patch('pydantic_ai.providers.litellm.mistral_model_profile', return_value=MockProfile()), + 'cohere': mocker.patch('pydantic_ai.providers.litellm.cohere_model_profile', return_value=MockProfile()), + 'amazon': mocker.patch('pydantic_ai.providers.litellm.amazon_model_profile', return_value=MockProfile()), + 'deepseek': mocker.patch('pydantic_ai.providers.litellm.deepseek_model_profile', return_value=MockProfile()), + 'groq': mocker.patch('pydantic_ai.providers.litellm.groq_model_profile', return_value=MockProfile()), + 'grok': mocker.patch('pydantic_ai.providers.litellm.grok_model_profile', return_value=MockProfile()), + 'moonshotai': mocker.patch( + 'pydantic_ai.providers.litellm.moonshotai_model_profile', return_value=MockProfile() + ), + 'qwen': mocker.patch('pydantic_ai.providers.litellm.qwen_model_profile', return_value=MockProfile()), + } + + # Test models without provider prefix (should use openai profile) + models_without_prefix = ['gpt-4', 'claude-3-sonnet', 'gemini-pro', 'llama2-70b'] + + for model in models_without_prefix: + profile = provider.model_profile(model) + assert isinstance(profile, OpenAIModelProfile) + assert profile.json_schema_transformer == OpenAIJsonSchemaTransformer + + # Verify openai_model_profile was called for each model without prefix + assert mock_profiles['openai'].call_count == len(models_without_prefix) + + # Reset all call counts + for mock in mock_profiles.values(): + mock.reset_mock() + + # Test all provider prefixes + test_cases = [ + ('anthropic/claude-3-haiku', 'anthropic', 'claude-3-haiku'), + ('openai/gpt-4-turbo', 'openai', 'gpt-4-turbo'), + ('google/gemini-1.5-pro', 'google', 'gemini-1.5-pro'), + ('mistralai/mistral-large', 'mistral', 'mistral-large'), + ('mistral/mistral-7b', 'mistral', 'mistral-7b'), + ('cohere/command-r', 'cohere', 'command-r'), + ('amazon/titan-text', 'amazon', 'titan-text'), + ('bedrock/claude-v2', 'amazon', 'claude-v2'), + ('meta-llama/llama-3-8b', 'meta', 'llama-3-8b'), + ('meta/llama-2-70b', 'meta', 'llama-2-70b'), + ('groq/llama3-70b', 'groq', 'llama3-70b'), + ('deepseek/deepseek-coder', 'deepseek', 'deepseek-coder'), + ('moonshotai/moonshot-v1', 'moonshotai', 'moonshot-v1'), + ('x-ai/grok-beta', 'grok', 'grok-beta'), + ('qwen/qwen-72b', 'qwen', 'qwen-72b'), + ] + + for model_name, expected_profile, expected_suffix in test_cases: + profile = provider.model_profile(model_name) + assert isinstance(profile, OpenAIModelProfile) + assert profile.json_schema_transformer == OpenAIJsonSchemaTransformer + # Verify the correct profile function was called with the correct suffix + mock_profiles[expected_profile].assert_called_with(expected_suffix) + mock_profiles[expected_profile].reset_mock() + + # Test unknown provider prefix (should fall back to openai) + provider.model_profile('unknown-provider/some-model') + mock_profiles['openai'].assert_called_once_with('unknown-provider/some-model') + + +async def test_cached_http_client_usage(mocker: MockerFixture): + # Create a real AsyncClient for the mock + async with httpx.AsyncClient() as mock_cached_client: + mock_cached_http_client_func = mocker.patch( + 'pydantic_ai.providers.litellm.cached_async_http_client', return_value=mock_cached_client + ) + + provider = LiteLLMProvider(api_key='test-key') + + # Verify cached_async_http_client was called with 'litellm' provider + mock_cached_http_client_func.assert_called_once_with(provider='litellm') + + # Verify the client was created + assert isinstance(provider.client, AsyncOpenAI) + + +async def test_init_with_http_client_overrides_cached(): + async with httpx.AsyncClient() as custom_client: + provider = LiteLLMProvider(api_key='test-key', http_client=custom_client) + + # Verify the provider was created successfully with custom client + assert isinstance(provider.client, AsyncOpenAI) + assert provider.client.api_key == 'test-key' diff --git a/tests/providers/test_provider_names.py b/tests/providers/test_provider_names.py index e09ba2a11c..695383147e 100644 --- a/tests/providers/test_provider_names.py +++ b/tests/providers/test_provider_names.py @@ -25,6 +25,7 @@ from pydantic_ai.providers.grok import GrokProvider from pydantic_ai.providers.groq import GroqProvider from pydantic_ai.providers.heroku import HerokuProvider + from pydantic_ai.providers.litellm import LiteLLMProvider from pydantic_ai.providers.mistral import MistralProvider from pydantic_ai.providers.moonshotai import MoonshotAIProvider from pydantic_ai.providers.ollama import OllamaProvider @@ -52,6 +53,7 @@ ('heroku', HerokuProvider, 'HEROKU_INFERENCE_KEY'), ('github', GitHubProvider, 'GITHUB_API_KEY'), ('ollama', OllamaProvider, 'OLLAMA_BASE_URL'), + ('litellm', LiteLLMProvider, None), ] if not imports_successful():