Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 2 additions & 12 deletions src/neptune_api_codegen/docker/rofiles/neptune_api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,8 +198,6 @@ class AuthenticatedClient:
status code that was not documented in the source OpenAPI document. Can also be provided as a keyword
argument to the constructor.
credentials: User credentials for authentication.
token_refreshing_endpoint: Token refreshing endpoint url
client_id: Client identifier for the OAuth application.
api_key_exchange_callback: The Neptune API Token exchange function
prefix: The prefix to use for the Authorization header
auth_header_name: The name of the Authorization header
Expand All @@ -218,8 +216,6 @@ class AuthenticatedClient:
_async_client: Optional[httpx.AsyncClient] = field(default=None, init=False)

credentials: Credentials
token_refreshing_endpoint: str
client_id: str
api_key_exchange_callback: Callable[[Client, Credentials], OAuthToken]
prefix: str = "Bearer"
auth_header_name: str = "Authorization"
Expand Down Expand Up @@ -299,8 +295,6 @@ def get_httpx_client(self) -> httpx.Client:
self._client = httpx.Client(
auth=NeptuneAuthenticator(
credentials=self.credentials,
client_id=self.client_id,
token_refreshing_endpoint=self.token_refreshing_endpoint,
api_key_exchange_factory=self.api_key_exchange_callback,
client=self.get_token_refreshing_client(),
),
Expand Down Expand Up @@ -355,14 +349,10 @@ class NeptuneAuthenticator(httpx.Auth):
def __init__(
self,
credentials: Credentials,
client_id: str,
token_refreshing_endpoint: str,
api_key_exchange_factory: Callable[[Client, Credentials], OAuthToken],
client: Client,
):
self._credentials: Credentials = credentials
self._client_id: str = client_id
self._token_refreshing_endpoint: str = token_refreshing_endpoint
self._api_key_exchange_factory: Callable[[Client, Credentials], OAuthToken] = api_key_exchange_factory

self._client = client
Expand All @@ -373,11 +363,11 @@ def _refresh_existing_token(self) -> OAuthToken:
raise ValueError("Cannot refresh an empty token")
try:
response = self._client.get_httpx_client().post(
url=self._token_refreshing_endpoint,
url=self._token.token_endpoint,
data={
"grant_type": "refresh_token",
"refresh_token": self._token.refresh_token,
"client_id": self._client_id,
"client_id": self._token.client_id,
"expires_in": self._token.seconds_left,
},
)
Expand Down
12 changes: 11 additions & 1 deletion src/neptune_api_codegen/docker/rofiles/templates/types.py.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -91,17 +91,27 @@ class OAuthToken:
_expiration_time: float = field(default=0.0, alias="expiration_time", kw_only=True)
access_token: str
refresh_token: str
client_id: str
token_endpoint: str

@classmethod
def from_tokens(cls, access: str, refresh: str) -> "OAuthToken":
# Decode the JWT to get expiration time
try:
decoded_token = jwt.decode(access, options=DECODING_OPTIONS)
expiration_time = float(decoded_token["exp"])
client_id = str(decoded_token["azp"])
issuer = str(decoded_token["iss"]).rstrip("/")
except (jwt.ExpiredSignatureError, jwt.InvalidTokenError, KeyError) as e:
raise InvalidApiTokenException("Cannot decode the access token") from e

return OAuthToken(access_token=access, refresh_token=refresh, expiration_time=expiration_time)
return OAuthToken(
access_token=access,
refresh_token=refresh,
expiration_time=expiration_time,
client_id=client_id,
token_endpoint=f"{issuer}/protocol/openid-connect/token",
)

@property
def seconds_left(self) -> float:
Expand Down
14 changes: 2 additions & 12 deletions src/neptune_query/generated/neptune_api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,8 +198,6 @@ class AuthenticatedClient:
status code that was not documented in the source OpenAPI document. Can also be provided as a keyword
argument to the constructor.
credentials: User credentials for authentication.
token_refreshing_endpoint: Token refreshing endpoint url
client_id: Client identifier for the OAuth application.
api_key_exchange_callback: The Neptune API Token exchange function
prefix: The prefix to use for the Authorization header
auth_header_name: The name of the Authorization header
Expand All @@ -218,8 +216,6 @@ class AuthenticatedClient:
_async_client: Optional[httpx.AsyncClient] = field(default=None, init=False)

credentials: Credentials
token_refreshing_endpoint: str
client_id: str
api_key_exchange_callback: Callable[[Client, Credentials], OAuthToken]
prefix: str = "Bearer"
auth_header_name: str = "Authorization"
Expand Down Expand Up @@ -299,8 +295,6 @@ def get_httpx_client(self) -> httpx.Client:
self._client = httpx.Client(
auth=NeptuneAuthenticator(
credentials=self.credentials,
client_id=self.client_id,
token_refreshing_endpoint=self.token_refreshing_endpoint,
api_key_exchange_factory=self.api_key_exchange_callback,
client=self.get_token_refreshing_client(),
),
Expand Down Expand Up @@ -355,14 +349,10 @@ class NeptuneAuthenticator(httpx.Auth):
def __init__(
self,
credentials: Credentials,
client_id: str,
token_refreshing_endpoint: str,
api_key_exchange_factory: Callable[[Client, Credentials], OAuthToken],
client: Client,
):
self._credentials: Credentials = credentials
self._client_id: str = client_id
self._token_refreshing_endpoint: str = token_refreshing_endpoint
self._api_key_exchange_factory: Callable[[Client, Credentials], OAuthToken] = api_key_exchange_factory

self._client = client
Expand All @@ -373,11 +363,11 @@ def _refresh_existing_token(self) -> OAuthToken:
raise ValueError("Cannot refresh an empty token")
try:
response = self._client.get_httpx_client().post(
url=self._token_refreshing_endpoint,
url=self._token.token_endpoint,
data={
"grant_type": "refresh_token",
"refresh_token": self._token.refresh_token,
"client_id": self._client_id,
"client_id": self._token.client_id,
"expires_in": self._token.seconds_left,
},
)
Expand Down
12 changes: 11 additions & 1 deletion src/neptune_query/generated/neptune_api/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,17 +90,27 @@ class OAuthToken:
_expiration_time: float = field(default=0.0, alias="expiration_time", kw_only=True)
access_token: str
refresh_token: str
client_id: str
token_endpoint: str

@classmethod
def from_tokens(cls, access: str, refresh: str) -> "OAuthToken":
# Decode the JWT to get expiration time
try:
decoded_token = jwt.decode(access, options=DECODING_OPTIONS)
expiration_time = float(decoded_token["exp"])
client_id = str(decoded_token["azp"])
issuer = str(decoded_token["iss"]).rstrip("/")
except (jwt.ExpiredSignatureError, jwt.InvalidTokenError, KeyError) as e:
raise InvalidApiTokenException("Cannot decode the access token") from e

return OAuthToken(access_token=access, refresh_token=refresh, expiration_time=expiration_time)
return OAuthToken(
access_token=access,
refresh_token=refresh,
expiration_time=expiration_time,
client_id=client_id,
token_endpoint=f"{issuer}/protocol/openid-connect/token",
)

@property
def seconds_left(self) -> float:
Expand Down
71 changes: 1 addition & 70 deletions src/neptune_query/internal/api_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass
from http import HTTPStatus
from typing import (
Callable,
Dict,
Expand All @@ -23,91 +21,24 @@

import httpx

from neptune_query.generated.neptune_api import (
AuthenticatedClient,
Client,
)
from neptune_query.generated.neptune_api.api.backend import get_client_config
from neptune_query.generated.neptune_api import AuthenticatedClient
from neptune_query.generated.neptune_api.auth_helpers import exchange_api_key
from neptune_query.generated.neptune_api.credentials import Credentials
from neptune_query.generated.neptune_api.models import ClientConfig
from neptune_query.generated.neptune_api.types import Response

from ..exceptions import NeptuneFailedToFetchClientConfig
from .env import (
NEPTUNE_HTTP_REQUEST_TIMEOUT_SECONDS,
NEPTUNE_VERIFY_SSL,
)
from .retrieval.retry import handle_errors_default


@dataclass
class TokenRefreshingURLs:
authorization_endpoint: str
token_endpoint: str

@classmethod
def from_dict(cls, data: dict) -> "TokenRefreshingURLs":
return TokenRefreshingURLs(
authorization_endpoint=data["authorization_endpoint"], token_endpoint=data["token_endpoint"]
)


def _wrap_httpx_json_response(httpx_response: httpx.Response) -> Response:
"""Wrap a httpx.Response into an neptune-api Response object that is compatible
with backoff_retry(). Use .json() as parsed content in the result."""

return Response(
status_code=HTTPStatus(httpx_response.status_code),
content=httpx_response.content,
headers=httpx_response.headers,
parsed=httpx_response.json(),
url=str(httpx_response.url),
)


def get_config_and_token_urls(
*, credentials: Credentials, proxies: Optional[Dict[str, str]] = None
) -> tuple[ClientConfig, TokenRefreshingURLs]:
timeout = httpx.Timeout(NEPTUNE_HTTP_REQUEST_TIMEOUT_SECONDS.get())
with Client(
base_url=credentials.base_url,
httpx_args={"mounts": proxies},
verify_ssl=NEPTUNE_VERIFY_SSL.get(),
timeout=timeout,
headers={"User-Agent": _generate_user_agent()},
) as client:
try:
config_response = handle_errors_default(get_client_config.sync_detailed)(client=client)
config = config_response.parsed
if not isinstance(config, ClientConfig):
raise RuntimeError(f"Expected ClientConfig but got {type(config).__name__}")

urls_response = handle_errors_default(
lambda: _wrap_httpx_json_response(client.get_httpx_client().get(config.security.open_id_discovery))
)()
token_urls_dict = urls_response.parsed
if not isinstance(token_urls_dict, dict):
raise RuntimeError(f"Expected dict but got {type(token_urls_dict).__name__}")
token_urls = TokenRefreshingURLs.from_dict(token_urls_dict)

return config, token_urls
except Exception as e:
raise NeptuneFailedToFetchClientConfig(exception=e) from e


def create_auth_api_client(
*,
credentials: Credentials,
config: ClientConfig,
token_refreshing_urls: TokenRefreshingURLs,
proxies: Optional[Dict[str, str]] = None,
) -> AuthenticatedClient:
return AuthenticatedClient(
base_url=credentials.base_url,
credentials=credentials,
client_id=config.security.client_id,
token_refreshing_endpoint=token_refreshing_urls.token_endpoint,
api_key_exchange_callback=exchange_api_key,
verify_ssl=NEPTUNE_VERIFY_SSL.get(),
httpx_args={"mounts": proxies, "http2": False},
Expand Down
13 changes: 2 additions & 11 deletions src/neptune_query/internal/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,7 @@
from neptune_query.generated.neptune_api import AuthenticatedClient
from neptune_query.generated.neptune_api.credentials import Credentials

from .api_utils import (
create_auth_api_client,
get_config_and_token_urls,
)
from .api_utils import create_auth_api_client
from .context import Context

# Disable httpx logging, httpx logs requests at INFO level
Expand All @@ -58,13 +55,7 @@ def get_client(context: Context, proxies: Optional[Dict[str, str]] = None) -> Au
raise ValueError("API token is not set")

credentials = Credentials.from_api_key(api_key=context.api_token)
config, token_urls = get_config_and_token_urls(credentials=credentials, proxies=proxies)
client = create_auth_api_client(
credentials=credentials,
config=config,
token_refreshing_urls=token_urls,
proxies=proxies,
)
client = create_auth_api_client(credentials=credentials, proxies=proxies)

_cache[hash_key] = client
return client
Expand Down
10 changes: 2 additions & 8 deletions tests/e2e/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,7 @@

from neptune_query.generated.neptune_api import AuthenticatedClient
from neptune_query.generated.neptune_api.credentials import Credentials
from neptune_query.internal.api_utils import (
create_auth_api_client,
get_config_and_token_urls,
)
from neptune_query.internal.api_utils import create_auth_api_client
from neptune_query.internal.composition import concurrency
from neptune_query.internal.context import set_api_token
from tests.e2e.data_ingestion import (
Expand Down Expand Up @@ -93,10 +90,7 @@ def set_api_token_auto(api_token) -> None:
@pytest.fixture(scope="session")
def client(api_token) -> AuthenticatedClient:
credentials = Credentials.from_api_key(api_key=api_token)
config, token_urls = get_config_and_token_urls(credentials=credentials, proxies=None)
client = create_auth_api_client(
credentials=credentials, config=config, token_refreshing_urls=token_urls, proxies=None
)
client = create_auth_api_client(credentials=credentials, proxies=None)

return client

Expand Down
11 changes: 8 additions & 3 deletions tests/performance_e2e/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,12 +210,17 @@ def http_client(monkeypatch, backend_base_url: str, api_token: str) -> ClientPro
Returns:
A wrapper for the Neptune API client with header management
"""
never_expiring_token = OAuthToken(access_token="x", refresh_token="x", expiration_time=time.time() + 10_000_000)
realm_base_url = f"{backend_base_url.rstrip('/')}/auth/realms/neptune"
never_expiring_token = OAuthToken(
access_token="x",
refresh_token="x",
expiration_time=time.time() + 10_000_000,
client_id="perf-test-client-id",
token_endpoint=f"{realm_base_url}/protocol/openid-connect/token",
)
patched_client = AuthenticatedClient(
base_url=backend_base_url,
credentials=Credentials.from_api_key(api_token),
client_id="",
token_refreshing_endpoint="",
api_key_exchange_callback=lambda _client, _credentials: never_expiring_token,
verify_ssl=False,
httpx_args={"http2": False},
Expand Down
6 changes: 1 addition & 5 deletions tests/unit/internal/test_api_client_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,7 @@ def clear_cache_before_test():

@fixture(autouse=True)
def mock_networking():
with (
patch("neptune_query.internal.client.get_config_and_token_urls") as get_config_and_token_urls,
patch("neptune_query.internal.client.create_auth_api_client") as create_auth_api_client,
):
get_config_and_token_urls.return_value = (Mock(), Mock())
with patch("neptune_query.internal.client.create_auth_api_client") as create_auth_api_client:
# create_auth_api_client() needs to return a different "client" each time
create_auth_api_client.side_effect = lambda *args, **kwargs: Mock()
yield
Expand Down
Loading
Loading