diff --git a/pydantic_ai_slim/pydantic_ai/models/__init__.py b/pydantic_ai_slim/pydantic_ai/models/__init__.py index d329154ff6..06258bc6ba 100644 --- a/pydantic_ai_slim/pydantic_ai/models/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/models/__init__.py @@ -764,7 +764,9 @@ def infer_model(model: Model | KnownModelName | str) -> Model: # noqa: C901 model_kind = provider_name if model_kind.startswith('gateway/'): - model_kind = provider_name.removeprefix('gateway/') + from ..providers.gateway import infer_gateway_model + + return infer_gateway_model(model_kind.removeprefix('gateway/'), model_name=model_name) if model_kind in ( 'openai', 'azure', diff --git a/pydantic_ai_slim/pydantic_ai/models/bedrock.py b/pydantic_ai_slim/pydantic_ai/models/bedrock.py index 7e0bc6a009..256e300099 100644 --- a/pydantic_ai_slim/pydantic_ai/models/bedrock.py +++ b/pydantic_ai_slim/pydantic_ai/models/bedrock.py @@ -226,7 +226,7 @@ def __init__( self._model_name = model_name if isinstance(provider, str): - provider = infer_provider('gateway/bedrock' if provider == 'gateway' else provider) + provider = infer_provider('gateway/converse' if provider == 'gateway' else provider) self._provider = provider self.client = cast('BedrockRuntimeClient', provider.client) diff --git a/pydantic_ai_slim/pydantic_ai/models/google.py b/pydantic_ai_slim/pydantic_ai/models/google.py index d17f766219..875d32ee75 100644 --- a/pydantic_ai_slim/pydantic_ai/models/google.py +++ b/pydantic_ai_slim/pydantic_ai/models/google.py @@ -204,7 +204,7 @@ def __init__( self._model_name = model_name if isinstance(provider, str): - provider = infer_provider('gateway/google-vertex' if provider == 'gateway' else provider) + provider = infer_provider('gateway/gemini' if provider == 'gateway' else provider) self._provider = provider self.client = provider.client diff --git a/pydantic_ai_slim/pydantic_ai/models/openai.py b/pydantic_ai_slim/pydantic_ai/models/openai.py index ed1e711823..a51ecff1b3 100644 --- a/pydantic_ai_slim/pydantic_ai/models/openai.py +++ b/pydantic_ai_slim/pydantic_ai/models/openai.py @@ -375,7 +375,7 @@ def __init__( self._model_name = model_name if isinstance(provider, str): - provider = infer_provider('gateway/openai' if provider == 'gateway' else provider) + provider = infer_provider('gateway/chat' if provider == 'gateway' else provider) self._provider = provider self.client = provider.client @@ -944,7 +944,7 @@ def __init__( self._model_name = model_name if isinstance(provider, str): - provider = infer_provider('gateway/openai' if provider == 'gateway' else provider) + provider = infer_provider('gateway/responses' if provider == 'gateway' else provider) self._provider = provider self.client = provider.client diff --git a/pydantic_ai_slim/pydantic_ai/providers/__init__.py b/pydantic_ai_slim/pydantic_ai/providers/__init__.py index 714d29b285..945f4fec4a 100644 --- a/pydantic_ai_slim/pydantic_ai/providers/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/providers/__init__.py @@ -158,8 +158,8 @@ def infer_provider(provider: str) -> Provider[Any]: if provider.startswith('gateway/'): from .gateway import gateway_provider - provider = provider.removeprefix('gateway/') - return gateway_provider(provider) + api_type = provider.removeprefix('gateway/') + return gateway_provider(api_type) elif provider in ('google-vertex', 'google-gla'): from .google import GoogleProvider diff --git a/pydantic_ai_slim/pydantic_ai/providers/gateway.py b/pydantic_ai_slim/pydantic_ai/providers/gateway.py index bf4e2560bb..7f28aeb606 100644 --- a/pydantic_ai_slim/pydantic_ai/providers/gateway.py +++ b/pydantic_ai_slim/pydantic_ai/providers/gateway.py @@ -17,6 +17,7 @@ from groq import AsyncGroq from openai import AsyncOpenAI + from pydantic_ai.models import Model from pydantic_ai.models.anthropic import AsyncAnthropicClient from pydantic_ai.providers import Provider @@ -25,7 +26,8 @@ @overload def gateway_provider( - upstream_provider: Literal['openai', 'openai-chat', 'openai-responses'], + api_type: Literal['chat', 'responses'], + /, *, api_key: str | None = None, base_url: str | None = None, @@ -35,7 +37,8 @@ def gateway_provider( @overload def gateway_provider( - upstream_provider: Literal['groq'], + api_type: Literal['groq'], + /, *, api_key: str | None = None, base_url: str | None = None, @@ -45,56 +48,63 @@ def gateway_provider( @overload def gateway_provider( - upstream_provider: Literal['google-vertex'], + api_type: Literal['anthropic'], + /, *, api_key: str | None = None, base_url: str | None = None, -) -> Provider[GoogleClient]: ... + http_client: httpx.AsyncClient | None = None, +) -> Provider[AsyncAnthropicClient]: ... @overload def gateway_provider( - upstream_provider: Literal['anthropic'], + api_type: Literal['converse'], + /, *, api_key: str | None = None, base_url: str | None = None, -) -> Provider[AsyncAnthropicClient]: ... +) -> Provider[BaseClient]: ... @overload def gateway_provider( - upstream_provider: Literal['bedrock'], + api_type: Literal['gemini'], + /, *, api_key: str | None = None, base_url: str | None = None, -) -> Provider[BaseClient]: ... + http_client: httpx.AsyncClient | None = None, +) -> Provider[GoogleClient]: ... @overload def gateway_provider( - upstream_provider: str, + api_type: str, + /, *, api_key: str | None = None, base_url: str | None = None, ) -> Provider[Any]: ... -UpstreamProvider = Literal['openai', 'openai-chat', 'openai-responses', 'groq', 'google-vertex', 'anthropic', 'bedrock'] +APIType = Literal['chat', 'responses', 'gemini', 'converse', 'anthropic', 'groq'] def gateway_provider( - upstream_provider: UpstreamProvider | str, + api_type: APIType | str, + /, *, # Every provider api_key: str | None = None, base_url: str | None = None, - # OpenAI, Groq & Anthropic + # OpenAI, Groq, Anthropic & Gemini - Only Bedrock doesn't have an HTTPX client. http_client: httpx.AsyncClient | None = None, ) -> Provider[Any]: """Create a new Gateway provider. Args: - upstream_provider: The upstream provider to use. + api_type: Determines the API type to use. api_key: The API key to use for authentication. If not provided, the `PYDANTIC_AI_GATEWAY_API_KEY` environment variable will be used if available. base_url: The base URL to use for the Gateway. If not provided, the `PYDANTIC_AI_GATEWAY_BASE_URL` @@ -109,18 +119,18 @@ def gateway_provider( ) base_url = base_url or os.getenv('PYDANTIC_AI_GATEWAY_BASE_URL', GATEWAY_BASE_URL) - http_client = http_client or cached_async_http_client(provider=f'gateway/{upstream_provider}') + http_client = http_client or cached_async_http_client(provider=f'gateway/{api_type}') http_client.event_hooks = {'request': [_request_hook(api_key)]} - if upstream_provider in ('openai', 'openai-chat', 'openai-responses'): + if api_type in ('chat', 'responses'): from .openai import OpenAIProvider - return OpenAIProvider(api_key=api_key, base_url=_merge_url_path(base_url, 'openai'), http_client=http_client) - elif upstream_provider == 'groq': + return OpenAIProvider(api_key=api_key, base_url=_merge_url_path(base_url, api_type), http_client=http_client) + elif api_type == 'groq': from .groq import GroqProvider return GroqProvider(api_key=api_key, base_url=_merge_url_path(base_url, 'groq'), http_client=http_client) - elif upstream_provider == 'anthropic': + elif api_type == 'anthropic': from anthropic import AsyncAnthropic from .anthropic import AnthropicProvider @@ -132,25 +142,25 @@ def gateway_provider( http_client=http_client, ) ) - elif upstream_provider == 'bedrock': + elif api_type == 'converse': from .bedrock import BedrockProvider return BedrockProvider( api_key=api_key, - base_url=_merge_url_path(base_url, 'bedrock'), + base_url=_merge_url_path(base_url, api_type), region_name='pydantic-ai-gateway', # Fake region name to avoid NoRegionError ) - elif upstream_provider == 'google-vertex': + elif api_type == 'gemini': from .google import GoogleProvider return GoogleProvider( vertexai=True, api_key=api_key, - base_url=_merge_url_path(base_url, 'google-vertex'), + base_url=_merge_url_path(base_url, 'gemini'), http_client=http_client, ) else: - raise UserError(f'Unknown upstream provider: {upstream_provider}') + raise UserError(f'Unknown API type: {api_type}') def _request_hook(api_key: str) -> Callable[[httpx.Request], Awaitable[httpx.Request]]: @@ -182,3 +192,33 @@ def _merge_url_path(base_url: str, path: str) -> str: path: The path to merge. """ return base_url.rstrip('/') + '/' + path.lstrip('/') + + +def infer_gateway_model(api_type: APIType | str, *, model_name: str) -> Model: + """Infer the model class for a given API type.""" + if api_type == 'chat': + from pydantic_ai.models.openai import OpenAIChatModel + + return OpenAIChatModel(model_name=model_name, provider='gateway') + elif api_type == 'groq': + from pydantic_ai.models.groq import GroqModel + + return GroqModel(model_name=model_name, provider='gateway') + elif api_type == 'responses': + from pydantic_ai.models.openai import OpenAIResponsesModel + + return OpenAIResponsesModel(model_name=model_name, provider='gateway') + elif api_type == 'gemini': + from pydantic_ai.models.google import GoogleModel + + return GoogleModel(model_name=model_name, provider='gateway') + elif api_type == 'converse': + from pydantic_ai.models.bedrock import BedrockConverseModel + + return BedrockConverseModel(model_name=model_name, provider='gateway') + elif api_type == 'anthropic': + from pydantic_ai.models.anthropic import AnthropicModel + + return AnthropicModel(model_name=model_name, provider='gateway') + else: + raise ValueError(f'Unknown API type: {api_type}') # pragma: no cover diff --git a/tests/models/cassettes/test_model_names/test_known_model_names.yaml b/tests/models/cassettes/test_model_names/test_known_model_names.yaml index a2e71d3cec..1c55c62948 100644 --- a/tests/models/cassettes/test_model_names/test_known_model_names.yaml +++ b/tests/models/cassettes/test_model_names/test_known_model_names.yaml @@ -134,31 +134,31 @@ interactions: parsed_body: data: - created: 0 - id: qwen-3-32b + id: qwen-3-235b-a22b-thinking-2507 object: model owned_by: Cerebras - created: 0 - id: qwen-3-235b-a22b-instruct-2507 + id: llama-3.3-70b object: model owned_by: Cerebras - created: 0 - id: gpt-oss-120b + id: qwen-3-235b-a22b-instruct-2507 object: model owned_by: Cerebras - created: 0 - id: zai-glm-4.6 + id: qwen-3-32b object: model owned_by: Cerebras - created: 0 - id: llama3.1-8b + id: zai-glm-4.6 object: model owned_by: Cerebras - created: 0 - id: llama-3.3-70b + id: gpt-oss-120b object: model owned_by: Cerebras - created: 0 - id: qwen-3-235b-a22b-thinking-2507 + id: llama3.1-8b object: model owned_by: Cerebras object: list diff --git a/tests/models/test_model.py b/tests/models/test_model.py index b824fcbfe6..20a50699d8 100644 --- a/tests/models/test_model.py +++ b/tests/models/test_model.py @@ -29,30 +29,21 @@ TEST_CASES = [ pytest.param( {'PYDANTIC_AI_GATEWAY_API_KEY': 'gateway-api-key'}, - 'gateway/openai:gpt-5', + 'gateway/chat:gpt-5', 'gpt-5', 'openai', 'openai', OpenAIChatModel, - id='gateway/openai:gpt-5', + id='gateway/chat:gpt-5', ), pytest.param( {'PYDANTIC_AI_GATEWAY_API_KEY': 'gateway-api-key'}, - 'gateway/openai-chat:gpt-5', - 'gpt-5', - 'openai', - 'openai', - OpenAIChatModel, - id='gateway/openai-chat:gpt-5', - ), - pytest.param( - {'PYDANTIC_AI_GATEWAY_API_KEY': 'gateway-api-key'}, - 'gateway/openai-responses:gpt-5', + 'gateway/responses:gpt-5', 'gpt-5', 'openai', 'openai', OpenAIResponsesModel, - id='gateway/openai-responses:gpt-5', + id='gateway/responses:gpt-5', ), pytest.param( {'PYDANTIC_AI_GATEWAY_API_KEY': 'gateway-api-key'}, @@ -65,12 +56,12 @@ ), pytest.param( {'PYDANTIC_AI_GATEWAY_API_KEY': 'gateway-api-key'}, - 'gateway/google-vertex:gemini-1.5-flash', + 'gateway/gemini:gemini-1.5-flash', 'gemini-1.5-flash', 'google-vertex', 'google', GoogleModel, - id='gateway/google-vertex:gemini-1.5-flash', + id='gateway/gemini:gemini-1.5-flash', ), pytest.param( {'PYDANTIC_AI_GATEWAY_API_KEY': 'gateway-api-key'}, @@ -83,12 +74,12 @@ ), pytest.param( {'PYDANTIC_AI_GATEWAY_API_KEY': 'gateway-api-key'}, - 'gateway/bedrock:amazon.nova-micro-v1:0', + 'gateway/converse:amazon.nova-micro-v1:0', 'amazon.nova-micro-v1:0', 'bedrock', 'bedrock', BedrockConverseModel, - id='gateway/bedrock:amazon.nova-micro-v1:0', + id='gateway/converse:amazon.nova-micro-v1:0', ), pytest.param( {'OPENAI_API_KEY': 'openai-api-key'}, diff --git a/tests/providers/cassettes/test_gateway/test_gateway_provider_with_bedrock.yaml b/tests/providers/cassettes/test_gateway/test_gateway_provider_with_bedrock.yaml index c6040e7ad7..ea9436f55d 100644 --- a/tests/providers/cassettes/test_gateway/test_gateway_provider_with_bedrock.yaml +++ b/tests/providers/cassettes/test_gateway/test_gateway_provider_with_bedrock.yaml @@ -5,7 +5,7 @@ interactions: headers: amz-sdk-invocation-id: - !!binary | - MWYwNDlkMTQtMjVmMC00YTRhLWJhYmMtNTQ0MDdhMmRlNjgw + MmEzMzkzMGUtNzI3YS00YzFhLWFmYWQtYzFhYWMyMTI3NDlj amz-sdk-request: - !!binary | YXR0ZW1wdD0x @@ -15,7 +15,7 @@ interactions: - !!binary | YXBwbGljYXRpb24vanNvbg== method: POST - uri: http://localhost:8787/bedrock/model/amazon.nova-micro-v1%3A0/converse + uri: http://localhost:8787/converse/model/amazon.nova-micro-v1%3A0/converse response: headers: content-length: @@ -26,7 +26,7 @@ interactions: - 0.0000USD parsed_body: metrics: - latencyMs: 668 + latencyMs: 682 output: message: content: diff --git a/tests/providers/cassettes/test_gateway/test_gateway_provider_with_google_vertex.yaml b/tests/providers/cassettes/test_gateway/test_gateway_provider_with_google_vertex.yaml index f1a6a76fd4..c92f55fe07 100644 --- a/tests/providers/cassettes/test_gateway/test_gateway_provider_with_google_vertex.yaml +++ b/tests/providers/cassettes/test_gateway/test_gateway_provider_with_google_vertex.yaml @@ -8,7 +8,7 @@ interactions: connection: - keep-alive content-length: - - '111' + - '141' content-type: - application/json host: @@ -19,46 +19,48 @@ interactions: - parts: - text: What is the capital of France? role: user - generationConfig: {} - uri: http://localhost:8787/google-vertex/v1beta1/publishers/google/models/gemini-1.5-flash:generateContent + generationConfig: + responseModalities: + - TEXT + uri: http://localhost:8787/gemini/v1beta1/publishers/google/models/gemini-2.5-flash:generateContent response: headers: alt-svc: - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 content-length: - - '536' + - '581' content-type: - application/json; charset=UTF-8 pydantic-ai-gateway-price-estimate: - - 0.0000USD + - 0.0002USD transfer-encoding: - chunked vary: - Origin, X-Origin, Referer parsed_body: candidates: - - avgLogprobs: -0.014910747297108173 + - avgLogprobs: -1.1077090501785278 content: parts: - - text: | - Paris + - text: The capital of France is **Paris**. role: model finishReason: STOP - createTime: '2025-09-11T07:46:07.967270Z' - modelVersion: gemini-1.5-flash-002 - responseId: P37CaOaEO4OugLUPlpariAU + createTime: '2025-11-09T10:55:33.769850Z' + modelVersion: gemini-2.5-flash + responseId: JXMQabr-LsPv2fMP0YexuA8 usageMetadata: - candidatesTokenCount: 2 + candidatesTokenCount: 8 candidatesTokensDetails: - modality: TEXT - tokenCount: 2 + tokenCount: 8 promptTokenCount: 7 promptTokensDetails: - modality: TEXT tokenCount: 7 pydantic_ai_gateway: - cost_estimate: 1.125e-06 - totalTokenCount: 9 + cost_estimate: 0.0001546 + thoughtsTokenCount: 53 + totalTokenCount: 68 trafficType: ON_DEMAND status: code: 200 diff --git a/tests/providers/cassettes/test_gateway/test_gateway_provider_with_openai.yaml b/tests/providers/cassettes/test_gateway/test_gateway_provider_with_openai.yaml index a68e5fbc18..904213b062 100644 --- a/tests/providers/cassettes/test_gateway/test_gateway_provider_with_openai.yaml +++ b/tests/providers/cassettes/test_gateway/test_gateway_provider_with_openai.yaml @@ -20,15 +20,25 @@ interactions: role: user model: gpt-5 stream: false - uri: http://localhost:8787/openai/chat/completions + uri: http://localhost:8787/chat/chat/completions response: headers: + access-control-expose-headers: + - X-Request-ID + alt-svc: + - h3=":443"; ma=86400 content-length: - '614' content-type: - application/json + openai-processing-ms: + - '2046' + openai-version: + - '2020-10-01' pydantic-ai-gateway-price-estimate: - 0.0001USD + strict-transport-security: + - max-age=31536000; includeSubDomains; preload transfer-encoding: - chunked parsed_body: @@ -40,8 +50,8 @@ interactions: content: Paris. refusal: null role: assistant - created: 1757577825 - id: chatcmpl-CEWZNtQzFkg7hbdjFKPJtCT7CwGyc + created: 1762685398 + id: chatcmpl-CZxHa0hkNXToLinJr0VlRYia1OMT0 model: gpt-5-2025-08-07 object: chat.completion service_tier: default diff --git a/tests/providers/cassettes/test_gateway/test_gateway_provider_with_openai_responses.yaml b/tests/providers/cassettes/test_gateway/test_gateway_provider_with_openai_responses.yaml index c1e7a72b9c..1f5b5c8aaf 100644 --- a/tests/providers/cassettes/test_gateway/test_gateway_provider_with_openai_responses.yaml +++ b/tests/providers/cassettes/test_gateway/test_gateway_provider_with_openai_responses.yaml @@ -22,22 +22,32 @@ interactions: role: user model: gpt-5 stream: false - uri: http://localhost:8787/openai/responses + uri: http://localhost:8787/responses/responses response: headers: + alt-svc: + - h3=":443"; ma=86400 content-length: - - '2081' + - '2651' content-type: - application/json + openai-processing-ms: + - '2444' + openai-version: + - '2020-10-01' pydantic-ai-gateway-price-estimate: - - 0.0001USD + - 0.0007USD + strict-transport-security: + - max-age=31536000; includeSubDomains; preload transfer-encoding: - chunked parsed_body: background: false - created_at: 1757578027 + billing: + payer: developer + created_at: 1762685402 error: null - id: resp_68c2832b6e8881a39b7a5c51ca65f5ed0edb299050b14e3e + id: resp_0b0a499f6332473100691071da2f6081a184bad9183c461b04 incomplete_details: null instructions: null max_output_tokens: null @@ -46,8 +56,8 @@ interactions: model: gpt-5-2025-08-07 object: response output: - - encrypted_content: gAAAAABowoMtzpLx3I0ACILqVQrD5JLqPsSsnOId4zO9UPFzwHvShOS-m-z-qvUDPOKg8lpiGdxwbnTucs2LKboLMFhy4HdnBKA2Ht5eO0VaFiicrjJX6di0eEciaMi7WfJcmtEqz_E911yZIX112SeooYEbL3wSZ0nchpYkaGL_uVaqLITnIDnutiwenhiqGbLSTM7slXONnUF3JquEIdYhuEdII6usslBXFHKpcv0JGVMAv2m0G24Zj2RA3BihqZnmdCFXUVeEvBIdQGPjf7BK1AGQCg9Bk3FPFJqPjTrjqCBuoP8oloJT-GW0qaOChasym2zWU3Xvaqr7K4uiTR6ss1F5Krex7IF9GErHMBxuKCKKbj3-xb5plTxey-J-dex_swVvzT13RWqRB99TPd3k3wxnZgbWRkieEnw7T37Iq1SpEncke67vwCJNJ1bEUTEcXYiY34x3eAw5nF3wvVl_TwZUtZ8CfLbEdyH4t_WndTNrMUMjxA6TRgD944KKXOiVt-NwwL3dpY5ZCxlHx-NKVyEu1rLWBIFFOMJE50lRzhs7oOR1FUKdTfDsy23zB-p9V1OwGsE5i37GXd-efBVHt10saATXHXikEPHkAtkK5eketZ8NnwRKHPKcbHRRnMBOTD3ygvA8-gYo81oWusWSu_1M4WvHKNpawX7o24eLpXyN7YRnYS1LQonoXPES5n9G5AAzryc7ROwcHwQ_709Ua-2YXqGFgo8X9n5kY3QC-g0ZFxCO0clCBIOImoVEM5VRt0WRK30RFHiKl6_JhZ4OCiHWHb2oN131FOmb_UbMk698CsKNbY6-soEnyWMfFcRGQGsVI1-_vrwK2DeQuzL5jW0Fm3qpdnRKVeVhHyBDaekwHeWClTw= - id: rs_68c2832bf1a481a3b653d20f5608f9750edb299050b14e3e + - encrypted_content: gAAAAABpEHHcIjiqxbZXNndi9xNsTyvngDbxl8seH4HjGuU0P_aiyztPIEtTBwhaq7QRlf2pobZHONLRzflx_6xxB7Do7X9UkcJ_c0zjvwyMGK83CJEwb6_n5qQJy3VdtfQXJRfLJhUSn7ozk4iTn65h_qurYMntYCxP6Cq11I9bHKEaUesKGeYgCwUtF9mmvgTwfRBxefHohNCa8HpLbHUBETMfBecL4OBVJvk7RvpcRtELL8hZgQCvjJGQXYC23TjeN83ctuvIfA5l5n-y6fOtxIc-TcTgLPqUoa6CwYuQRFF_jXXrY8x4bLyPv6X6KI-we4S_lr8h6V-cG_k5f3ezTWq7V_AhCwodZFW0keNdsxVm9X4PDpDNo5_7vobA6XBC6HLzWWpwQDqNWaDmQ6hbA82WewrFKoUSur86ekXDA8Ju82U1HkSzakOs4xUE9Ix9djVNQowDD53BybUCF92SImLZko9j4dJM5TP3zCQ9SYJ_06B_rN3JG5IJ9MM0LNGHFfD6mAn4omBakgG_FCQfBD0t34cEn9VobNqcJffa3slLOikIy9jqy7jUdf2r3cdh4mJpzP5asvcrvm7sDCtnT3aUVOBSWkbLLHAcrT7N60aZjJZahKOi4xlLHAbjKnR4fFbs3qSUIL9Ig44568oNHrGCnIEZ2h7ojpIb1lgNUcEXdIGmhNzLyqkqoWk0Uk8Y_-tD38zPdy-g_xdJjgmHe0797o0pD12aJkMJ3Pkmo9hDe-kZtR4A1rIQFDIJF--NiE2SjT5rZM9ntGrEcEVD6Nc3dxzWnrziTPfWuomovRWDuEbBnSYJ2wYIVl-snX5fPn_ENPS8YBY5ymZhbqOp9r6GA-OrCU5Ipv46VIY521nPcx4pC7PJ6X877N9CCd1j-URlKAyySgZNx6PZ_NiDxLAnsAPx-YlhvuT3Y8qjms7r98RCwa5hmbisrGFBQoKK22Lsinu4Ak_qO673JMJPWNK_ZtFyD7BKfmqPZm_EtmAC8GMXyrvmRCGU7jlienzW8DoA_8z-hidaACvAoVqPpQURxocKUMmHLeAfALDcpJHmWB2vJjl1oKQZ1JerwLatQMIWwONe4eTSKKw5KX4nn9wiGkQQU4bJHnFuanoSiaQvuPAXucp7I7pBBDhJT5X5ixWsR5pcTjs32ILI7YDTwc0tMWZ9gR5910O_K6cZHMPqDN1cCJvuf_5qsGFkVe217Osl7QyXHODCVbK6YXO_D7RGN5FCCypLI4FinDxfewQbiIFAcizF1TlCmJ6gF5vRQyio271qOLKOT497_1TKTbLHcM37D2C7MXrO59sNwl9-UBq1eowuquFiy24vmQePi9A0lNVrSD3kAhMpd8j3sTwIaoWlGabm194Oqm-3kHfyrqR6ExU= + id: rs_0b0a499f6332473100691071db0c2c81a1943920f74bbe507b summary: [] type: reasoning - content: @@ -55,13 +65,14 @@ interactions: logprobs: [] text: Paris. type: output_text - id: msg_68c2832c1dcc81a38f7b5fb9e3cae9000edb299050b14e3e + id: msg_0b0a499f6332473100691071dc471c81a1a55121b76f7d995e role: assistant status: completed type: message parallel_tool_calls: true previous_response_id: null prompt_cache_key: null + prompt_cache_retention: null reasoning: effort: medium summary: null @@ -83,12 +94,12 @@ interactions: input_tokens: 13 input_tokens_details: cached_tokens: 0 - output_tokens: 8 + output_tokens: 72 output_tokens_details: - reasoning_tokens: 0 + reasoning_tokens: 64 pydantic_ai_gateway: - cost_estimate: 9.625000000000001e-05 - total_tokens: 21 + cost_estimate: 0.00073625 + total_tokens: 85 user: null status: code: 200 diff --git a/tests/providers/test_gateway.py b/tests/providers/test_gateway.py index 6194a52636..6b3b16001d 100644 --- a/tests/providers/test_gateway.py +++ b/tests/providers/test_gateway.py @@ -34,15 +34,12 @@ @pytest.mark.parametrize( - 'provider_name, provider_cls', - [('openai', OpenAIProvider), ('openai-chat', OpenAIProvider), ('openai-responses', OpenAIProvider)], + 'provider_name, provider_cls, path', [('chat', OpenAIProvider, 'chat'), ('responses', OpenAIProvider, 'responses')] ) -def test_init_with_base_url( - provider_name: Literal['openai', 'openai-chat', 'openai-responses'], provider_cls: type[Provider[Any]] -): +def test_init_with_base_url(provider_name: Literal['chat', 'responses'], provider_cls: type[Provider[Any]], path: str): provider = gateway_provider(provider_name, base_url='https://example.com/', api_key='foobar') assert isinstance(provider, provider_cls) - assert provider.base_url == 'https://example.com/openai/' + assert provider.base_url == f'https://example.com/{path}/' assert provider.client.api_key == 'foobar' @@ -54,12 +51,12 @@ def test_init_gateway_without_api_key_raises_error(env: TestEnv): 'Set the `PYDANTIC_AI_GATEWAY_API_KEY` environment variable or pass it via `gateway_provider(..., api_key=...)` to use the Pydantic AI Gateway provider.' ), ): - gateway_provider('openai') + gateway_provider('chat') async def test_init_with_http_client(): async with httpx.AsyncClient() as http_client: - provider = gateway_provider('openai', http_client=http_client, api_key='foobar') + provider = gateway_provider('chat', http_client=http_client, api_key='foobar') assert provider.client._client == http_client # type: ignore @@ -84,13 +81,12 @@ def vcr_config(): @pytest.mark.parametrize( 'provider_name, provider_cls, path', [ - ('openai', OpenAIProvider, 'openai'), - ('openai-chat', OpenAIProvider, 'openai'), - ('openai-responses', OpenAIProvider, 'openai'), + ('chat', OpenAIProvider, 'chat'), + ('responses', OpenAIProvider, 'responses'), ('groq', GroqProvider, 'groq'), - ('google-vertex', GoogleProvider, 'google-vertex'), + ('gemini', GoogleProvider, 'gemini'), ('anthropic', AnthropicProvider, 'anthropic'), - ('bedrock', BedrockProvider, 'bedrock'), + ('converse', BedrockProvider, 'converse'), ], ) def test_gateway_provider(provider_name: str, provider_cls: type[Provider[Any]], path: str): @@ -98,20 +94,17 @@ def test_gateway_provider(provider_name: str, provider_cls: type[Provider[Any]], assert isinstance(provider, provider_cls) # Some providers add a trailing slash, others don't - assert provider.base_url in ( - f'{GATEWAY_BASE_URL}/{path}/', - f'{GATEWAY_BASE_URL}/{path}', - ) + assert provider.base_url in (f'{GATEWAY_BASE_URL}/{path}/', f'{GATEWAY_BASE_URL}/{path}') @patch.dict(os.environ, {'PYDANTIC_AI_GATEWAY_API_KEY': 'test-api-key'}) def test_gateway_provider_unknown(): - with raises(snapshot('UserError: Unknown upstream provider: foo')): + with raises(snapshot('UserError: Unknown API type: foo')): gateway_provider('foo') async def test_gateway_provider_with_openai(allow_model_requests: None, gateway_api_key: str): - provider = gateway_provider('openai', api_key=gateway_api_key, base_url='http://localhost:8787') + provider = gateway_provider('chat', api_key=gateway_api_key, base_url='http://localhost:8787') model = OpenAIChatModel('gpt-5', provider=provider) agent = Agent(model) @@ -120,7 +113,7 @@ async def test_gateway_provider_with_openai(allow_model_requests: None, gateway_ async def test_gateway_provider_with_openai_responses(allow_model_requests: None, gateway_api_key: str): - provider = gateway_provider('openai-responses', api_key=gateway_api_key, base_url='http://localhost:8787') + provider = gateway_provider('responses', api_key=gateway_api_key, base_url='http://localhost:8787') model = OpenAIResponsesModel('gpt-5', provider=provider) agent = Agent(model) @@ -138,12 +131,12 @@ async def test_gateway_provider_with_groq(allow_model_requests: None, gateway_ap async def test_gateway_provider_with_google_vertex(allow_model_requests: None, gateway_api_key: str): - provider = gateway_provider('google-vertex', api_key=gateway_api_key, base_url='http://localhost:8787') - model = GoogleModel('gemini-1.5-flash', provider=provider) + provider = gateway_provider('gemini', api_key=gateway_api_key, base_url='http://localhost:8787') + model = GoogleModel('gemini-2.5-flash', provider=provider) agent = Agent(model) result = await agent.run('What is the capital of France?') - assert result.output == snapshot('Paris\n') + assert result.output == snapshot('The capital of France is **Paris**.') async def test_gateway_provider_with_anthropic(allow_model_requests: None, gateway_api_key: str): @@ -156,7 +149,7 @@ async def test_gateway_provider_with_anthropic(allow_model_requests: None, gatew async def test_gateway_provider_with_bedrock(allow_model_requests: None, gateway_api_key: str): - provider = gateway_provider('bedrock', api_key=gateway_api_key, base_url='http://localhost:8787') + provider = gateway_provider('converse', api_key=gateway_api_key, base_url='http://localhost:8787') model = BedrockConverseModel('amazon.nova-micro-v1:0', provider=provider) agent = Agent(model) diff --git a/tests/providers/test_provider_names.py b/tests/providers/test_provider_names.py index 97b7f963de..55c859faef 100644 --- a/tests/providers/test_provider_names.py +++ b/tests/providers/test_provider_names.py @@ -60,11 +60,11 @@ ('litellm', LiteLLMProvider, None), ('nebius', NebiusProvider, 'NEBIUS_API_KEY'), ('ovhcloud', OVHcloudProvider, 'OVHCLOUD_API_KEY'), - ('gateway/openai', OpenAIProvider, 'PYDANTIC_AI_GATEWAY_API_KEY'), + ('gateway/chat', OpenAIProvider, 'PYDANTIC_AI_GATEWAY_API_KEY'), ('gateway/groq', GroqProvider, 'PYDANTIC_AI_GATEWAY_API_KEY'), - ('gateway/google-vertex', GoogleProvider, 'PYDANTIC_AI_GATEWAY_API_KEY'), + ('gateway/gemini', GoogleProvider, 'PYDANTIC_AI_GATEWAY_API_KEY'), ('gateway/anthropic', AnthropicProvider, 'PYDANTIC_AI_GATEWAY_API_KEY'), - ('gateway/bedrock', BedrockProvider, 'PYDANTIC_AI_GATEWAY_API_KEY'), + ('gateway/converse', BedrockProvider, 'PYDANTIC_AI_GATEWAY_API_KEY'), ('outlines', OutlinesProvider, None), ]