Skip to content

Commit 739395d

Browse files
author
Anton Chaplygin
committed
Add Nebius AI Studio provider support
1 parent 6cf43ea commit 739395d

File tree

10 files changed

+278
-1
lines changed

10 files changed

+278
-1
lines changed

docs/api/providers.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,3 +41,5 @@
4141
::: pydantic_ai.providers.ollama.OllamaProvider
4242

4343
::: pydantic_ai.providers.litellm.LiteLLMProvider
44+
45+
::: pydantic_ai.providers.nebius.NebiusProvider

docs/models/openai.md

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -608,3 +608,35 @@ print(result.output)
608608
#> The capital of France is Paris.
609609
...
610610
```
611+
612+
### Nebius AI Studio
613+
614+
Go to [Nebius AI Studio](https://studio.nebius.com/) and create an API key.
615+
616+
Once you've set the `NEBIUS_API_KEY` environment variable, you can run the following:
617+
618+
```python
619+
from pydantic_ai import Agent
620+
621+
agent = Agent('nebius:Qwen/Qwen3-32B-fast')
622+
result = agent.run_sync('What is the capital of France?')
623+
print(result.output)
624+
#> The capital of France is Paris.
625+
```
626+
627+
If you need to configure the provider, you can use the [`NebiusProvider`][pydantic_ai.providers.nebius.NebiusProvider] class:
628+
629+
```python
630+
from pydantic_ai import Agent
631+
from pydantic_ai.models.openai import OpenAIChatModel
632+
from pydantic_ai.providers.nebius import NebiusProvider
633+
634+
model = OpenAIChatModel(
635+
'Qwen/Qwen3-32B-fast',
636+
provider=NebiusProvider(api_key='your-nebius-api-key'),
637+
)
638+
agent = Agent(model)
639+
result = agent.run_sync('What is the capital of France?')
640+
print(result.output)
641+
#> The capital of France is Paris.
642+
```

docs/models/overview.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ In addition, many providers are compatible with the OpenAI API, and can be used
2828
- [GitHub Models](openai.md#github-models)
2929
- [Cerebras](openai.md#cerebras)
3030
- [LiteLLM](openai.md#litellm)
31+
- [Nebius AI Studio](openai.md#nebius-ai-studio)
3132

3233
Pydantic AI also comes with [`TestModel`](../api/models/test.md) and [`FunctionModel`](../api/models/function.md)
3334
for testing and development.

pydantic_ai_slim/pydantic_ai/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -691,6 +691,7 @@ def infer_model(model: Model | KnownModelName | str) -> Model: # noqa: C901
691691
'together',
692692
'vercel',
693693
'litellm',
694+
'nebius',
694695
):
695696
from .openai import OpenAIChatModel
696697

pydantic_ai_slim/pydantic_ai/models/openai.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,7 @@ def __init__(
284284
'together',
285285
'vercel',
286286
'litellm',
287+
'nebius',
287288
]
288289
| Provider[AsyncOpenAI] = 'openai',
289290
profile: ModelProfileSpec | None = None,
@@ -312,6 +313,7 @@ def __init__(
312313
'together',
313314
'vercel',
314315
'litellm',
316+
'nebius',
315317
]
316318
| Provider[AsyncOpenAI] = 'openai',
317319
profile: ModelProfileSpec | None = None,
@@ -339,6 +341,7 @@ def __init__(
339341
'together',
340342
'vercel',
341343
'litellm',
344+
'nebius',
342345
]
343346
| Provider[AsyncOpenAI] = 'openai',
344347
profile: ModelProfileSpec | None = None,
@@ -899,7 +902,7 @@ def __init__(
899902
self,
900903
model_name: OpenAIModelName,
901904
*,
902-
provider: Literal['openai', 'deepseek', 'azure', 'openrouter', 'grok', 'fireworks', 'together']
905+
provider: Literal['openai', 'deepseek', 'azure', 'openrouter', 'grok', 'fireworks', 'together', 'nebius']
903906
| Provider[AsyncOpenAI] = 'openai',
904907
profile: ModelProfileSpec | None = None,
905908
settings: ModelSettings | None = None,

pydantic_ai_slim/pydantic_ai/providers/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,10 @@ def infer_provider_class(provider: str) -> type[Provider[Any]]: # noqa: C901
142142
from .litellm import LiteLLMProvider
143143

144144
return LiteLLMProvider
145+
elif provider == 'nebius':
146+
from .nebius import NebiusProvider
147+
148+
return NebiusProvider
145149
else: # pragma: no cover
146150
raise ValueError(f'Unknown provider: {provider}')
147151

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
from __future__ import annotations as _annotations
2+
3+
import os
4+
from typing import overload
5+
6+
import httpx
7+
8+
from pydantic_ai import ModelProfile
9+
from pydantic_ai.exceptions import UserError
10+
from pydantic_ai.models import cached_async_http_client
11+
from pydantic_ai.profiles.deepseek import deepseek_model_profile
12+
from pydantic_ai.profiles.google import google_model_profile
13+
from pydantic_ai.profiles.meta import meta_model_profile
14+
from pydantic_ai.profiles.mistral import mistral_model_profile
15+
from pydantic_ai.profiles.moonshotai import moonshotai_model_profile
16+
from pydantic_ai.profiles.openai import OpenAIJsonSchemaTransformer, OpenAIModelProfile, openai_model_profile
17+
from pydantic_ai.profiles.qwen import qwen_model_profile
18+
from pydantic_ai.providers import Provider
19+
20+
try:
21+
from openai import AsyncOpenAI
22+
except ImportError as _import_error: # pragma: no cover
23+
raise ImportError(
24+
'Please install the `openai` package to use the Nebius provider, '
25+
'you can use the `openai` optional group — `pip install "pydantic-ai-slim[openai]"`'
26+
) from _import_error
27+
28+
29+
class NebiusProvider(Provider[AsyncOpenAI]):
30+
"""Provider for Nebius AI Studio API."""
31+
32+
@property
33+
def name(self) -> str:
34+
return 'nebius'
35+
36+
@property
37+
def base_url(self) -> str:
38+
return 'https://api.studio.nebius.com/v1'
39+
40+
@property
41+
def client(self) -> AsyncOpenAI:
42+
return self._client
43+
44+
def model_profile(self, model_name: str) -> ModelProfile | None:
45+
provider_to_profile = {
46+
'meta-llama': meta_model_profile,
47+
'deepseek-ai': deepseek_model_profile,
48+
'qwen': qwen_model_profile,
49+
'google': google_model_profile,
50+
'openai': openai_model_profile,
51+
'mistralai': mistral_model_profile,
52+
'moonshotai': moonshotai_model_profile,
53+
}
54+
55+
profile = None
56+
57+
try:
58+
model_name = model_name.lower()
59+
provider, model_name = model_name.split('/', 1)
60+
except ValueError:
61+
raise UserError(f"Model name must be in 'provider/model' format, got: {model_name!r}")
62+
if provider in provider_to_profile:
63+
profile = provider_to_profile[provider](model_name)
64+
65+
# As NebiusProvider is always used with OpenAIChatModel, which used to unconditionally use OpenAIJsonSchemaTransformer,
66+
# we need to maintain that behavior unless json_schema_transformer is set explicitly
67+
return OpenAIModelProfile(json_schema_transformer=OpenAIJsonSchemaTransformer).update(profile)
68+
69+
@overload
70+
def __init__(self) -> None: ...
71+
72+
@overload
73+
def __init__(self, *, api_key: str) -> None: ...
74+
75+
@overload
76+
def __init__(self, *, api_key: str, http_client: httpx.AsyncClient) -> None: ...
77+
78+
@overload
79+
def __init__(self, *, openai_client: AsyncOpenAI | None = None) -> None: ...
80+
81+
def __init__(
82+
self,
83+
*,
84+
api_key: str | None = None,
85+
openai_client: AsyncOpenAI | None = None,
86+
http_client: httpx.AsyncClient | None = None,
87+
) -> None:
88+
api_key = api_key or os.getenv('NEBIUS_API_KEY')
89+
if not api_key and openai_client is None:
90+
raise UserError(
91+
'Set the `NEBIUS_API_KEY` environment variable or pass it via '
92+
'`NebiusProvider(api_key=...)` to use the Nebius AI Studio provider.'
93+
)
94+
95+
if openai_client is not None:
96+
self._client = openai_client
97+
elif http_client is not None:
98+
self._client = AsyncOpenAI(base_url=self.base_url, api_key=api_key, http_client=http_client)
99+
else:
100+
http_client = cached_async_http_client(provider='nebius')
101+
self._client = AsyncOpenAI(base_url=self.base_url, api_key=api_key, http_client=http_client)

tests/providers/test_nebius.py

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
import re
2+
3+
import httpx
4+
import pytest
5+
from pytest_mock import MockerFixture
6+
7+
from pydantic_ai._json_schema import InlineDefsJsonSchemaTransformer
8+
from pydantic_ai.exceptions import UserError
9+
from pydantic_ai.profiles.deepseek import deepseek_model_profile
10+
from pydantic_ai.profiles.google import GoogleJsonSchemaTransformer, google_model_profile
11+
from pydantic_ai.profiles.meta import meta_model_profile
12+
from pydantic_ai.profiles.mistral import mistral_model_profile
13+
from pydantic_ai.profiles.moonshotai import moonshotai_model_profile
14+
from pydantic_ai.profiles.openai import OpenAIJsonSchemaTransformer, openai_model_profile
15+
from pydantic_ai.profiles.qwen import qwen_model_profile
16+
17+
from ..conftest import TestEnv, try_import
18+
19+
with try_import() as imports_successful:
20+
import openai
21+
22+
from pydantic_ai.providers.nebius import NebiusProvider
23+
24+
25+
pytestmark = [
26+
pytest.mark.skipif(not imports_successful(), reason='openai not installed'),
27+
pytest.mark.vcr,
28+
pytest.mark.anyio,
29+
]
30+
31+
32+
def test_nebius_provider():
33+
provider = NebiusProvider(api_key='api-key')
34+
assert provider.name == 'nebius'
35+
assert provider.base_url == 'https://api.studio.nebius.com/v1'
36+
assert isinstance(provider.client, openai.AsyncOpenAI)
37+
assert provider.client.api_key == 'api-key'
38+
39+
40+
def test_nebius_provider_need_api_key(env: TestEnv) -> None:
41+
env.remove('NEBIUS_API_KEY')
42+
with pytest.raises(
43+
UserError,
44+
match=re.escape(
45+
'Set the `NEBIUS_API_KEY` environment variable or pass it via '
46+
'`NebiusProvider(api_key=...)` to use the Nebius AI Studio provider.'
47+
),
48+
):
49+
NebiusProvider()
50+
51+
52+
def test_nebius_pass_openai_client() -> None:
53+
openai_client = openai.AsyncOpenAI(api_key='api-key')
54+
provider = NebiusProvider(openai_client=openai_client)
55+
assert provider.client == openai_client
56+
57+
58+
def test_nebius_provider_pass_http_client() -> None:
59+
http_client = httpx.AsyncClient()
60+
provider = NebiusProvider(http_client=http_client, api_key='api-key')
61+
assert provider.client._client == http_client # type: ignore[reportPrivateUsage]
62+
63+
64+
def test_nebius_provider_model_profile(mocker: MockerFixture):
65+
provider = NebiusProvider(api_key='api-key')
66+
67+
ns = 'pydantic_ai.providers.nebius'
68+
69+
# Mock all profile functions
70+
meta_mock = mocker.patch(f'{ns}.meta_model_profile', wraps=meta_model_profile)
71+
deepseek_mock = mocker.patch(f'{ns}.deepseek_model_profile', wraps=deepseek_model_profile)
72+
qwen_mock = mocker.patch(f'{ns}.qwen_model_profile', wraps=qwen_model_profile)
73+
google_mock = mocker.patch(f'{ns}.google_model_profile', wraps=google_model_profile)
74+
openai_mock = mocker.patch(f'{ns}.openai_model_profile', wraps=openai_model_profile)
75+
mistral_mock = mocker.patch(f'{ns}.mistral_model_profile', wraps=mistral_model_profile)
76+
moonshotai_mock = mocker.patch(f'{ns}.moonshotai_model_profile', wraps=moonshotai_model_profile)
77+
78+
# Test meta provider
79+
meta_profile = provider.model_profile('meta-llama/Llama-3.3-70B-Instruct')
80+
meta_mock.assert_called_with('llama-3.3-70b-instruct')
81+
assert meta_profile is not None
82+
assert meta_profile.json_schema_transformer == InlineDefsJsonSchemaTransformer
83+
84+
# Test deepseek provider
85+
profile = provider.model_profile('deepseek-ai/DeepSeek-R1-0528')
86+
deepseek_mock.assert_called_with('deepseek-r1-0528')
87+
assert profile is not None
88+
assert profile.json_schema_transformer == OpenAIJsonSchemaTransformer
89+
90+
# Test qwen provider
91+
qwen_profile = provider.model_profile('Qwen/Qwen3-30B-A3B')
92+
qwen_mock.assert_called_with('qwen3-30b-a3b')
93+
assert qwen_profile is not None
94+
assert qwen_profile.json_schema_transformer == InlineDefsJsonSchemaTransformer
95+
96+
# Test google provider
97+
google_profile = provider.model_profile('google/gemma-2-2b-it')
98+
google_mock.assert_called_with('gemma-2-2b-it')
99+
assert google_profile is not None
100+
assert google_profile.json_schema_transformer == GoogleJsonSchemaTransformer
101+
102+
# Test openai provider
103+
profile = provider.model_profile('openai/gpt-oss-120b')
104+
openai_mock.assert_called_with('gpt-oss-120b')
105+
assert profile is not None
106+
assert profile.json_schema_transformer == OpenAIJsonSchemaTransformer
107+
108+
# Test mistral provider
109+
profile = provider.model_profile('mistralai/Devstral-Small-2505')
110+
mistral_mock.assert_called_with('devstral-small-2505')
111+
assert profile is not None
112+
assert profile.json_schema_transformer == OpenAIJsonSchemaTransformer
113+
114+
# Test moonshotai provider
115+
moonshotai_profile = provider.model_profile('moonshotai/Kimi-K2-Instruct')
116+
moonshotai_mock.assert_called_with('kimi-k2-instruct')
117+
assert moonshotai_profile is not None
118+
assert moonshotai_profile.json_schema_transformer == OpenAIJsonSchemaTransformer
119+
120+
# Test unknown provider
121+
unknown_profile = provider.model_profile('unknown-provider/unknown-model')
122+
assert unknown_profile is not None
123+
assert unknown_profile.json_schema_transformer == OpenAIJsonSchemaTransformer
124+
125+
126+
def test_nebius_provider_invalid_model_name():
127+
provider = NebiusProvider(api_key='api-key')
128+
129+
with pytest.raises(UserError, match="Model name must be in 'provider/model' format"):
130+
provider.model_profile('invalid-model-name')

tests/providers/test_provider_names.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from pydantic_ai.providers.litellm import LiteLLMProvider
2929
from pydantic_ai.providers.mistral import MistralProvider
3030
from pydantic_ai.providers.moonshotai import MoonshotAIProvider
31+
from pydantic_ai.providers.nebius import NebiusProvider
3132
from pydantic_ai.providers.ollama import OllamaProvider
3233
from pydantic_ai.providers.openai import OpenAIProvider
3334
from pydantic_ai.providers.openrouter import OpenRouterProvider
@@ -54,6 +55,7 @@
5455
('github', GitHubProvider, 'GITHUB_API_KEY'),
5556
('ollama', OllamaProvider, 'OLLAMA_BASE_URL'),
5657
('litellm', LiteLLMProvider, None),
58+
('nebius', NebiusProvider, 'NEBIUS_API_KEY'),
5759
]
5860

5961
if not imports_successful():

tests/test_examples.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,7 @@ def print(self, *args: Any, **kwargs: Any) -> None:
170170
env.set('AWS_DEFAULT_REGION', 'us-east-1')
171171
env.set('VERCEL_AI_GATEWAY_API_KEY', 'testing')
172172
env.set('CEREBRAS_API_KEY', 'testing')
173+
env.set('NEBIUS_API_KEY', 'testing')
173174

174175
prefix_settings = example.prefix_settings()
175176
opt_test = prefix_settings.get('test', '')

0 commit comments

Comments
 (0)