Skip to content

New OAuth2Manager #2244

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
21 changes: 19 additions & 2 deletions mkdocs/docs/configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,8 @@ The RESTCatalog supports pluggable authentication via the `auth` configuration b

- `noop`: No authentication (no Authorization header sent).
- `basic`: HTTP Basic authentication.
- `oauth2`: OAuth2 client credentials flow.
- `legacyoauth2`: Legacy OAuth2 client credentials flow (Deprecated and will be removed in PyIceberg 1.0.0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: i think this becomes a little confusing. legacyoauth2 is a fallback mechanism. i.e. when the auth: block is not provided. i think we should call this out

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense - I'm fine with leaving legacyoauth2 completely out of this section and relying on the existing Authentication Options to explain those configurations

- `custom`: Custom authentication manager (requires `auth.impl`).
- `google`: Google Authentication support

Expand All @@ -411,9 +413,10 @@ catalog:

| Property | Required | Description |
|------------------|----------|-------------------------------------------------------------------------------------------------|
| `auth.type` | Yes | The authentication type to use (`noop`, `basic`, or `custom`). |
| `auth.type` | Yes | The authentication type to use (`noop`, `basic`, `oauth2`, or `custom`). |
| `auth.impl` | Conditionally | The fully qualified class path for a custom AuthManager. Required if `auth.type` is `custom`. |
| `auth.basic` | If type is `basic` | Block containing `username` and `password` for HTTP Basic authentication. |
| `auth.oauth2` | If type is `oauth2` | Block containing OAuth2 configuration (see below). |
| `auth.custom` | If type is `custom` | Block containing configuration for the custom AuthManager. |
| `auth.google` | If type is `google` | Block containing `credentials_path` to a service account file (if using). Will default to using Application Default Credentials. |

Expand All @@ -436,6 +439,20 @@ auth:
password: mypass
```

OAuth2 Authentication:

```yaml
auth:
type: oauth2
oauth2:
client_id: my-client-id
client_secret: my-client-secret
token_url: https://auth.example.com/oauth/token
scope: read
refresh_margin: 60 # (optional) seconds before expiry to refresh
expires_in: 3600 # (optional) fallback if server does not provide
```

Custom Authentication:

```yaml
Expand All @@ -451,7 +468,7 @@ auth:

- If `auth.type` is `custom`, you **must** specify `auth.impl` with the full class path to your custom AuthManager.
- If `auth.type` is not `custom`, specifying `auth.impl` is not allowed.
- The configuration block under each type (e.g., `basic`, `custom`) is passed as keyword arguments to the corresponding AuthManager.
- The configuration block under each type (e.g., `basic`, `oauth2`, `custom`) is passed as keyword arguments to the corresponding AuthManager.

<!-- markdown-link-check-enable-->

Expand Down
97 changes: 97 additions & 0 deletions pyiceberg/catalog/rest/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,13 @@
import base64
import importlib
import logging
import threading
import time
from abc import ABC, abstractmethod
from functools import cached_property
from typing import Any, Dict, List, Optional, Type

import requests
from requests import HTTPError, PreparedRequest, Session
from requests.auth import AuthBase

Expand Down Expand Up @@ -121,6 +125,98 @@ def auth_header(self) -> str:
return f"Bearer {self._token}"


class OAuth2TokenProvider:
"""Thread-safe OAuth2 token provider with token refresh support."""

client_id: str
client_secret: str
token_url: str
scope: Optional[str]
refresh_margin: int
expires_in: Optional[int]

_token: Optional[str]
_expires_at: int
_lock: threading.Lock

def __init__(
self,
client_id: str,
client_secret: str,
token_url: str,
scope: Optional[str] = None,
refresh_margin: int = 60,
expires_in: Optional[int] = None,
):
self.client_id = client_id
self.client_secret = client_secret
self.token_url = token_url
self.scope = scope
self.refresh_margin = refresh_margin
self.expires_in = expires_in

self._token = None
self._expires_at = 0
self._lock = threading.Lock()

@cached_property
def _client_secret_header(self) -> str:
creds = f"{self.client_id}:{self.client_secret}"
creds_bytes = creds.encode("utf-8")
b64_creds = base64.b64encode(creds_bytes).decode("utf-8")
return f"Basic {b64_creds}"

def _refresh_token(self) -> None:
data = {}
if self.scope:
data["scope"] = self.scope

response = requests.post(self.token_url, data=data, headers={"Authorization": self._client_secret_header})
response.raise_for_status()
result = response.json()

self._token = result["access_token"]
expires_in = result.get("expires_in", self.expires_in)
if expires_in is None:
raise ValueError(
"The expiration time of the Token must be provided by the Server in the Access Token Response in `expires_in` field, or by the PyIceberg Client."
)
self._expires_at = time.monotonic() + expires_in - self.refresh_margin

def get_token(self) -> str:
with self._lock:
if not self._token or time.monotonic() >= self._expires_at:
self._refresh_token()
if self._token is None:
raise ValueError("Authorization token is None after refresh")
return self._token


class OAuth2AuthManager(AuthManager):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we have any tests for LegacyOAuth2AuthManager? do we want OAuth2AuthManager to be feature parity in this first release?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i dont see credential, resource, and audience

| credential | client_id:client_secret | Credential to use for OAuth2 credential flow when initializing the catalog |
| scope | openid offline corpds:ds:profile | Desired scope of the requested security token (default : catalog) |
| resource | rest_catalog.iceberg.com | URI for the target resource or service |
| audience | rest_catalog | Logical name of target resource or service |

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we are using client_id and client_secret instead in the current implementation, as opposed to credential. This is also currently in draft mode, and I intend to review OAuth2 spec a little bit more in depth and other industry standard implementations before finalizing the implementation.

"""Auth Manager implementation that supports OAuth2 as defined in IETF RFC6749."""

def __init__(
self,
client_id: str,
client_secret: str,
token_url: str,
scope: Optional[str] = None,
refresh_margin: int = 60,
expires_in: Optional[int] = None,
):
self.token_provider = OAuth2TokenProvider(
client_id,
client_secret,
token_url,
scope,
refresh_margin,
expires_in,
)

def auth_header(self) -> str:
return f"Bearer {self.token_provider.get_token()}"


class GoogleAuthManager(AuthManager):
"""An auth manager that is responsible for handling Google credentials."""

Expand Down Expand Up @@ -228,4 +324,5 @@ def create(cls, class_or_name: str, config: Dict[str, Any]) -> AuthManager:
AuthManagerFactory.register("noop", NoopAuthManager)
AuthManagerFactory.register("basic", BasicAuthManager)
AuthManagerFactory.register("legacyoauth2", LegacyOAuth2AuthManager)
AuthManagerFactory.register("oauth2", OAuth2AuthManager)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: is it a good idea to call AuthManagerFactory.register directly in the file? Is it better if its encapsulated in a function?

Im worry about import automatically running this code

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My understanding is that with Python import the module's code is executed only once, and subsequent imports instructs Python to retrieve the loaded module object from sys.modules unless the application owner intentionaly reloads the module. The registration is also currently idempotent

AuthManagerFactory.register("google", GoogleAuthManager)
58 changes: 58 additions & 0 deletions tests/catalog/test_rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from unittest import mock

import pytest
from requests.exceptions import HTTPError
from requests_mock import Mocker

import pyiceberg
Expand Down Expand Up @@ -1646,6 +1647,63 @@ def test_rest_catalog_with_unsupported_auth_type() -> None:
assert "Could not load AuthManager class for 'unsupported'" in str(e.value)


def test_rest_catalog_with_oauth2_auth_type(requests_mock: Mocker) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice to have a test where fetching the token does not return a 200 or where refreshing the token does not return a 200.

That would allow us to verify that the non-happy paths work as intended.

requests_mock.post(
f"{TEST_URI}oauth2/token",
json={
"access_token": "MTQ0NjJkZmQ5OTM2NDE1ZTZjNGZmZjI3",
"token_type": "Bearer",
"expires_in": 3600,
"refresh_token": "IwOGYzYTlmM2YxOTQ5MGE3YmNmMDFkNTVk",
"scope": "read",
},
status_code=200,
)
requests_mock.get(
f"{TEST_URI}v1/config",
json={"defaults": {}, "overrides": {}},
status_code=200,
)
# Given
catalog_properties = {
"uri": TEST_URI,
"auth": {
"type": "oauth2",
"oauth2": {
"client_id": "some_client_id",
"client_secret": "some_client_secret",
"token_url": f"{TEST_URI}oauth2/token",
"scope": "read",
},
},
}
catalog = RestCatalog("rest", **catalog_properties) # type: ignore
assert catalog.uri == TEST_URI


def test_rest_catalog_oauth2_non_200_token_response(requests_mock: Mocker) -> None:
requests_mock.post(
f"{TEST_URI}oauth2/token",
json={"error": "invalid_client"},
status_code=401,
)
catalog_properties = {
"uri": TEST_URI,
"auth": {
"type": "oauth2",
"oauth2": {
"client_id": "bad_client_id",
"client_secret": "bad_client_secret",
"token_url": f"{TEST_URI}oauth2/token",
"scope": "read",
},
},
}

with pytest.raises(HTTPError):
RestCatalog("rest", **catalog_properties) # type: ignore


EXAMPLE_ENV = {"PYICEBERG_CATALOG__PRODUCTION__URI": TEST_URI}


Expand Down