Skip to content

Commit 4e48bb7

Browse files
author
Motta Kin
committed
Address review comments
1 parent a3a9693 commit 4e48bb7

File tree

3 files changed

+116
-40
lines changed

3 files changed

+116
-40
lines changed

pydantic_ai_slim/pydantic_ai/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -699,6 +699,7 @@ def infer_model(model: Model | KnownModelName | str) -> Model: # noqa: C901
699699
'together',
700700
'heroku',
701701
'github',
702+
'litellm',
702703
):
703704
from .openai import OpenAIModel
704705

pydantic_ai_slim/pydantic_ai/providers/litellm.py

Lines changed: 48 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,21 @@
55
from httpx import AsyncClient as AsyncHTTPClient
66
from openai import AsyncOpenAI
77

8+
from pydantic_ai.models import cached_async_http_client
89
from pydantic_ai.profiles import ModelProfile
10+
from pydantic_ai.profiles.amazon import amazon_model_profile
11+
from pydantic_ai.profiles.anthropic import anthropic_model_profile
12+
from pydantic_ai.profiles.cohere import cohere_model_profile
13+
from pydantic_ai.profiles.deepseek import deepseek_model_profile
14+
from pydantic_ai.profiles.google import google_model_profile
15+
from pydantic_ai.profiles.grok import grok_model_profile
16+
from pydantic_ai.profiles.groq import groq_model_profile
17+
from pydantic_ai.profiles.meta import meta_model_profile
18+
from pydantic_ai.profiles.mistral import mistral_model_profile
19+
from pydantic_ai.profiles.moonshotai import moonshotai_model_profile
920
from pydantic_ai.profiles.openai import OpenAIJsonSchemaTransformer, OpenAIModelProfile, openai_model_profile
21+
from pydantic_ai.profiles.qwen import qwen_model_profile
1022
from pydantic_ai.providers import Provider
11-
from pydantic_ai_slim.pydantic_ai.models import cached_async_http_client
1223

1324
try:
1425
from openai import AsyncOpenAI
@@ -28,18 +39,45 @@ def name(self) -> str:
2839

2940
@property
3041
def base_url(self) -> str:
31-
return self._base_url
42+
return str(self.client.base_url)
3243

3344
@property
3445
def client(self) -> AsyncOpenAI:
3546
return self._client
3647

3748
def model_profile(self, model_name: str) -> ModelProfile | None:
38-
# For LiteLLM, we use a basic OpenAI profile since it's OpenAI-compatible
39-
# Users can override this with their own profile if needed
40-
profile = openai_model_profile(model_name)
41-
42-
# As LiteLLMProvider is used with OpenAIModel, which used to use OpenAIJsonSchemaTransformer,
49+
# Map provider prefixes to their profile functions
50+
provider_to_profile = {
51+
'anthropic': anthropic_model_profile,
52+
'openai': openai_model_profile,
53+
'google': google_model_profile,
54+
'mistralai': mistral_model_profile,
55+
'mistral': mistral_model_profile,
56+
'cohere': cohere_model_profile,
57+
'amazon': amazon_model_profile,
58+
'bedrock': amazon_model_profile,
59+
'meta-llama': meta_model_profile,
60+
'meta': meta_model_profile,
61+
'groq': groq_model_profile,
62+
'deepseek': deepseek_model_profile,
63+
'moonshotai': moonshotai_model_profile,
64+
'x-ai': grok_model_profile,
65+
'qwen': qwen_model_profile,
66+
}
67+
68+
profile = None
69+
70+
# Check if model name contains a provider prefix (e.g., "anthropic/claude-3")
71+
if '/' in model_name:
72+
provider_prefix, model_suffix = model_name.split('/', 1)
73+
if provider_prefix in provider_to_profile:
74+
profile = provider_to_profile[provider_prefix](model_suffix)
75+
76+
# If no profile found, default to OpenAI profile
77+
if profile is None:
78+
profile = openai_model_profile(model_name)
79+
80+
# As LiteLLMProvider is used with OpenAIModel, which uses OpenAIJsonSchemaTransformer,
4381
# we maintain that behavior
4482
return OpenAIModelProfile(json_schema_transformer=OpenAIJsonSchemaTransformer).update(profile)
4583

@@ -85,36 +123,20 @@ def __init__(
85123
"""
86124
if openai_client is not None:
87125
self._client = openai_client
88-
self._base_url = str(openai_client.base_url)
89126
return
90127

91-
# Set up LiteLLM configuration
92-
if api_key:
93-
# Store API key in LiteLLM's global config if needed
94-
# LiteLLM will handle provider-specific API key names
95-
pass
96-
97-
if custom_llm_provider:
98-
# LiteLLM can auto-detect most providers, but this allows override
99-
pass
100-
101128
# Use api_base if provided, otherwise use a generic base URL
102129
# LiteLLM doesn't actually use this URL - it routes internally
103-
self._base_url = api_base or 'https://api.litellm.ai/v1'
130+
base_url = api_base or 'https://api.litellm.ai/v1'
104131

105132
# Create OpenAI client that will be used with LiteLLM's completion function
106133
# The actual API calls will be intercepted and routed through LiteLLM
107134
if http_client is not None:
108135
self._client = AsyncOpenAI(
109-
base_url=self._base_url, api_key=api_key or 'litellm-placeholder', http_client=http_client
136+
base_url=base_url, api_key=api_key or 'litellm-placeholder', http_client=http_client
110137
)
111138
else:
112139
http_client = cached_async_http_client(provider='litellm')
113140
self._client = AsyncOpenAI(
114-
base_url=self._base_url, api_key=api_key or 'litellm-placeholder', http_client=http_client
141+
base_url=base_url, api_key=api_key or 'litellm-placeholder', http_client=http_client
115142
)
116-
117-
# Store configuration for LiteLLM
118-
self._api_key = api_key
119-
self._api_base = api_base
120-
self._custom_llm_provider = custom_llm_provider

tests/providers/test_litellm.py

Lines changed: 67 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
def test_init_with_api_config():
2020
provider = LiteLLMProvider(api_key='test-key', api_base='https://custom.litellm.com/v1')
21-
assert provider.base_url == 'https://custom.litellm.com/v1'
21+
assert provider.base_url == 'https://custom.litellm.com/v1/'
2222
assert provider.client.api_key == 'test-key'
2323

2424

@@ -30,7 +30,7 @@ def test_init_with_custom_llm_provider():
3030
def test_init_without_api_key():
3131
provider = LiteLLMProvider()
3232
assert provider.name == 'litellm'
33-
assert provider.base_url == 'https://api.litellm.ai/v1'
33+
assert provider.base_url == 'https://api.litellm.ai/v1/'
3434
assert provider.client.api_key == 'litellm-placeholder'
3535

3636

@@ -68,24 +68,77 @@ class MockProfile:
6868
def test_model_profile_with_different_models(mocker: MockerFixture):
6969
provider = LiteLLMProvider(api_key='test-key')
7070

71-
mock_openai_profile = mocker.patch('pydantic_ai.providers.litellm.openai_model_profile', return_value={})
71+
# Create mocks for all profile functions
72+
from dataclasses import dataclass
7273

73-
# Test with different model formats that LiteLLM supports
74-
test_models = [
75-
'gpt-4',
76-
'claude-3-sonnet-20240229',
77-
'gemini-pro',
78-
'llama2-70b-chat',
79-
'anthropic/claude-3-haiku-20240307',
80-
]
74+
@dataclass
75+
class MockProfile:
76+
max_tokens: int = 4096
77+
supports_streaming: bool = True
8178

82-
for model in test_models:
79+
# Mock all profile functions
80+
mock_profiles = {
81+
'openai': mocker.patch('pydantic_ai.providers.litellm.openai_model_profile', return_value=MockProfile()),
82+
'anthropic': mocker.patch('pydantic_ai.providers.litellm.anthropic_model_profile', return_value=MockProfile()),
83+
'google': mocker.patch('pydantic_ai.providers.litellm.google_model_profile', return_value=MockProfile()),
84+
'meta': mocker.patch('pydantic_ai.providers.litellm.meta_model_profile', return_value=MockProfile()),
85+
'mistral': mocker.patch('pydantic_ai.providers.litellm.mistral_model_profile', return_value=MockProfile()),
86+
'cohere': mocker.patch('pydantic_ai.providers.litellm.cohere_model_profile', return_value=MockProfile()),
87+
'amazon': mocker.patch('pydantic_ai.providers.litellm.amazon_model_profile', return_value=MockProfile()),
88+
'deepseek': mocker.patch('pydantic_ai.providers.litellm.deepseek_model_profile', return_value=MockProfile()),
89+
'groq': mocker.patch('pydantic_ai.providers.litellm.groq_model_profile', return_value=MockProfile()),
90+
'grok': mocker.patch('pydantic_ai.providers.litellm.grok_model_profile', return_value=MockProfile()),
91+
'moonshotai': mocker.patch(
92+
'pydantic_ai.providers.litellm.moonshotai_model_profile', return_value=MockProfile()
93+
),
94+
'qwen': mocker.patch('pydantic_ai.providers.litellm.qwen_model_profile', return_value=MockProfile()),
95+
}
96+
97+
# Test models without provider prefix (should use openai profile)
98+
models_without_prefix = ['gpt-4', 'claude-3-sonnet', 'gemini-pro', 'llama2-70b']
99+
100+
for model in models_without_prefix:
83101
profile = provider.model_profile(model)
84102
assert isinstance(profile, OpenAIModelProfile)
85103
assert profile.json_schema_transformer == OpenAIJsonSchemaTransformer
86104

87-
# Verify openai_model_profile was called for each model
88-
assert mock_openai_profile.call_count == len(test_models)
105+
# Verify openai_model_profile was called for each model without prefix
106+
assert mock_profiles['openai'].call_count == len(models_without_prefix)
107+
108+
# Reset all call counts
109+
for mock in mock_profiles.values():
110+
mock.reset_mock()
111+
112+
# Test all provider prefixes
113+
test_cases = [
114+
('anthropic/claude-3-haiku', 'anthropic', 'claude-3-haiku'),
115+
('openai/gpt-4-turbo', 'openai', 'gpt-4-turbo'),
116+
('google/gemini-1.5-pro', 'google', 'gemini-1.5-pro'),
117+
('mistralai/mistral-large', 'mistral', 'mistral-large'),
118+
('mistral/mistral-7b', 'mistral', 'mistral-7b'),
119+
('cohere/command-r', 'cohere', 'command-r'),
120+
('amazon/titan-text', 'amazon', 'titan-text'),
121+
('bedrock/claude-v2', 'amazon', 'claude-v2'),
122+
('meta-llama/llama-3-8b', 'meta', 'llama-3-8b'),
123+
('meta/llama-2-70b', 'meta', 'llama-2-70b'),
124+
('groq/llama3-70b', 'groq', 'llama3-70b'),
125+
('deepseek/deepseek-coder', 'deepseek', 'deepseek-coder'),
126+
('moonshotai/moonshot-v1', 'moonshotai', 'moonshot-v1'),
127+
('x-ai/grok-beta', 'grok', 'grok-beta'),
128+
('qwen/qwen-72b', 'qwen', 'qwen-72b'),
129+
]
130+
131+
for model_name, expected_profile, expected_suffix in test_cases:
132+
profile = provider.model_profile(model_name)
133+
assert isinstance(profile, OpenAIModelProfile)
134+
assert profile.json_schema_transformer == OpenAIJsonSchemaTransformer
135+
# Verify the correct profile function was called with the correct suffix
136+
mock_profiles[expected_profile].assert_called_with(expected_suffix)
137+
mock_profiles[expected_profile].reset_mock()
138+
139+
# Test unknown provider prefix (should fall back to openai)
140+
provider.model_profile('unknown-provider/some-model')
141+
mock_profiles['openai'].assert_called_once_with('unknown-provider/some-model')
89142

90143

91144
async def test_cached_http_client_usage(mocker: MockerFixture):

0 commit comments

Comments
 (0)