From 3261d21818a63e79c9202a43ae202986fca428a0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Patryk=20Ga=C5=82a?= Date: Thu, 12 Feb 2026 22:39:44 +0100 Subject: [PATCH] Optimize cold start of neptune-query --- .../docker/rofiles/neptune_api/client.py | 14 +--- .../docker/rofiles/templates/types.py.jinja | 12 +++- .../generated/neptune_api/client.py | 14 +--- .../generated/neptune_api/types.py | 12 +++- src/neptune_query/internal/api_utils.py | 71 +------------------ src/neptune_query/internal/client.py | 13 +--- tests/e2e/conftest.py | 10 +-- tests/performance_e2e/conftest.py | 11 ++- tests/unit/internal/test_api_client_cache.py | 6 +- tests/unit/neptune_api/conftest.py | 26 +++++-- tests/unit/neptune_api/test_authenticator.py | 5 +- tests/unit/neptune_api/test_oauth_token.py | 1 + 12 files changed, 64 insertions(+), 131 deletions(-) diff --git a/src/neptune_api_codegen/docker/rofiles/neptune_api/client.py b/src/neptune_api_codegen/docker/rofiles/neptune_api/client.py index 83f9633c..38d3f4da 100644 --- a/src/neptune_api_codegen/docker/rofiles/neptune_api/client.py +++ b/src/neptune_api_codegen/docker/rofiles/neptune_api/client.py @@ -198,8 +198,6 @@ class AuthenticatedClient: status code that was not documented in the source OpenAPI document. Can also be provided as a keyword argument to the constructor. credentials: User credentials for authentication. - token_refreshing_endpoint: Token refreshing endpoint url - client_id: Client identifier for the OAuth application. api_key_exchange_callback: The Neptune API Token exchange function prefix: The prefix to use for the Authorization header auth_header_name: The name of the Authorization header @@ -218,8 +216,6 @@ class AuthenticatedClient: _async_client: Optional[httpx.AsyncClient] = field(default=None, init=False) credentials: Credentials - token_refreshing_endpoint: str - client_id: str api_key_exchange_callback: Callable[[Client, Credentials], OAuthToken] prefix: str = "Bearer" auth_header_name: str = "Authorization" @@ -299,8 +295,6 @@ def get_httpx_client(self) -> httpx.Client: self._client = httpx.Client( auth=NeptuneAuthenticator( credentials=self.credentials, - client_id=self.client_id, - token_refreshing_endpoint=self.token_refreshing_endpoint, api_key_exchange_factory=self.api_key_exchange_callback, client=self.get_token_refreshing_client(), ), @@ -355,14 +349,10 @@ class NeptuneAuthenticator(httpx.Auth): def __init__( self, credentials: Credentials, - client_id: str, - token_refreshing_endpoint: str, api_key_exchange_factory: Callable[[Client, Credentials], OAuthToken], client: Client, ): self._credentials: Credentials = credentials - self._client_id: str = client_id - self._token_refreshing_endpoint: str = token_refreshing_endpoint self._api_key_exchange_factory: Callable[[Client, Credentials], OAuthToken] = api_key_exchange_factory self._client = client @@ -373,11 +363,11 @@ def _refresh_existing_token(self) -> OAuthToken: raise ValueError("Cannot refresh an empty token") try: response = self._client.get_httpx_client().post( - url=self._token_refreshing_endpoint, + url=self._token.token_endpoint, data={ "grant_type": "refresh_token", "refresh_token": self._token.refresh_token, - "client_id": self._client_id, + "client_id": self._token.client_id, "expires_in": self._token.seconds_left, }, ) diff --git a/src/neptune_api_codegen/docker/rofiles/templates/types.py.jinja b/src/neptune_api_codegen/docker/rofiles/templates/types.py.jinja index f9a24aad..6e88432f 100644 --- a/src/neptune_api_codegen/docker/rofiles/templates/types.py.jinja +++ b/src/neptune_api_codegen/docker/rofiles/templates/types.py.jinja @@ -91,6 +91,8 @@ class OAuthToken: _expiration_time: float = field(default=0.0, alias="expiration_time", kw_only=True) access_token: str refresh_token: str + client_id: str + token_endpoint: str @classmethod def from_tokens(cls, access: str, refresh: str) -> "OAuthToken": @@ -98,10 +100,18 @@ class OAuthToken: try: decoded_token = jwt.decode(access, options=DECODING_OPTIONS) expiration_time = float(decoded_token["exp"]) + client_id = str(decoded_token["azp"]) + issuer = str(decoded_token["iss"]).rstrip("/") except (jwt.ExpiredSignatureError, jwt.InvalidTokenError, KeyError) as e: raise InvalidApiTokenException("Cannot decode the access token") from e - return OAuthToken(access_token=access, refresh_token=refresh, expiration_time=expiration_time) + return OAuthToken( + access_token=access, + refresh_token=refresh, + expiration_time=expiration_time, + client_id=client_id, + token_endpoint=f"{issuer}/protocol/openid-connect/token", + ) @property def seconds_left(self) -> float: diff --git a/src/neptune_query/generated/neptune_api/client.py b/src/neptune_query/generated/neptune_api/client.py index 83f9633c..38d3f4da 100644 --- a/src/neptune_query/generated/neptune_api/client.py +++ b/src/neptune_query/generated/neptune_api/client.py @@ -198,8 +198,6 @@ class AuthenticatedClient: status code that was not documented in the source OpenAPI document. Can also be provided as a keyword argument to the constructor. credentials: User credentials for authentication. - token_refreshing_endpoint: Token refreshing endpoint url - client_id: Client identifier for the OAuth application. api_key_exchange_callback: The Neptune API Token exchange function prefix: The prefix to use for the Authorization header auth_header_name: The name of the Authorization header @@ -218,8 +216,6 @@ class AuthenticatedClient: _async_client: Optional[httpx.AsyncClient] = field(default=None, init=False) credentials: Credentials - token_refreshing_endpoint: str - client_id: str api_key_exchange_callback: Callable[[Client, Credentials], OAuthToken] prefix: str = "Bearer" auth_header_name: str = "Authorization" @@ -299,8 +295,6 @@ def get_httpx_client(self) -> httpx.Client: self._client = httpx.Client( auth=NeptuneAuthenticator( credentials=self.credentials, - client_id=self.client_id, - token_refreshing_endpoint=self.token_refreshing_endpoint, api_key_exchange_factory=self.api_key_exchange_callback, client=self.get_token_refreshing_client(), ), @@ -355,14 +349,10 @@ class NeptuneAuthenticator(httpx.Auth): def __init__( self, credentials: Credentials, - client_id: str, - token_refreshing_endpoint: str, api_key_exchange_factory: Callable[[Client, Credentials], OAuthToken], client: Client, ): self._credentials: Credentials = credentials - self._client_id: str = client_id - self._token_refreshing_endpoint: str = token_refreshing_endpoint self._api_key_exchange_factory: Callable[[Client, Credentials], OAuthToken] = api_key_exchange_factory self._client = client @@ -373,11 +363,11 @@ def _refresh_existing_token(self) -> OAuthToken: raise ValueError("Cannot refresh an empty token") try: response = self._client.get_httpx_client().post( - url=self._token_refreshing_endpoint, + url=self._token.token_endpoint, data={ "grant_type": "refresh_token", "refresh_token": self._token.refresh_token, - "client_id": self._client_id, + "client_id": self._token.client_id, "expires_in": self._token.seconds_left, }, ) diff --git a/src/neptune_query/generated/neptune_api/types.py b/src/neptune_query/generated/neptune_api/types.py index 082c5fa7..a57fd510 100644 --- a/src/neptune_query/generated/neptune_api/types.py +++ b/src/neptune_query/generated/neptune_api/types.py @@ -90,6 +90,8 @@ class OAuthToken: _expiration_time: float = field(default=0.0, alias="expiration_time", kw_only=True) access_token: str refresh_token: str + client_id: str + token_endpoint: str @classmethod def from_tokens(cls, access: str, refresh: str) -> "OAuthToken": @@ -97,10 +99,18 @@ def from_tokens(cls, access: str, refresh: str) -> "OAuthToken": try: decoded_token = jwt.decode(access, options=DECODING_OPTIONS) expiration_time = float(decoded_token["exp"]) + client_id = str(decoded_token["azp"]) + issuer = str(decoded_token["iss"]).rstrip("/") except (jwt.ExpiredSignatureError, jwt.InvalidTokenError, KeyError) as e: raise InvalidApiTokenException("Cannot decode the access token") from e - return OAuthToken(access_token=access, refresh_token=refresh, expiration_time=expiration_time) + return OAuthToken( + access_token=access, + refresh_token=refresh, + expiration_time=expiration_time, + client_id=client_id, + token_endpoint=f"{issuer}/protocol/openid-connect/token", + ) @property def seconds_left(self) -> float: diff --git a/src/neptune_query/internal/api_utils.py b/src/neptune_query/internal/api_utils.py index 4d10ed0c..4f08b2b6 100644 --- a/src/neptune_query/internal/api_utils.py +++ b/src/neptune_query/internal/api_utils.py @@ -13,8 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass -from http import HTTPStatus from typing import ( Callable, Dict, @@ -23,91 +21,24 @@ import httpx -from neptune_query.generated.neptune_api import ( - AuthenticatedClient, - Client, -) -from neptune_query.generated.neptune_api.api.backend import get_client_config +from neptune_query.generated.neptune_api import AuthenticatedClient from neptune_query.generated.neptune_api.auth_helpers import exchange_api_key from neptune_query.generated.neptune_api.credentials import Credentials -from neptune_query.generated.neptune_api.models import ClientConfig -from neptune_query.generated.neptune_api.types import Response -from ..exceptions import NeptuneFailedToFetchClientConfig from .env import ( NEPTUNE_HTTP_REQUEST_TIMEOUT_SECONDS, NEPTUNE_VERIFY_SSL, ) -from .retrieval.retry import handle_errors_default - - -@dataclass -class TokenRefreshingURLs: - authorization_endpoint: str - token_endpoint: str - - @classmethod - def from_dict(cls, data: dict) -> "TokenRefreshingURLs": - return TokenRefreshingURLs( - authorization_endpoint=data["authorization_endpoint"], token_endpoint=data["token_endpoint"] - ) - - -def _wrap_httpx_json_response(httpx_response: httpx.Response) -> Response: - """Wrap a httpx.Response into an neptune-api Response object that is compatible - with backoff_retry(). Use .json() as parsed content in the result.""" - - return Response( - status_code=HTTPStatus(httpx_response.status_code), - content=httpx_response.content, - headers=httpx_response.headers, - parsed=httpx_response.json(), - url=str(httpx_response.url), - ) - - -def get_config_and_token_urls( - *, credentials: Credentials, proxies: Optional[Dict[str, str]] = None -) -> tuple[ClientConfig, TokenRefreshingURLs]: - timeout = httpx.Timeout(NEPTUNE_HTTP_REQUEST_TIMEOUT_SECONDS.get()) - with Client( - base_url=credentials.base_url, - httpx_args={"mounts": proxies}, - verify_ssl=NEPTUNE_VERIFY_SSL.get(), - timeout=timeout, - headers={"User-Agent": _generate_user_agent()}, - ) as client: - try: - config_response = handle_errors_default(get_client_config.sync_detailed)(client=client) - config = config_response.parsed - if not isinstance(config, ClientConfig): - raise RuntimeError(f"Expected ClientConfig but got {type(config).__name__}") - - urls_response = handle_errors_default( - lambda: _wrap_httpx_json_response(client.get_httpx_client().get(config.security.open_id_discovery)) - )() - token_urls_dict = urls_response.parsed - if not isinstance(token_urls_dict, dict): - raise RuntimeError(f"Expected dict but got {type(token_urls_dict).__name__}") - token_urls = TokenRefreshingURLs.from_dict(token_urls_dict) - - return config, token_urls - except Exception as e: - raise NeptuneFailedToFetchClientConfig(exception=e) from e def create_auth_api_client( *, credentials: Credentials, - config: ClientConfig, - token_refreshing_urls: TokenRefreshingURLs, proxies: Optional[Dict[str, str]] = None, ) -> AuthenticatedClient: return AuthenticatedClient( base_url=credentials.base_url, credentials=credentials, - client_id=config.security.client_id, - token_refreshing_endpoint=token_refreshing_urls.token_endpoint, api_key_exchange_callback=exchange_api_key, verify_ssl=NEPTUNE_VERIFY_SSL.get(), httpx_args={"mounts": proxies, "http2": False}, diff --git a/src/neptune_query/internal/client.py b/src/neptune_query/internal/client.py index a4cb742b..bc016560 100644 --- a/src/neptune_query/internal/client.py +++ b/src/neptune_query/internal/client.py @@ -28,10 +28,7 @@ from neptune_query.generated.neptune_api import AuthenticatedClient from neptune_query.generated.neptune_api.credentials import Credentials -from .api_utils import ( - create_auth_api_client, - get_config_and_token_urls, -) +from .api_utils import create_auth_api_client from .context import Context # Disable httpx logging, httpx logs requests at INFO level @@ -58,13 +55,7 @@ def get_client(context: Context, proxies: Optional[Dict[str, str]] = None) -> Au raise ValueError("API token is not set") credentials = Credentials.from_api_key(api_key=context.api_token) - config, token_urls = get_config_and_token_urls(credentials=credentials, proxies=proxies) - client = create_auth_api_client( - credentials=credentials, - config=config, - token_refreshing_urls=token_urls, - proxies=proxies, - ) + client = create_auth_api_client(credentials=credentials, proxies=proxies) _cache[hash_key] = client return client diff --git a/tests/e2e/conftest.py b/tests/e2e/conftest.py index 26185c32..4596fbb1 100644 --- a/tests/e2e/conftest.py +++ b/tests/e2e/conftest.py @@ -16,10 +16,7 @@ from neptune_query.generated.neptune_api import AuthenticatedClient from neptune_query.generated.neptune_api.credentials import Credentials -from neptune_query.internal.api_utils import ( - create_auth_api_client, - get_config_and_token_urls, -) +from neptune_query.internal.api_utils import create_auth_api_client from neptune_query.internal.composition import concurrency from neptune_query.internal.context import set_api_token from tests.e2e.data_ingestion import ( @@ -93,10 +90,7 @@ def set_api_token_auto(api_token) -> None: @pytest.fixture(scope="session") def client(api_token) -> AuthenticatedClient: credentials = Credentials.from_api_key(api_key=api_token) - config, token_urls = get_config_and_token_urls(credentials=credentials, proxies=None) - client = create_auth_api_client( - credentials=credentials, config=config, token_refreshing_urls=token_urls, proxies=None - ) + client = create_auth_api_client(credentials=credentials, proxies=None) return client diff --git a/tests/performance_e2e/conftest.py b/tests/performance_e2e/conftest.py index 37b28d4d..87de2cdf 100644 --- a/tests/performance_e2e/conftest.py +++ b/tests/performance_e2e/conftest.py @@ -210,12 +210,17 @@ def http_client(monkeypatch, backend_base_url: str, api_token: str) -> ClientPro Returns: A wrapper for the Neptune API client with header management """ - never_expiring_token = OAuthToken(access_token="x", refresh_token="x", expiration_time=time.time() + 10_000_000) + realm_base_url = f"{backend_base_url.rstrip('/')}/auth/realms/neptune" + never_expiring_token = OAuthToken( + access_token="x", + refresh_token="x", + expiration_time=time.time() + 10_000_000, + client_id="perf-test-client-id", + token_endpoint=f"{realm_base_url}/protocol/openid-connect/token", + ) patched_client = AuthenticatedClient( base_url=backend_base_url, credentials=Credentials.from_api_key(api_token), - client_id="", - token_refreshing_endpoint="", api_key_exchange_callback=lambda _client, _credentials: never_expiring_token, verify_ssl=False, httpx_args={"http2": False}, diff --git a/tests/unit/internal/test_api_client_cache.py b/tests/unit/internal/test_api_client_cache.py index 95cb0bc9..d3ade017 100644 --- a/tests/unit/internal/test_api_client_cache.py +++ b/tests/unit/internal/test_api_client_cache.py @@ -69,11 +69,7 @@ def clear_cache_before_test(): @fixture(autouse=True) def mock_networking(): - with ( - patch("neptune_query.internal.client.get_config_and_token_urls") as get_config_and_token_urls, - patch("neptune_query.internal.client.create_auth_api_client") as create_auth_api_client, - ): - get_config_and_token_urls.return_value = (Mock(), Mock()) + with patch("neptune_query.internal.client.create_auth_api_client") as create_auth_api_client: # create_auth_api_client() needs to return a different "client" each time create_auth_api_client.side_effect = lambda *args, **kwargs: Mock() yield diff --git a/tests/unit/neptune_api/conftest.py b/tests/unit/neptune_api/conftest.py index 74e6c3e9..198b22cd 100644 --- a/tests/unit/neptune_api/conftest.py +++ b/tests/unit/neptune_api/conftest.py @@ -18,6 +18,8 @@ FIXED_TIME = datetime(2024, 1, 2, 3, 4, 5, tzinfo=timezone.utc) EXPIRATION_TIME = FIXED_TIME + timedelta(seconds=MINIMAL_EXPIRATION_SECONDS + 1) +ISSUER = "https://dev.neptune.internal.openai.org/auth/realms/neptune" +CLIENT_ID = "test-client-id" @pytest.fixture @@ -30,14 +32,30 @@ def credentials() -> Credentials: @pytest.fixture def oauth_token() -> OAuthToken: return OAuthToken.from_tokens( - access=jwt.encode({"exp": datetime.timestamp(FIXED_TIME)}, "secret", algorithm="HS256"), - refresh=jwt.encode({"exp": datetime.timestamp(FIXED_TIME)}, "secret", algorithm="HS256"), + access=jwt.encode( + {"exp": datetime.timestamp(FIXED_TIME), "azp": CLIENT_ID, "iss": ISSUER}, + "secret", + algorithm="HS256", + ), + refresh=jwt.encode( + {"exp": datetime.timestamp(FIXED_TIME), "azp": CLIENT_ID, "iss": ISSUER}, + "secret", + algorithm="HS256", + ), ) @pytest.fixture def expired_oauth_token() -> OAuthToken: return OAuthToken.from_tokens( - access=jwt.encode({"exp": datetime.timestamp(EXPIRATION_TIME)}, "secret", algorithm="HS256"), - refresh=jwt.encode({"exp": datetime.timestamp(EXPIRATION_TIME)}, "secret", algorithm="HS256"), + access=jwt.encode( + {"exp": datetime.timestamp(EXPIRATION_TIME), "azp": CLIENT_ID, "iss": ISSUER}, + "secret", + algorithm="HS256", + ), + refresh=jwt.encode( + {"exp": datetime.timestamp(EXPIRATION_TIME), "azp": CLIENT_ID, "iss": ISSUER}, + "secret", + algorithm="HS256", + ), ) diff --git a/tests/unit/neptune_api/test_authenticator.py b/tests/unit/neptune_api/test_authenticator.py index 30300234..3e2e0dcb 100644 --- a/tests/unit/neptune_api/test_authenticator.py +++ b/tests/unit/neptune_api/test_authenticator.py @@ -8,8 +8,6 @@ def test_use_token_factory(mocker, credentials, oauth_token): client = mocker.MagicMock() authenticator = NeptuneAuthenticator( credentials=credentials, - client_id="client_id", - token_refreshing_endpoint="https://api.neptune.ai/oauth/token", api_key_exchange_factory=(lambda _, __: oauth_token), client=client, ) @@ -29,8 +27,6 @@ def test_refresh(mocker, credentials, expired_oauth_token, oauth_token): client = mocker.MagicMock() authenticator = NeptuneAuthenticator( credentials=credentials, - client_id="client_id", - token_refreshing_endpoint="https://api.neptune.ai/oauth/token", client=client, api_key_exchange_factory=token_factory_stub, ) @@ -53,6 +49,7 @@ def test_refresh(mocker, credentials, expired_oauth_token, oauth_token): # then assert request.headers["Authorization"] == f"Bearer {oauth_token.access_token}" + assert client.get_httpx_client().post.call_args.kwargs["url"] == expired_oauth_token.token_endpoint def token_factory_stub(client, credentials): diff --git a/tests/unit/neptune_api/test_oauth_token.py b/tests/unit/neptune_api/test_oauth_token.py index cee26765..27dec961 100644 --- a/tests/unit/neptune_api/test_oauth_token.py +++ b/tests/unit/neptune_api/test_oauth_token.py @@ -26,6 +26,7 @@ def test_almost_expired(oauth_token): # then assert token.seconds_left == 1 assert token.is_expired is False + assert token.token_endpoint.endswith("/protocol/openid-connect/token") @freeze_time("2024-01-02 03:03:35 UTC")