Skip to content

Commit 3261d21

Browse files
Optimize cold start of neptune-query
1 parent 5701269 commit 3261d21

File tree

12 files changed

+64
-131
lines changed

12 files changed

+64
-131
lines changed

src/neptune_api_codegen/docker/rofiles/neptune_api/client.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -198,8 +198,6 @@ class AuthenticatedClient:
198198
status code that was not documented in the source OpenAPI document. Can also be provided as a keyword
199199
argument to the constructor.
200200
credentials: User credentials for authentication.
201-
token_refreshing_endpoint: Token refreshing endpoint url
202-
client_id: Client identifier for the OAuth application.
203201
api_key_exchange_callback: The Neptune API Token exchange function
204202
prefix: The prefix to use for the Authorization header
205203
auth_header_name: The name of the Authorization header
@@ -218,8 +216,6 @@ class AuthenticatedClient:
218216
_async_client: Optional[httpx.AsyncClient] = field(default=None, init=False)
219217

220218
credentials: Credentials
221-
token_refreshing_endpoint: str
222-
client_id: str
223219
api_key_exchange_callback: Callable[[Client, Credentials], OAuthToken]
224220
prefix: str = "Bearer"
225221
auth_header_name: str = "Authorization"
@@ -299,8 +295,6 @@ def get_httpx_client(self) -> httpx.Client:
299295
self._client = httpx.Client(
300296
auth=NeptuneAuthenticator(
301297
credentials=self.credentials,
302-
client_id=self.client_id,
303-
token_refreshing_endpoint=self.token_refreshing_endpoint,
304298
api_key_exchange_factory=self.api_key_exchange_callback,
305299
client=self.get_token_refreshing_client(),
306300
),
@@ -355,14 +349,10 @@ class NeptuneAuthenticator(httpx.Auth):
355349
def __init__(
356350
self,
357351
credentials: Credentials,
358-
client_id: str,
359-
token_refreshing_endpoint: str,
360352
api_key_exchange_factory: Callable[[Client, Credentials], OAuthToken],
361353
client: Client,
362354
):
363355
self._credentials: Credentials = credentials
364-
self._client_id: str = client_id
365-
self._token_refreshing_endpoint: str = token_refreshing_endpoint
366356
self._api_key_exchange_factory: Callable[[Client, Credentials], OAuthToken] = api_key_exchange_factory
367357

368358
self._client = client
@@ -373,11 +363,11 @@ def _refresh_existing_token(self) -> OAuthToken:
373363
raise ValueError("Cannot refresh an empty token")
374364
try:
375365
response = self._client.get_httpx_client().post(
376-
url=self._token_refreshing_endpoint,
366+
url=self._token.token_endpoint,
377367
data={
378368
"grant_type": "refresh_token",
379369
"refresh_token": self._token.refresh_token,
380-
"client_id": self._client_id,
370+
"client_id": self._token.client_id,
381371
"expires_in": self._token.seconds_left,
382372
},
383373
)

src/neptune_api_codegen/docker/rofiles/templates/types.py.jinja

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,17 +91,27 @@ class OAuthToken:
9191
_expiration_time: float = field(default=0.0, alias="expiration_time", kw_only=True)
9292
access_token: str
9393
refresh_token: str
94+
client_id: str
95+
token_endpoint: str
9496

9597
@classmethod
9698
def from_tokens(cls, access: str, refresh: str) -> "OAuthToken":
9799
# Decode the JWT to get expiration time
98100
try:
99101
decoded_token = jwt.decode(access, options=DECODING_OPTIONS)
100102
expiration_time = float(decoded_token["exp"])
103+
client_id = str(decoded_token["azp"])
104+
issuer = str(decoded_token["iss"]).rstrip("/")
101105
except (jwt.ExpiredSignatureError, jwt.InvalidTokenError, KeyError) as e:
102106
raise InvalidApiTokenException("Cannot decode the access token") from e
103107

104-
return OAuthToken(access_token=access, refresh_token=refresh, expiration_time=expiration_time)
108+
return OAuthToken(
109+
access_token=access,
110+
refresh_token=refresh,
111+
expiration_time=expiration_time,
112+
client_id=client_id,
113+
token_endpoint=f"{issuer}/protocol/openid-connect/token",
114+
)
105115

106116
@property
107117
def seconds_left(self) -> float:

src/neptune_query/generated/neptune_api/client.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -198,8 +198,6 @@ class AuthenticatedClient:
198198
status code that was not documented in the source OpenAPI document. Can also be provided as a keyword
199199
argument to the constructor.
200200
credentials: User credentials for authentication.
201-
token_refreshing_endpoint: Token refreshing endpoint url
202-
client_id: Client identifier for the OAuth application.
203201
api_key_exchange_callback: The Neptune API Token exchange function
204202
prefix: The prefix to use for the Authorization header
205203
auth_header_name: The name of the Authorization header
@@ -218,8 +216,6 @@ class AuthenticatedClient:
218216
_async_client: Optional[httpx.AsyncClient] = field(default=None, init=False)
219217

220218
credentials: Credentials
221-
token_refreshing_endpoint: str
222-
client_id: str
223219
api_key_exchange_callback: Callable[[Client, Credentials], OAuthToken]
224220
prefix: str = "Bearer"
225221
auth_header_name: str = "Authorization"
@@ -299,8 +295,6 @@ def get_httpx_client(self) -> httpx.Client:
299295
self._client = httpx.Client(
300296
auth=NeptuneAuthenticator(
301297
credentials=self.credentials,
302-
client_id=self.client_id,
303-
token_refreshing_endpoint=self.token_refreshing_endpoint,
304298
api_key_exchange_factory=self.api_key_exchange_callback,
305299
client=self.get_token_refreshing_client(),
306300
),
@@ -355,14 +349,10 @@ class NeptuneAuthenticator(httpx.Auth):
355349
def __init__(
356350
self,
357351
credentials: Credentials,
358-
client_id: str,
359-
token_refreshing_endpoint: str,
360352
api_key_exchange_factory: Callable[[Client, Credentials], OAuthToken],
361353
client: Client,
362354
):
363355
self._credentials: Credentials = credentials
364-
self._client_id: str = client_id
365-
self._token_refreshing_endpoint: str = token_refreshing_endpoint
366356
self._api_key_exchange_factory: Callable[[Client, Credentials], OAuthToken] = api_key_exchange_factory
367357

368358
self._client = client
@@ -373,11 +363,11 @@ def _refresh_existing_token(self) -> OAuthToken:
373363
raise ValueError("Cannot refresh an empty token")
374364
try:
375365
response = self._client.get_httpx_client().post(
376-
url=self._token_refreshing_endpoint,
366+
url=self._token.token_endpoint,
377367
data={
378368
"grant_type": "refresh_token",
379369
"refresh_token": self._token.refresh_token,
380-
"client_id": self._client_id,
370+
"client_id": self._token.client_id,
381371
"expires_in": self._token.seconds_left,
382372
},
383373
)

src/neptune_query/generated/neptune_api/types.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,17 +90,27 @@ class OAuthToken:
9090
_expiration_time: float = field(default=0.0, alias="expiration_time", kw_only=True)
9191
access_token: str
9292
refresh_token: str
93+
client_id: str
94+
token_endpoint: str
9395

9496
@classmethod
9597
def from_tokens(cls, access: str, refresh: str) -> "OAuthToken":
9698
# Decode the JWT to get expiration time
9799
try:
98100
decoded_token = jwt.decode(access, options=DECODING_OPTIONS)
99101
expiration_time = float(decoded_token["exp"])
102+
client_id = str(decoded_token["azp"])
103+
issuer = str(decoded_token["iss"]).rstrip("/")
100104
except (jwt.ExpiredSignatureError, jwt.InvalidTokenError, KeyError) as e:
101105
raise InvalidApiTokenException("Cannot decode the access token") from e
102106

103-
return OAuthToken(access_token=access, refresh_token=refresh, expiration_time=expiration_time)
107+
return OAuthToken(
108+
access_token=access,
109+
refresh_token=refresh,
110+
expiration_time=expiration_time,
111+
client_id=client_id,
112+
token_endpoint=f"{issuer}/protocol/openid-connect/token",
113+
)
104114

105115
@property
106116
def seconds_left(self) -> float:

src/neptune_query/internal/api_utils.py

Lines changed: 1 addition & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,6 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
from dataclasses import dataclass
17-
from http import HTTPStatus
1816
from typing import (
1917
Callable,
2018
Dict,
@@ -23,91 +21,24 @@
2321

2422
import httpx
2523

26-
from neptune_query.generated.neptune_api import (
27-
AuthenticatedClient,
28-
Client,
29-
)
30-
from neptune_query.generated.neptune_api.api.backend import get_client_config
24+
from neptune_query.generated.neptune_api import AuthenticatedClient
3125
from neptune_query.generated.neptune_api.auth_helpers import exchange_api_key
3226
from neptune_query.generated.neptune_api.credentials import Credentials
33-
from neptune_query.generated.neptune_api.models import ClientConfig
34-
from neptune_query.generated.neptune_api.types import Response
3527

36-
from ..exceptions import NeptuneFailedToFetchClientConfig
3728
from .env import (
3829
NEPTUNE_HTTP_REQUEST_TIMEOUT_SECONDS,
3930
NEPTUNE_VERIFY_SSL,
4031
)
41-
from .retrieval.retry import handle_errors_default
42-
43-
44-
@dataclass
45-
class TokenRefreshingURLs:
46-
authorization_endpoint: str
47-
token_endpoint: str
48-
49-
@classmethod
50-
def from_dict(cls, data: dict) -> "TokenRefreshingURLs":
51-
return TokenRefreshingURLs(
52-
authorization_endpoint=data["authorization_endpoint"], token_endpoint=data["token_endpoint"]
53-
)
54-
55-
56-
def _wrap_httpx_json_response(httpx_response: httpx.Response) -> Response:
57-
"""Wrap a httpx.Response into an neptune-api Response object that is compatible
58-
with backoff_retry(). Use .json() as parsed content in the result."""
59-
60-
return Response(
61-
status_code=HTTPStatus(httpx_response.status_code),
62-
content=httpx_response.content,
63-
headers=httpx_response.headers,
64-
parsed=httpx_response.json(),
65-
url=str(httpx_response.url),
66-
)
67-
68-
69-
def get_config_and_token_urls(
70-
*, credentials: Credentials, proxies: Optional[Dict[str, str]] = None
71-
) -> tuple[ClientConfig, TokenRefreshingURLs]:
72-
timeout = httpx.Timeout(NEPTUNE_HTTP_REQUEST_TIMEOUT_SECONDS.get())
73-
with Client(
74-
base_url=credentials.base_url,
75-
httpx_args={"mounts": proxies},
76-
verify_ssl=NEPTUNE_VERIFY_SSL.get(),
77-
timeout=timeout,
78-
headers={"User-Agent": _generate_user_agent()},
79-
) as client:
80-
try:
81-
config_response = handle_errors_default(get_client_config.sync_detailed)(client=client)
82-
config = config_response.parsed
83-
if not isinstance(config, ClientConfig):
84-
raise RuntimeError(f"Expected ClientConfig but got {type(config).__name__}")
85-
86-
urls_response = handle_errors_default(
87-
lambda: _wrap_httpx_json_response(client.get_httpx_client().get(config.security.open_id_discovery))
88-
)()
89-
token_urls_dict = urls_response.parsed
90-
if not isinstance(token_urls_dict, dict):
91-
raise RuntimeError(f"Expected dict but got {type(token_urls_dict).__name__}")
92-
token_urls = TokenRefreshingURLs.from_dict(token_urls_dict)
93-
94-
return config, token_urls
95-
except Exception as e:
96-
raise NeptuneFailedToFetchClientConfig(exception=e) from e
9732

9833

9934
def create_auth_api_client(
10035
*,
10136
credentials: Credentials,
102-
config: ClientConfig,
103-
token_refreshing_urls: TokenRefreshingURLs,
10437
proxies: Optional[Dict[str, str]] = None,
10538
) -> AuthenticatedClient:
10639
return AuthenticatedClient(
10740
base_url=credentials.base_url,
10841
credentials=credentials,
109-
client_id=config.security.client_id,
110-
token_refreshing_endpoint=token_refreshing_urls.token_endpoint,
11142
api_key_exchange_callback=exchange_api_key,
11243
verify_ssl=NEPTUNE_VERIFY_SSL.get(),
11344
httpx_args={"mounts": proxies, "http2": False},

src/neptune_query/internal/client.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,7 @@
2828
from neptune_query.generated.neptune_api import AuthenticatedClient
2929
from neptune_query.generated.neptune_api.credentials import Credentials
3030

31-
from .api_utils import (
32-
create_auth_api_client,
33-
get_config_and_token_urls,
34-
)
31+
from .api_utils import create_auth_api_client
3532
from .context import Context
3633

3734
# Disable httpx logging, httpx logs requests at INFO level
@@ -58,13 +55,7 @@ def get_client(context: Context, proxies: Optional[Dict[str, str]] = None) -> Au
5855
raise ValueError("API token is not set")
5956

6057
credentials = Credentials.from_api_key(api_key=context.api_token)
61-
config, token_urls = get_config_and_token_urls(credentials=credentials, proxies=proxies)
62-
client = create_auth_api_client(
63-
credentials=credentials,
64-
config=config,
65-
token_refreshing_urls=token_urls,
66-
proxies=proxies,
67-
)
58+
client = create_auth_api_client(credentials=credentials, proxies=proxies)
6859

6960
_cache[hash_key] = client
7061
return client

tests/e2e/conftest.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,7 @@
1616

1717
from neptune_query.generated.neptune_api import AuthenticatedClient
1818
from neptune_query.generated.neptune_api.credentials import Credentials
19-
from neptune_query.internal.api_utils import (
20-
create_auth_api_client,
21-
get_config_and_token_urls,
22-
)
19+
from neptune_query.internal.api_utils import create_auth_api_client
2320
from neptune_query.internal.composition import concurrency
2421
from neptune_query.internal.context import set_api_token
2522
from tests.e2e.data_ingestion import (
@@ -93,10 +90,7 @@ def set_api_token_auto(api_token) -> None:
9390
@pytest.fixture(scope="session")
9491
def client(api_token) -> AuthenticatedClient:
9592
credentials = Credentials.from_api_key(api_key=api_token)
96-
config, token_urls = get_config_and_token_urls(credentials=credentials, proxies=None)
97-
client = create_auth_api_client(
98-
credentials=credentials, config=config, token_refreshing_urls=token_urls, proxies=None
99-
)
93+
client = create_auth_api_client(credentials=credentials, proxies=None)
10094

10195
return client
10296

tests/performance_e2e/conftest.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -210,12 +210,17 @@ def http_client(monkeypatch, backend_base_url: str, api_token: str) -> ClientPro
210210
Returns:
211211
A wrapper for the Neptune API client with header management
212212
"""
213-
never_expiring_token = OAuthToken(access_token="x", refresh_token="x", expiration_time=time.time() + 10_000_000)
213+
realm_base_url = f"{backend_base_url.rstrip('/')}/auth/realms/neptune"
214+
never_expiring_token = OAuthToken(
215+
access_token="x",
216+
refresh_token="x",
217+
expiration_time=time.time() + 10_000_000,
218+
client_id="perf-test-client-id",
219+
token_endpoint=f"{realm_base_url}/protocol/openid-connect/token",
220+
)
214221
patched_client = AuthenticatedClient(
215222
base_url=backend_base_url,
216223
credentials=Credentials.from_api_key(api_token),
217-
client_id="",
218-
token_refreshing_endpoint="",
219224
api_key_exchange_callback=lambda _client, _credentials: never_expiring_token,
220225
verify_ssl=False,
221226
httpx_args={"http2": False},

tests/unit/internal/test_api_client_cache.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -69,11 +69,7 @@ def clear_cache_before_test():
6969

7070
@fixture(autouse=True)
7171
def mock_networking():
72-
with (
73-
patch("neptune_query.internal.client.get_config_and_token_urls") as get_config_and_token_urls,
74-
patch("neptune_query.internal.client.create_auth_api_client") as create_auth_api_client,
75-
):
76-
get_config_and_token_urls.return_value = (Mock(), Mock())
72+
with patch("neptune_query.internal.client.create_auth_api_client") as create_auth_api_client:
7773
# create_auth_api_client() needs to return a different "client" each time
7874
create_auth_api_client.side_effect = lambda *args, **kwargs: Mock()
7975
yield

0 commit comments

Comments
 (0)