-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Add Nebius AI Studio provider support #3124
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,102 @@ | ||
| from __future__ import annotations as _annotations | ||
|
|
||
| import os | ||
| from typing import overload | ||
|
|
||
| import httpx | ||
|
|
||
| from pydantic_ai import ModelProfile | ||
| from pydantic_ai.exceptions import UserError | ||
| from pydantic_ai.models import cached_async_http_client | ||
| from pydantic_ai.profiles.deepseek import deepseek_model_profile | ||
| from pydantic_ai.profiles.google import google_model_profile | ||
| from pydantic_ai.profiles.harmony import harmony_model_profile | ||
| from pydantic_ai.profiles.meta import meta_model_profile | ||
| from pydantic_ai.profiles.mistral import mistral_model_profile | ||
| from pydantic_ai.profiles.moonshotai import moonshotai_model_profile | ||
| from pydantic_ai.profiles.openai import OpenAIJsonSchemaTransformer, OpenAIModelProfile | ||
| from pydantic_ai.profiles.qwen import qwen_model_profile | ||
| from pydantic_ai.providers import Provider | ||
|
|
||
| try: | ||
| from openai import AsyncOpenAI | ||
| except ImportError as _import_error: # pragma: no cover | ||
| raise ImportError( | ||
| 'Please install the `openai` package to use the Nebius provider, ' | ||
| 'you can use the `openai` optional group — `pip install "pydantic-ai-slim[openai]"`' | ||
| ) from _import_error | ||
|
|
||
|
|
||
| class NebiusProvider(Provider[AsyncOpenAI]): | ||
| """Provider for Nebius AI Studio API.""" | ||
|
|
||
| @property | ||
| def name(self) -> str: | ||
| return 'nebius' | ||
|
|
||
| @property | ||
| def base_url(self) -> str: | ||
| return 'https://api.studio.nebius.com/v1' | ||
|
|
||
| @property | ||
| def client(self) -> AsyncOpenAI: | ||
| return self._client | ||
|
|
||
| def model_profile(self, model_name: str) -> ModelProfile | None: | ||
| provider_to_profile = { | ||
| 'meta-llama': meta_model_profile, | ||
| 'deepseek-ai': deepseek_model_profile, | ||
| 'qwen': qwen_model_profile, | ||
| 'google': google_model_profile, | ||
| 'openai': harmony_model_profile, # used for gpt-oss models on Nebius | ||
| 'mistralai': mistral_model_profile, | ||
| 'moonshotai': moonshotai_model_profile, | ||
| } | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are these all of the providers Nebius supports?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, all currently supported Nebius providers are included here. They are listed at https://studio.nebius.com/ |
||
|
|
||
| profile = None | ||
|
|
||
| try: | ||
| model_name = model_name.lower() | ||
| provider, model_name = model_name.split('/', 1) | ||
| except ValueError: | ||
| raise UserError(f"Model name must be in 'provider/model' format, got: {model_name!r}") | ||
| if provider in provider_to_profile: | ||
| profile = provider_to_profile[provider](model_name) | ||
|
|
||
| # As NebiusProvider is always used with OpenAIChatModel, which used to unconditionally use OpenAIJsonSchemaTransformer, | ||
| # we need to maintain that behavior unless json_schema_transformer is set explicitly | ||
| return OpenAIModelProfile(json_schema_transformer=OpenAIJsonSchemaTransformer).update(profile) | ||
|
|
||
| @overload | ||
| def __init__(self) -> None: ... | ||
|
|
||
| @overload | ||
| def __init__(self, *, api_key: str) -> None: ... | ||
|
|
||
| @overload | ||
| def __init__(self, *, api_key: str, http_client: httpx.AsyncClient) -> None: ... | ||
|
|
||
| @overload | ||
| def __init__(self, *, openai_client: AsyncOpenAI | None = None) -> None: ... | ||
|
|
||
| def __init__( | ||
| self, | ||
| *, | ||
| api_key: str | None = None, | ||
| openai_client: AsyncOpenAI | None = None, | ||
| http_client: httpx.AsyncClient | None = None, | ||
| ) -> None: | ||
| api_key = api_key or os.getenv('NEBIUS_API_KEY') | ||
| if not api_key and openai_client is None: | ||
| raise UserError( | ||
| 'Set the `NEBIUS_API_KEY` environment variable or pass it via ' | ||
| '`NebiusProvider(api_key=...)` to use the Nebius AI Studio provider.' | ||
| ) | ||
|
|
||
| if openai_client is not None: | ||
| self._client = openai_client | ||
| elif http_client is not None: | ||
| self._client = AsyncOpenAI(base_url=self.base_url, api_key=api_key, http_client=http_client) | ||
| else: | ||
| http_client = cached_async_http_client(provider='nebius') | ||
| self._client = AsyncOpenAI(base_url=self.base_url, api_key=api_key, http_client=http_client) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,131 @@ | ||
| import re | ||
|
|
||
| import httpx | ||
| import pytest | ||
| from pytest_mock import MockerFixture | ||
|
|
||
| from pydantic_ai._json_schema import InlineDefsJsonSchemaTransformer | ||
| from pydantic_ai.exceptions import UserError | ||
| from pydantic_ai.profiles.deepseek import deepseek_model_profile | ||
| from pydantic_ai.profiles.google import GoogleJsonSchemaTransformer, google_model_profile | ||
| from pydantic_ai.profiles.harmony import harmony_model_profile | ||
| from pydantic_ai.profiles.meta import meta_model_profile | ||
| from pydantic_ai.profiles.mistral import mistral_model_profile | ||
| from pydantic_ai.profiles.moonshotai import moonshotai_model_profile | ||
| from pydantic_ai.profiles.openai import OpenAIJsonSchemaTransformer | ||
| from pydantic_ai.profiles.qwen import qwen_model_profile | ||
|
|
||
| from ..conftest import TestEnv, try_import | ||
|
|
||
| with try_import() as imports_successful: | ||
| import openai | ||
|
|
||
| from pydantic_ai.providers.nebius import NebiusProvider | ||
|
|
||
|
|
||
| pytestmark = [ | ||
| pytest.mark.skipif(not imports_successful(), reason='openai not installed'), | ||
| pytest.mark.vcr, | ||
| pytest.mark.anyio, | ||
| ] | ||
|
|
||
|
|
||
| def test_nebius_provider(): | ||
| provider = NebiusProvider(api_key='api-key') | ||
| assert provider.name == 'nebius' | ||
| assert provider.base_url == 'https://api.studio.nebius.com/v1' | ||
| assert isinstance(provider.client, openai.AsyncOpenAI) | ||
| assert provider.client.api_key == 'api-key' | ||
|
|
||
|
|
||
| def test_nebius_provider_need_api_key(env: TestEnv) -> None: | ||
| env.remove('NEBIUS_API_KEY') | ||
| with pytest.raises( | ||
| UserError, | ||
| match=re.escape( | ||
| 'Set the `NEBIUS_API_KEY` environment variable or pass it via ' | ||
| '`NebiusProvider(api_key=...)` to use the Nebius AI Studio provider.' | ||
| ), | ||
| ): | ||
| NebiusProvider() | ||
|
|
||
|
|
||
| def test_nebius_pass_openai_client() -> None: | ||
| openai_client = openai.AsyncOpenAI(api_key='api-key') | ||
| provider = NebiusProvider(openai_client=openai_client) | ||
| assert provider.client == openai_client | ||
|
|
||
|
|
||
| def test_nebius_provider_pass_http_client() -> None: | ||
| http_client = httpx.AsyncClient() | ||
| provider = NebiusProvider(http_client=http_client, api_key='api-key') | ||
| assert provider.client._client == http_client # type: ignore[reportPrivateUsage] | ||
|
|
||
|
|
||
| def test_nebius_provider_model_profile(mocker: MockerFixture): | ||
| provider = NebiusProvider(api_key='api-key') | ||
|
|
||
| ns = 'pydantic_ai.providers.nebius' | ||
|
|
||
| # Mock all profile functions | ||
| meta_mock = mocker.patch(f'{ns}.meta_model_profile', wraps=meta_model_profile) | ||
| deepseek_mock = mocker.patch(f'{ns}.deepseek_model_profile', wraps=deepseek_model_profile) | ||
| qwen_mock = mocker.patch(f'{ns}.qwen_model_profile', wraps=qwen_model_profile) | ||
| google_mock = mocker.patch(f'{ns}.google_model_profile', wraps=google_model_profile) | ||
| harmony_mock = mocker.patch(f'{ns}.harmony_model_profile', wraps=harmony_model_profile) | ||
| mistral_mock = mocker.patch(f'{ns}.mistral_model_profile', wraps=mistral_model_profile) | ||
| moonshotai_mock = mocker.patch(f'{ns}.moonshotai_model_profile', wraps=moonshotai_model_profile) | ||
|
|
||
| # Test meta provider | ||
| meta_profile = provider.model_profile('meta-llama/Llama-3.3-70B-Instruct') | ||
| meta_mock.assert_called_with('llama-3.3-70b-instruct') | ||
| assert meta_profile is not None | ||
| assert meta_profile.json_schema_transformer == InlineDefsJsonSchemaTransformer | ||
|
|
||
| # Test deepseek provider | ||
| profile = provider.model_profile('deepseek-ai/DeepSeek-R1-0528') | ||
| deepseek_mock.assert_called_with('deepseek-r1-0528') | ||
| assert profile is not None | ||
| assert profile.json_schema_transformer == OpenAIJsonSchemaTransformer | ||
|
|
||
| # Test qwen provider | ||
| qwen_profile = provider.model_profile('Qwen/Qwen3-30B-A3B') | ||
| qwen_mock.assert_called_with('qwen3-30b-a3b') | ||
| assert qwen_profile is not None | ||
| assert qwen_profile.json_schema_transformer == InlineDefsJsonSchemaTransformer | ||
|
|
||
| # Test google provider | ||
| google_profile = provider.model_profile('google/gemma-2-2b-it') | ||
| google_mock.assert_called_with('gemma-2-2b-it') | ||
| assert google_profile is not None | ||
| assert google_profile.json_schema_transformer == GoogleJsonSchemaTransformer | ||
|
|
||
| # Test harmony (for openai gpt-oss) provider | ||
| profile = provider.model_profile('openai/gpt-oss-120b') | ||
| harmony_mock.assert_called_with('gpt-oss-120b') | ||
| assert profile is not None | ||
| assert profile.json_schema_transformer == OpenAIJsonSchemaTransformer | ||
|
|
||
| # Test mistral provider | ||
| profile = provider.model_profile('mistralai/Devstral-Small-2505') | ||
| mistral_mock.assert_called_with('devstral-small-2505') | ||
| assert profile is not None | ||
| assert profile.json_schema_transformer == OpenAIJsonSchemaTransformer | ||
|
|
||
| # Test moonshotai provider | ||
| moonshotai_profile = provider.model_profile('moonshotai/Kimi-K2-Instruct') | ||
| moonshotai_mock.assert_called_with('kimi-k2-instruct') | ||
| assert moonshotai_profile is not None | ||
| assert moonshotai_profile.json_schema_transformer == OpenAIJsonSchemaTransformer | ||
|
|
||
| # Test unknown provider | ||
| unknown_profile = provider.model_profile('unknown-provider/unknown-model') | ||
| assert unknown_profile is not None | ||
| assert unknown_profile.json_schema_transformer == OpenAIJsonSchemaTransformer | ||
|
|
||
|
|
||
| def test_nebius_provider_invalid_model_name(): | ||
| provider = NebiusProvider(api_key='api-key') | ||
|
|
||
| with pytest.raises(UserError, match="Model name must be in 'provider/model' format"): | ||
| provider.model_profile('invalid-model-name') |
Uh oh!
There was an error while loading. Please reload this page.