diff --git a/mkdocs/docs/configuration.md b/mkdocs/docs/configuration.md index 40cfc0b8c9..0bafbef1fc 100644 --- a/mkdocs/docs/configuration.md +++ b/mkdocs/docs/configuration.md @@ -388,6 +388,8 @@ The RESTCatalog supports pluggable authentication via the `auth` configuration b - `noop`: No authentication (no Authorization header sent). - `basic`: HTTP Basic authentication. +- `oauth2`: OAuth2 client credentials flow. +- `legacyoauth2`: Legacy OAuth2 client credentials flow (Deprecated and will be removed in PyIceberg 1.0.0) - `custom`: Custom authentication manager (requires `auth.impl`). - `google`: Google Authentication support @@ -411,9 +413,10 @@ catalog: | Property | Required | Description | |------------------|----------|-------------------------------------------------------------------------------------------------| -| `auth.type` | Yes | The authentication type to use (`noop`, `basic`, or `custom`). | +| `auth.type` | Yes | The authentication type to use (`noop`, `basic`, `oauth2`, or `custom`). | | `auth.impl` | Conditionally | The fully qualified class path for a custom AuthManager. Required if `auth.type` is `custom`. | | `auth.basic` | If type is `basic` | Block containing `username` and `password` for HTTP Basic authentication. | +| `auth.oauth2` | If type is `oauth2` | Block containing OAuth2 configuration (see below). | | `auth.custom` | If type is `custom` | Block containing configuration for the custom AuthManager. | | `auth.google` | If type is `google` | Block containing `credentials_path` to a service account file (if using). Will default to using Application Default Credentials. | @@ -436,6 +439,20 @@ auth: password: mypass ``` +OAuth2 Authentication: + +```yaml +auth: + type: oauth2 + oauth2: + client_id: my-client-id + client_secret: my-client-secret + token_url: https://auth.example.com/oauth/token + scope: read + refresh_margin: 60 # (optional) seconds before expiry to refresh + expires_in: 3600 # (optional) fallback if server does not provide +``` + Custom Authentication: ```yaml @@ -451,7 +468,7 @@ auth: - If `auth.type` is `custom`, you **must** specify `auth.impl` with the full class path to your custom AuthManager. - If `auth.type` is not `custom`, specifying `auth.impl` is not allowed. -- The configuration block under each type (e.g., `basic`, `custom`) is passed as keyword arguments to the corresponding AuthManager. +- The configuration block under each type (e.g., `basic`, `oauth2`, `custom`) is passed as keyword arguments to the corresponding AuthManager. diff --git a/pyiceberg/catalog/rest/auth.py b/pyiceberg/catalog/rest/auth.py index ab547d8d55..85df069338 100644 --- a/pyiceberg/catalog/rest/auth.py +++ b/pyiceberg/catalog/rest/auth.py @@ -18,9 +18,13 @@ import base64 import importlib import logging +import threading +import time from abc import ABC, abstractmethod +from functools import cached_property from typing import Any, Dict, List, Optional, Type +import requests from requests import HTTPError, PreparedRequest, Session from requests.auth import AuthBase @@ -121,6 +125,98 @@ def auth_header(self) -> str: return f"Bearer {self._token}" +class OAuth2TokenProvider: + """Thread-safe OAuth2 token provider with token refresh support.""" + + client_id: str + client_secret: str + token_url: str + scope: Optional[str] + refresh_margin: int + expires_in: Optional[int] + + _token: Optional[str] + _expires_at: int + _lock: threading.Lock + + def __init__( + self, + client_id: str, + client_secret: str, + token_url: str, + scope: Optional[str] = None, + refresh_margin: int = 60, + expires_in: Optional[int] = None, + ): + self.client_id = client_id + self.client_secret = client_secret + self.token_url = token_url + self.scope = scope + self.refresh_margin = refresh_margin + self.expires_in = expires_in + + self._token = None + self._expires_at = 0 + self._lock = threading.Lock() + + @cached_property + def _client_secret_header(self) -> str: + creds = f"{self.client_id}:{self.client_secret}" + creds_bytes = creds.encode("utf-8") + b64_creds = base64.b64encode(creds_bytes).decode("utf-8") + return f"Basic {b64_creds}" + + def _refresh_token(self) -> None: + data = {} + if self.scope: + data["scope"] = self.scope + + response = requests.post(self.token_url, data=data, headers={"Authorization": self._client_secret_header}) + response.raise_for_status() + result = response.json() + + self._token = result["access_token"] + expires_in = result.get("expires_in", self.expires_in) + if expires_in is None: + raise ValueError( + "The expiration time of the Token must be provided by the Server in the Access Token Response in `expires_in` field, or by the PyIceberg Client." + ) + self._expires_at = time.monotonic() + expires_in - self.refresh_margin + + def get_token(self) -> str: + with self._lock: + if not self._token or time.monotonic() >= self._expires_at: + self._refresh_token() + if self._token is None: + raise ValueError("Authorization token is None after refresh") + return self._token + + +class OAuth2AuthManager(AuthManager): + """Auth Manager implementation that supports OAuth2 as defined in IETF RFC6749.""" + + def __init__( + self, + client_id: str, + client_secret: str, + token_url: str, + scope: Optional[str] = None, + refresh_margin: int = 60, + expires_in: Optional[int] = None, + ): + self.token_provider = OAuth2TokenProvider( + client_id, + client_secret, + token_url, + scope, + refresh_margin, + expires_in, + ) + + def auth_header(self) -> str: + return f"Bearer {self.token_provider.get_token()}" + + class GoogleAuthManager(AuthManager): """An auth manager that is responsible for handling Google credentials.""" @@ -228,4 +324,5 @@ def create(cls, class_or_name: str, config: Dict[str, Any]) -> AuthManager: AuthManagerFactory.register("noop", NoopAuthManager) AuthManagerFactory.register("basic", BasicAuthManager) AuthManagerFactory.register("legacyoauth2", LegacyOAuth2AuthManager) +AuthManagerFactory.register("oauth2", OAuth2AuthManager) AuthManagerFactory.register("google", GoogleAuthManager) diff --git a/tests/catalog/test_rest.py b/tests/catalog/test_rest.py index dcabd0a2ea..c78a0344f1 100644 --- a/tests/catalog/test_rest.py +++ b/tests/catalog/test_rest.py @@ -21,6 +21,7 @@ from unittest import mock import pytest +from requests.exceptions import HTTPError from requests_mock import Mocker import pyiceberg @@ -1646,6 +1647,63 @@ def test_rest_catalog_with_unsupported_auth_type() -> None: assert "Could not load AuthManager class for 'unsupported'" in str(e.value) +def test_rest_catalog_with_oauth2_auth_type(requests_mock: Mocker) -> None: + requests_mock.post( + f"{TEST_URI}oauth2/token", + json={ + "access_token": "MTQ0NjJkZmQ5OTM2NDE1ZTZjNGZmZjI3", + "token_type": "Bearer", + "expires_in": 3600, + "refresh_token": "IwOGYzYTlmM2YxOTQ5MGE3YmNmMDFkNTVk", + "scope": "read", + }, + status_code=200, + ) + requests_mock.get( + f"{TEST_URI}v1/config", + json={"defaults": {}, "overrides": {}}, + status_code=200, + ) + # Given + catalog_properties = { + "uri": TEST_URI, + "auth": { + "type": "oauth2", + "oauth2": { + "client_id": "some_client_id", + "client_secret": "some_client_secret", + "token_url": f"{TEST_URI}oauth2/token", + "scope": "read", + }, + }, + } + catalog = RestCatalog("rest", **catalog_properties) # type: ignore + assert catalog.uri == TEST_URI + + +def test_rest_catalog_oauth2_non_200_token_response(requests_mock: Mocker) -> None: + requests_mock.post( + f"{TEST_URI}oauth2/token", + json={"error": "invalid_client"}, + status_code=401, + ) + catalog_properties = { + "uri": TEST_URI, + "auth": { + "type": "oauth2", + "oauth2": { + "client_id": "bad_client_id", + "client_secret": "bad_client_secret", + "token_url": f"{TEST_URI}oauth2/token", + "scope": "read", + }, + }, + } + + with pytest.raises(HTTPError): + RestCatalog("rest", **catalog_properties) # type: ignore + + EXAMPLE_ENV = {"PYICEBERG_CATALOG__PRODUCTION__URI": TEST_URI}