1818
1919def test_init_with_api_config ():
2020 provider = LiteLLMProvider (api_key = 'test-key' , api_base = 'https://custom.litellm.com/v1' )
21- assert provider .base_url == 'https://custom.litellm.com/v1'
21+ assert provider .base_url == 'https://custom.litellm.com/v1/ '
2222 assert provider .client .api_key == 'test-key'
2323
2424
@@ -30,7 +30,7 @@ def test_init_with_custom_llm_provider():
3030def test_init_without_api_key ():
3131 provider = LiteLLMProvider ()
3232 assert provider .name == 'litellm'
33- assert provider .base_url == 'https://api.litellm.ai/v1'
33+ assert provider .base_url == 'https://api.litellm.ai/v1/ '
3434 assert provider .client .api_key == 'litellm-placeholder'
3535
3636
@@ -68,24 +68,77 @@ class MockProfile:
6868def test_model_profile_with_different_models (mocker : MockerFixture ):
6969 provider = LiteLLMProvider (api_key = 'test-key' )
7070
71- mock_openai_profile = mocker .patch ('pydantic_ai.providers.litellm.openai_model_profile' , return_value = {})
71+ # Create mocks for all profile functions
72+ from dataclasses import dataclass
7273
73- # Test with different model formats that LiteLLM supports
74- test_models = [
75- 'gpt-4' ,
76- 'claude-3-sonnet-20240229' ,
77- 'gemini-pro' ,
78- 'llama2-70b-chat' ,
79- 'anthropic/claude-3-haiku-20240307' ,
80- ]
74+ @dataclass
75+ class MockProfile :
76+ max_tokens : int = 4096
77+ supports_streaming : bool = True
8178
82- for model in test_models :
79+ # Mock all profile functions
80+ mock_profiles = {
81+ 'openai' : mocker .patch ('pydantic_ai.providers.litellm.openai_model_profile' , return_value = MockProfile ()),
82+ 'anthropic' : mocker .patch ('pydantic_ai.providers.litellm.anthropic_model_profile' , return_value = MockProfile ()),
83+ 'google' : mocker .patch ('pydantic_ai.providers.litellm.google_model_profile' , return_value = MockProfile ()),
84+ 'meta' : mocker .patch ('pydantic_ai.providers.litellm.meta_model_profile' , return_value = MockProfile ()),
85+ 'mistral' : mocker .patch ('pydantic_ai.providers.litellm.mistral_model_profile' , return_value = MockProfile ()),
86+ 'cohere' : mocker .patch ('pydantic_ai.providers.litellm.cohere_model_profile' , return_value = MockProfile ()),
87+ 'amazon' : mocker .patch ('pydantic_ai.providers.litellm.amazon_model_profile' , return_value = MockProfile ()),
88+ 'deepseek' : mocker .patch ('pydantic_ai.providers.litellm.deepseek_model_profile' , return_value = MockProfile ()),
89+ 'groq' : mocker .patch ('pydantic_ai.providers.litellm.groq_model_profile' , return_value = MockProfile ()),
90+ 'grok' : mocker .patch ('pydantic_ai.providers.litellm.grok_model_profile' , return_value = MockProfile ()),
91+ 'moonshotai' : mocker .patch (
92+ 'pydantic_ai.providers.litellm.moonshotai_model_profile' , return_value = MockProfile ()
93+ ),
94+ 'qwen' : mocker .patch ('pydantic_ai.providers.litellm.qwen_model_profile' , return_value = MockProfile ()),
95+ }
96+
97+ # Test models without provider prefix (should use openai profile)
98+ models_without_prefix = ['gpt-4' , 'claude-3-sonnet' , 'gemini-pro' , 'llama2-70b' ]
99+
100+ for model in models_without_prefix :
83101 profile = provider .model_profile (model )
84102 assert isinstance (profile , OpenAIModelProfile )
85103 assert profile .json_schema_transformer == OpenAIJsonSchemaTransformer
86104
87- # Verify openai_model_profile was called for each model
88- assert mock_openai_profile .call_count == len (test_models )
105+ # Verify openai_model_profile was called for each model without prefix
106+ assert mock_profiles ['openai' ].call_count == len (models_without_prefix )
107+
108+ # Reset all call counts
109+ for mock in mock_profiles .values ():
110+ mock .reset_mock ()
111+
112+ # Test all provider prefixes
113+ test_cases = [
114+ ('anthropic/claude-3-haiku' , 'anthropic' , 'claude-3-haiku' ),
115+ ('openai/gpt-4-turbo' , 'openai' , 'gpt-4-turbo' ),
116+ ('google/gemini-1.5-pro' , 'google' , 'gemini-1.5-pro' ),
117+ ('mistralai/mistral-large' , 'mistral' , 'mistral-large' ),
118+ ('mistral/mistral-7b' , 'mistral' , 'mistral-7b' ),
119+ ('cohere/command-r' , 'cohere' , 'command-r' ),
120+ ('amazon/titan-text' , 'amazon' , 'titan-text' ),
121+ ('bedrock/claude-v2' , 'amazon' , 'claude-v2' ),
122+ ('meta-llama/llama-3-8b' , 'meta' , 'llama-3-8b' ),
123+ ('meta/llama-2-70b' , 'meta' , 'llama-2-70b' ),
124+ ('groq/llama3-70b' , 'groq' , 'llama3-70b' ),
125+ ('deepseek/deepseek-coder' , 'deepseek' , 'deepseek-coder' ),
126+ ('moonshotai/moonshot-v1' , 'moonshotai' , 'moonshot-v1' ),
127+ ('x-ai/grok-beta' , 'grok' , 'grok-beta' ),
128+ ('qwen/qwen-72b' , 'qwen' , 'qwen-72b' ),
129+ ]
130+
131+ for model_name , expected_profile , expected_suffix in test_cases :
132+ profile = provider .model_profile (model_name )
133+ assert isinstance (profile , OpenAIModelProfile )
134+ assert profile .json_schema_transformer == OpenAIJsonSchemaTransformer
135+ # Verify the correct profile function was called with the correct suffix
136+ mock_profiles [expected_profile ].assert_called_with (expected_suffix )
137+ mock_profiles [expected_profile ].reset_mock ()
138+
139+ # Test unknown provider prefix (should fall back to openai)
140+ provider .model_profile ('unknown-provider/some-model' )
141+ mock_profiles ['openai' ].assert_called_once_with ('unknown-provider/some-model' )
89142
90143
91144async def test_cached_http_client_usage (mocker : MockerFixture ):
0 commit comments