Skip to content

Commit e119392

Browse files
antoncpAnton Chaplygin
andauthored
Add Nebius AI Studio provider support (#3124)
Co-authored-by: Anton Chaplygin <[email protected]>
1 parent dab177a commit e119392

File tree

12 files changed

+282
-3
lines changed

12 files changed

+282
-3
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ We built Pydantic AI with one simple aim: to bring that FastAPI feeling to GenAI
3939
[Pydantic Validation](https://docs.pydantic.dev/latest/) is the validation layer of the OpenAI SDK, the Google ADK, the Anthropic SDK, LangChain, LlamaIndex, AutoGPT, Transformers, CrewAI, Instructor and many more. _Why use the derivative when you can go straight to the source?_ :smiley:
4040

4141
2. **Model-agnostic**:
42-
Supports virtually every [model](https://ai.pydantic.dev/models/overview) and provider: OpenAI, Anthropic, Gemini, DeepSeek, Grok, Cohere, Mistral, and Perplexity; Azure AI Foundry, Amazon Bedrock, Google Vertex AI, Ollama, LiteLLM, Groq, OpenRouter, Together AI, Fireworks AI, Cerebras, Hugging Face, GitHub, Heroku, Vercel. If your favorite model or provider is not listed, you can easily implement a [custom model](https://ai.pydantic.dev/models/overview#custom-models).
42+
Supports virtually every [model](https://ai.pydantic.dev/models/overview) and provider: OpenAI, Anthropic, Gemini, DeepSeek, Grok, Cohere, Mistral, and Perplexity; Azure AI Foundry, Amazon Bedrock, Google Vertex AI, Ollama, LiteLLM, Groq, OpenRouter, Together AI, Fireworks AI, Cerebras, Hugging Face, GitHub, Heroku, Vercel, Nebius. If your favorite model or provider is not listed, you can easily implement a [custom model](https://ai.pydantic.dev/models/overview#custom-models).
4343

4444
3. **Seamless Observability**:
4545
Tightly [integrates](https://ai.pydantic.dev/logfire) with [Pydantic Logfire](https://pydantic.dev/logfire), our general-purpose OpenTelemetry observability platform, for real-time debugging, evals-based performance monitoring, and behavior, tracing, and cost tracking. If you already have an observability platform that supports OTel, you can [use that too](https://ai.pydantic.dev/logfire#alternative-observability-backends).

docs/api/providers.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,3 +41,5 @@
4141
::: pydantic_ai.providers.ollama.OllamaProvider
4242

4343
::: pydantic_ai.providers.litellm.LiteLLMProvider
44+
45+
::: pydantic_ai.providers.nebius.NebiusProvider

docs/index.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ We built Pydantic AI with one simple aim: to bring that FastAPI feeling to GenAI
1414
[Pydantic Validation](https://docs.pydantic.dev/latest/) is the validation layer of the OpenAI SDK, the Google ADK, the Anthropic SDK, LangChain, LlamaIndex, AutoGPT, Transformers, CrewAI, Instructor and many more. _Why use the derivative when you can go straight to the source?_ :smiley:
1515

1616
2. **Model-agnostic**:
17-
Supports virtually every [model](models/overview.md) and provider: OpenAI, Anthropic, Gemini, DeepSeek, Grok, Cohere, Mistral, and Perplexity; Azure AI Foundry, Amazon Bedrock, Google Vertex AI, Ollama, LiteLLM, Groq, OpenRouter, Together AI, Fireworks AI, Cerebras, Hugging Face, GitHub, Heroku, Vercel. If your favorite model or provider is not listed, you can easily implement a [custom model](models/overview.md#custom-models).
17+
Supports virtually every [model](models/overview.md) and provider: OpenAI, Anthropic, Gemini, DeepSeek, Grok, Cohere, Mistral, and Perplexity; Azure AI Foundry, Amazon Bedrock, Google Vertex AI, Ollama, LiteLLM, Groq, OpenRouter, Together AI, Fireworks AI, Cerebras, Hugging Face, GitHub, Heroku, Vercel, Nebius. If your favorite model or provider is not listed, you can easily implement a [custom model](models/overview.md#custom-models).
1818

1919
3. **Seamless Observability**:
2020
Tightly [integrates](logfire.md) with [Pydantic Logfire](https://pydantic.dev/logfire), our general-purpose OpenTelemetry observability platform, for real-time debugging, evals-based performance monitoring, and behavior, tracing, and cost tracking. If you already have an observability platform that supports OTel, you can [use that too](logfire.md#alternative-observability-backends).

docs/models/openai.md

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -608,3 +608,35 @@ print(result.output)
608608
#> The capital of France is Paris.
609609
...
610610
```
611+
612+
### Nebius AI Studio
613+
614+
Go to [Nebius AI Studio](https://studio.nebius.com/) and create an API key.
615+
616+
Once you've set the `NEBIUS_API_KEY` environment variable, you can run the following:
617+
618+
```python
619+
from pydantic_ai import Agent
620+
621+
agent = Agent('nebius:Qwen/Qwen3-32B-fast')
622+
result = agent.run_sync('What is the capital of France?')
623+
print(result.output)
624+
#> The capital of France is Paris.
625+
```
626+
627+
If you need to configure the provider, you can use the [`NebiusProvider`][pydantic_ai.providers.nebius.NebiusProvider] class:
628+
629+
```python
630+
from pydantic_ai import Agent
631+
from pydantic_ai.models.openai import OpenAIChatModel
632+
from pydantic_ai.providers.nebius import NebiusProvider
633+
634+
model = OpenAIChatModel(
635+
'Qwen/Qwen3-32B-fast',
636+
provider=NebiusProvider(api_key='your-nebius-api-key'),
637+
)
638+
agent = Agent(model)
639+
result = agent.run_sync('What is the capital of France?')
640+
print(result.output)
641+
#> The capital of France is Paris.
642+
```

docs/models/overview.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ In addition, many providers are compatible with the OpenAI API, and can be used
2828
- [GitHub Models](openai.md#github-models)
2929
- [Cerebras](openai.md#cerebras)
3030
- [LiteLLM](openai.md#litellm)
31+
- [Nebius AI Studio](openai.md#nebius-ai-studio)
3132

3233
Pydantic AI also comes with [`TestModel`](../api/models/test.md) and [`FunctionModel`](../api/models/function.md)
3334
for testing and development.

pydantic_ai_slim/pydantic_ai/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -691,6 +691,7 @@ def infer_model(model: Model | KnownModelName | str) -> Model: # noqa: C901
691691
'together',
692692
'vercel',
693693
'litellm',
694+
'nebius',
694695
):
695696
from .openai import OpenAIChatModel
696697

pydantic_ai_slim/pydantic_ai/models/openai.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,7 @@ def __init__(
284284
'together',
285285
'vercel',
286286
'litellm',
287+
'nebius',
287288
]
288289
| Provider[AsyncOpenAI] = 'openai',
289290
profile: ModelProfileSpec | None = None,
@@ -312,6 +313,7 @@ def __init__(
312313
'together',
313314
'vercel',
314315
'litellm',
316+
'nebius',
315317
]
316318
| Provider[AsyncOpenAI] = 'openai',
317319
profile: ModelProfileSpec | None = None,
@@ -339,6 +341,7 @@ def __init__(
339341
'together',
340342
'vercel',
341343
'litellm',
344+
'nebius',
342345
]
343346
| Provider[AsyncOpenAI] = 'openai',
344347
profile: ModelProfileSpec | None = None,
@@ -899,7 +902,7 @@ def __init__(
899902
self,
900903
model_name: OpenAIModelName,
901904
*,
902-
provider: Literal['openai', 'deepseek', 'azure', 'openrouter', 'grok', 'fireworks', 'together']
905+
provider: Literal['openai', 'deepseek', 'azure', 'openrouter', 'grok', 'fireworks', 'together', 'nebius']
903906
| Provider[AsyncOpenAI] = 'openai',
904907
profile: ModelProfileSpec | None = None,
905908
settings: ModelSettings | None = None,

pydantic_ai_slim/pydantic_ai/providers/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,10 @@ def infer_provider_class(provider: str) -> type[Provider[Any]]: # noqa: C901
142142
from .litellm import LiteLLMProvider
143143

144144
return LiteLLMProvider
145+
elif provider == 'nebius':
146+
from .nebius import NebiusProvider
147+
148+
return NebiusProvider
145149
else: # pragma: no cover
146150
raise ValueError(f'Unknown provider: {provider}')
147151

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
from __future__ import annotations as _annotations
2+
3+
import os
4+
from typing import overload
5+
6+
import httpx
7+
8+
from pydantic_ai import ModelProfile
9+
from pydantic_ai.exceptions import UserError
10+
from pydantic_ai.models import cached_async_http_client
11+
from pydantic_ai.profiles.deepseek import deepseek_model_profile
12+
from pydantic_ai.profiles.google import google_model_profile
13+
from pydantic_ai.profiles.harmony import harmony_model_profile
14+
from pydantic_ai.profiles.meta import meta_model_profile
15+
from pydantic_ai.profiles.mistral import mistral_model_profile
16+
from pydantic_ai.profiles.moonshotai import moonshotai_model_profile
17+
from pydantic_ai.profiles.openai import OpenAIJsonSchemaTransformer, OpenAIModelProfile
18+
from pydantic_ai.profiles.qwen import qwen_model_profile
19+
from pydantic_ai.providers import Provider
20+
21+
try:
22+
from openai import AsyncOpenAI
23+
except ImportError as _import_error: # pragma: no cover
24+
raise ImportError(
25+
'Please install the `openai` package to use the Nebius provider, '
26+
'you can use the `openai` optional group — `pip install "pydantic-ai-slim[openai]"`'
27+
) from _import_error
28+
29+
30+
class NebiusProvider(Provider[AsyncOpenAI]):
31+
"""Provider for Nebius AI Studio API."""
32+
33+
@property
34+
def name(self) -> str:
35+
return 'nebius'
36+
37+
@property
38+
def base_url(self) -> str:
39+
return 'https://api.studio.nebius.com/v1'
40+
41+
@property
42+
def client(self) -> AsyncOpenAI:
43+
return self._client
44+
45+
def model_profile(self, model_name: str) -> ModelProfile | None:
46+
provider_to_profile = {
47+
'meta-llama': meta_model_profile,
48+
'deepseek-ai': deepseek_model_profile,
49+
'qwen': qwen_model_profile,
50+
'google': google_model_profile,
51+
'openai': harmony_model_profile, # used for gpt-oss models on Nebius
52+
'mistralai': mistral_model_profile,
53+
'moonshotai': moonshotai_model_profile,
54+
}
55+
56+
profile = None
57+
58+
try:
59+
model_name = model_name.lower()
60+
provider, model_name = model_name.split('/', 1)
61+
except ValueError:
62+
raise UserError(f"Model name must be in 'provider/model' format, got: {model_name!r}")
63+
if provider in provider_to_profile:
64+
profile = provider_to_profile[provider](model_name)
65+
66+
# As NebiusProvider is always used with OpenAIChatModel, which used to unconditionally use OpenAIJsonSchemaTransformer,
67+
# we need to maintain that behavior unless json_schema_transformer is set explicitly
68+
return OpenAIModelProfile(json_schema_transformer=OpenAIJsonSchemaTransformer).update(profile)
69+
70+
@overload
71+
def __init__(self) -> None: ...
72+
73+
@overload
74+
def __init__(self, *, api_key: str) -> None: ...
75+
76+
@overload
77+
def __init__(self, *, api_key: str, http_client: httpx.AsyncClient) -> None: ...
78+
79+
@overload
80+
def __init__(self, *, openai_client: AsyncOpenAI | None = None) -> None: ...
81+
82+
def __init__(
83+
self,
84+
*,
85+
api_key: str | None = None,
86+
openai_client: AsyncOpenAI | None = None,
87+
http_client: httpx.AsyncClient | None = None,
88+
) -> None:
89+
api_key = api_key or os.getenv('NEBIUS_API_KEY')
90+
if not api_key and openai_client is None:
91+
raise UserError(
92+
'Set the `NEBIUS_API_KEY` environment variable or pass it via '
93+
'`NebiusProvider(api_key=...)` to use the Nebius AI Studio provider.'
94+
)
95+
96+
if openai_client is not None:
97+
self._client = openai_client
98+
elif http_client is not None:
99+
self._client = AsyncOpenAI(base_url=self.base_url, api_key=api_key, http_client=http_client)
100+
else:
101+
http_client = cached_async_http_client(provider='nebius')
102+
self._client = AsyncOpenAI(base_url=self.base_url, api_key=api_key, http_client=http_client)

tests/providers/test_nebius.py

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
import re
2+
3+
import httpx
4+
import pytest
5+
from pytest_mock import MockerFixture
6+
7+
from pydantic_ai._json_schema import InlineDefsJsonSchemaTransformer
8+
from pydantic_ai.exceptions import UserError
9+
from pydantic_ai.profiles.deepseek import deepseek_model_profile
10+
from pydantic_ai.profiles.google import GoogleJsonSchemaTransformer, google_model_profile
11+
from pydantic_ai.profiles.harmony import harmony_model_profile
12+
from pydantic_ai.profiles.meta import meta_model_profile
13+
from pydantic_ai.profiles.mistral import mistral_model_profile
14+
from pydantic_ai.profiles.moonshotai import moonshotai_model_profile
15+
from pydantic_ai.profiles.openai import OpenAIJsonSchemaTransformer
16+
from pydantic_ai.profiles.qwen import qwen_model_profile
17+
18+
from ..conftest import TestEnv, try_import
19+
20+
with try_import() as imports_successful:
21+
import openai
22+
23+
from pydantic_ai.providers.nebius import NebiusProvider
24+
25+
26+
pytestmark = [
27+
pytest.mark.skipif(not imports_successful(), reason='openai not installed'),
28+
pytest.mark.vcr,
29+
pytest.mark.anyio,
30+
]
31+
32+
33+
def test_nebius_provider():
34+
provider = NebiusProvider(api_key='api-key')
35+
assert provider.name == 'nebius'
36+
assert provider.base_url == 'https://api.studio.nebius.com/v1'
37+
assert isinstance(provider.client, openai.AsyncOpenAI)
38+
assert provider.client.api_key == 'api-key'
39+
40+
41+
def test_nebius_provider_need_api_key(env: TestEnv) -> None:
42+
env.remove('NEBIUS_API_KEY')
43+
with pytest.raises(
44+
UserError,
45+
match=re.escape(
46+
'Set the `NEBIUS_API_KEY` environment variable or pass it via '
47+
'`NebiusProvider(api_key=...)` to use the Nebius AI Studio provider.'
48+
),
49+
):
50+
NebiusProvider()
51+
52+
53+
def test_nebius_pass_openai_client() -> None:
54+
openai_client = openai.AsyncOpenAI(api_key='api-key')
55+
provider = NebiusProvider(openai_client=openai_client)
56+
assert provider.client == openai_client
57+
58+
59+
def test_nebius_provider_pass_http_client() -> None:
60+
http_client = httpx.AsyncClient()
61+
provider = NebiusProvider(http_client=http_client, api_key='api-key')
62+
assert provider.client._client == http_client # type: ignore[reportPrivateUsage]
63+
64+
65+
def test_nebius_provider_model_profile(mocker: MockerFixture):
66+
provider = NebiusProvider(api_key='api-key')
67+
68+
ns = 'pydantic_ai.providers.nebius'
69+
70+
# Mock all profile functions
71+
meta_mock = mocker.patch(f'{ns}.meta_model_profile', wraps=meta_model_profile)
72+
deepseek_mock = mocker.patch(f'{ns}.deepseek_model_profile', wraps=deepseek_model_profile)
73+
qwen_mock = mocker.patch(f'{ns}.qwen_model_profile', wraps=qwen_model_profile)
74+
google_mock = mocker.patch(f'{ns}.google_model_profile', wraps=google_model_profile)
75+
harmony_mock = mocker.patch(f'{ns}.harmony_model_profile', wraps=harmony_model_profile)
76+
mistral_mock = mocker.patch(f'{ns}.mistral_model_profile', wraps=mistral_model_profile)
77+
moonshotai_mock = mocker.patch(f'{ns}.moonshotai_model_profile', wraps=moonshotai_model_profile)
78+
79+
# Test meta provider
80+
meta_profile = provider.model_profile('meta-llama/Llama-3.3-70B-Instruct')
81+
meta_mock.assert_called_with('llama-3.3-70b-instruct')
82+
assert meta_profile is not None
83+
assert meta_profile.json_schema_transformer == InlineDefsJsonSchemaTransformer
84+
85+
# Test deepseek provider
86+
profile = provider.model_profile('deepseek-ai/DeepSeek-R1-0528')
87+
deepseek_mock.assert_called_with('deepseek-r1-0528')
88+
assert profile is not None
89+
assert profile.json_schema_transformer == OpenAIJsonSchemaTransformer
90+
91+
# Test qwen provider
92+
qwen_profile = provider.model_profile('Qwen/Qwen3-30B-A3B')
93+
qwen_mock.assert_called_with('qwen3-30b-a3b')
94+
assert qwen_profile is not None
95+
assert qwen_profile.json_schema_transformer == InlineDefsJsonSchemaTransformer
96+
97+
# Test google provider
98+
google_profile = provider.model_profile('google/gemma-2-2b-it')
99+
google_mock.assert_called_with('gemma-2-2b-it')
100+
assert google_profile is not None
101+
assert google_profile.json_schema_transformer == GoogleJsonSchemaTransformer
102+
103+
# Test harmony (for openai gpt-oss) provider
104+
profile = provider.model_profile('openai/gpt-oss-120b')
105+
harmony_mock.assert_called_with('gpt-oss-120b')
106+
assert profile is not None
107+
assert profile.json_schema_transformer == OpenAIJsonSchemaTransformer
108+
109+
# Test mistral provider
110+
profile = provider.model_profile('mistralai/Devstral-Small-2505')
111+
mistral_mock.assert_called_with('devstral-small-2505')
112+
assert profile is not None
113+
assert profile.json_schema_transformer == OpenAIJsonSchemaTransformer
114+
115+
# Test moonshotai provider
116+
moonshotai_profile = provider.model_profile('moonshotai/Kimi-K2-Instruct')
117+
moonshotai_mock.assert_called_with('kimi-k2-instruct')
118+
assert moonshotai_profile is not None
119+
assert moonshotai_profile.json_schema_transformer == OpenAIJsonSchemaTransformer
120+
121+
# Test unknown provider
122+
unknown_profile = provider.model_profile('unknown-provider/unknown-model')
123+
assert unknown_profile is not None
124+
assert unknown_profile.json_schema_transformer == OpenAIJsonSchemaTransformer
125+
126+
127+
def test_nebius_provider_invalid_model_name():
128+
provider = NebiusProvider(api_key='api-key')
129+
130+
with pytest.raises(UserError, match="Model name must be in 'provider/model' format"):
131+
provider.model_profile('invalid-model-name')

0 commit comments

Comments
 (0)