Skip to content

Commit 48ad9c9

Browse files
committed
Add gateway route support to provider inference
1 parent 3450322 commit 48ad9c9

File tree

1 file changed

+15
-3
lines changed

1 file changed

+15
-3
lines changed

pydantic_ai_slim/pydantic_ai/providers/__init__.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from __future__ import annotations as _annotations
77

8+
import re
89
from abc import ABC, abstractmethod
910
from typing import Any, Generic, TypeVar
1011

@@ -153,13 +154,24 @@ def infer_provider_class(provider: str) -> type[Provider[Any]]: # noqa: C901
153154
raise ValueError(f'Unknown provider: {provider}')
154155

155156

157+
_gateway_provider_regex = re.compile('gateway(?::([^/]*))?/(.*)')
158+
"""
159+
This regex matches strings of the forms:
160+
1. gateway/my-upstream-provider
161+
2. gateway:my-route/my-upstream-provider
162+
163+
The (optional) route is the first group, the upstream provider is the second.
164+
"""
165+
166+
156167
def infer_provider(provider: str) -> Provider[Any]:
157168
"""Infer the provider from the provider name."""
158-
if provider.startswith('gateway/'):
169+
if match := re.match(_gateway_provider_regex, provider):
159170
from .gateway import gateway_provider
160171

161-
upstream_provider = provider.removeprefix('gateway/')
162-
return gateway_provider(upstream_provider)
172+
route = match.group(1)
173+
upstream_provider = match.group(2)
174+
return gateway_provider(upstream_provider, route=route)
163175
elif provider in ('google-vertex', 'google-gla'):
164176
from .google import GoogleProvider
165177

0 commit comments

Comments
 (0)