Skip to content

Commit bd5b8ba

Browse files
sungwykevinjqliuCopilot
authored
New OAuth2AuthManager (#2244)
<!-- Thanks for opening a pull request! --> <!-- In the case this PR will resolve an issue, please replace ${GITHUB_ISSUE_ID} below with the actual Github issue id. --> <!-- Closes #${GITHUB_ISSUE_ID} --> # Rationale for this change New OAuth2Manager implementation that makes use of AuthManager and more closely follows https://datatracker.ietf.org/doc/html/rfc6749 recommendations. It injects the encoded secret into the `Basic ` header against the authentication server, as recommended by the RFC instead of injecting it into the request body, which is less secure. Proactively refreshes the access token by checking the expiration. # Are these changes tested? Yes, both in unit and integration tests. # Are there any user-facing changes? No, this is a new feature. <!-- In the case of user-facing changes, please add the changelog label. --> --------- Co-authored-by: Kevin Liu <[email protected]> Co-authored-by: Kevin Liu <[email protected]> Co-authored-by: Copilot <[email protected]>
1 parent 4234879 commit bd5b8ba

File tree

3 files changed

+173
-2
lines changed

3 files changed

+173
-2
lines changed

mkdocs/docs/configuration.md

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -388,6 +388,7 @@ The RESTCatalog supports pluggable authentication via the `auth` configuration b
388388

389389
- `noop`: No authentication (no Authorization header sent).
390390
- `basic`: HTTP Basic authentication.
391+
- `oauth2`: OAuth2 client credentials flow.
391392
- `custom`: Custom authentication manager (requires `auth.impl`).
392393
- `google`: Google Authentication support
393394

@@ -411,9 +412,10 @@ catalog:
411412

412413
| Property | Required | Description |
413414
|------------------|----------|-------------------------------------------------------------------------------------------------|
414-
| `auth.type` | Yes | The authentication type to use (`noop`, `basic`, or `custom`). |
415+
| `auth.type` | Yes | The authentication type to use (`noop`, `basic`, `oauth2`, or `custom`). |
415416
| `auth.impl` | Conditionally | The fully qualified class path for a custom AuthManager. Required if `auth.type` is `custom`. |
416417
| `auth.basic` | If type is `basic` | Block containing `username` and `password` for HTTP Basic authentication. |
418+
| `auth.oauth2` | If type is `oauth2` | Block containing OAuth2 configuration (see below). |
417419
| `auth.custom` | If type is `custom` | Block containing configuration for the custom AuthManager. |
418420
| `auth.google` | If type is `google` | Block containing `credentials_path` to a service account file (if using). Will default to using Application Default Credentials. |
419421

@@ -436,6 +438,20 @@ auth:
436438
password: mypass
437439
```
438440

441+
OAuth2 Authentication:
442+
443+
```yaml
444+
auth:
445+
type: oauth2
446+
oauth2:
447+
client_id: my-client-id
448+
client_secret: my-client-secret
449+
token_url: https://auth.example.com/oauth/token
450+
scope: read
451+
refresh_margin: 60 # (optional) seconds before expiry to refresh
452+
expires_in: 3600 # (optional) fallback if server does not provide
453+
```
454+
439455
Custom Authentication:
440456

441457
```yaml
@@ -451,7 +467,7 @@ auth:
451467

452468
- If `auth.type` is `custom`, you **must** specify `auth.impl` with the full class path to your custom AuthManager.
453469
- If `auth.type` is not `custom`, specifying `auth.impl` is not allowed.
454-
- The configuration block under each type (e.g., `basic`, `custom`) is passed as keyword arguments to the corresponding AuthManager.
470+
- The configuration block under each type (e.g., `basic`, `oauth2`, `custom`) is passed as keyword arguments to the corresponding AuthManager.
455471

456472
<!-- markdown-link-check-enable-->
457473

pyiceberg/catalog/rest/auth.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,13 @@
1818
import base64
1919
import importlib
2020
import logging
21+
import threading
22+
import time
2123
from abc import ABC, abstractmethod
24+
from functools import cached_property
2225
from typing import Any, Dict, List, Optional, Type
2326

27+
import requests
2428
from requests import HTTPError, PreparedRequest, Session
2529
from requests.auth import AuthBase
2630

@@ -121,6 +125,98 @@ def auth_header(self) -> str:
121125
return f"Bearer {self._token}"
122126

123127

128+
class OAuth2TokenProvider:
129+
"""Thread-safe OAuth2 token provider with token refresh support."""
130+
131+
client_id: str
132+
client_secret: str
133+
token_url: str
134+
scope: Optional[str]
135+
refresh_margin: int
136+
expires_in: Optional[int]
137+
138+
_token: Optional[str]
139+
_expires_at: int
140+
_lock: threading.Lock
141+
142+
def __init__(
143+
self,
144+
client_id: str,
145+
client_secret: str,
146+
token_url: str,
147+
scope: Optional[str] = None,
148+
refresh_margin: int = 60,
149+
expires_in: Optional[int] = None,
150+
):
151+
self.client_id = client_id
152+
self.client_secret = client_secret
153+
self.token_url = token_url
154+
self.scope = scope
155+
self.refresh_margin = refresh_margin
156+
self.expires_in = expires_in
157+
158+
self._token = None
159+
self._expires_at = 0
160+
self._lock = threading.Lock()
161+
162+
@cached_property
163+
def _client_secret_header(self) -> str:
164+
creds = f"{self.client_id}:{self.client_secret}"
165+
creds_bytes = creds.encode("utf-8")
166+
b64_creds = base64.b64encode(creds_bytes).decode("utf-8")
167+
return f"Basic {b64_creds}"
168+
169+
def _refresh_token(self) -> None:
170+
data = {"grant_type": "client_credentials"}
171+
if self.scope:
172+
data["scope"] = self.scope
173+
174+
response = requests.post(self.token_url, data=data, headers={"Authorization": self._client_secret_header})
175+
response.raise_for_status()
176+
result = response.json()
177+
178+
self._token = result["access_token"]
179+
expires_in = result.get("expires_in", self.expires_in)
180+
if expires_in is None:
181+
raise ValueError(
182+
"The expiration time of the Token must be provided by the Server in the Access Token Response in `expires_in` field, or by the PyIceberg Client."
183+
)
184+
self._expires_at = time.monotonic() + expires_in - self.refresh_margin
185+
186+
def get_token(self) -> str:
187+
with self._lock:
188+
if not self._token or time.monotonic() >= self._expires_at:
189+
self._refresh_token()
190+
if self._token is None:
191+
raise ValueError("Authorization token is None after refresh")
192+
return self._token
193+
194+
195+
class OAuth2AuthManager(AuthManager):
196+
"""Auth Manager implementation that supports OAuth2 as defined in IETF RFC6749."""
197+
198+
def __init__(
199+
self,
200+
client_id: str,
201+
client_secret: str,
202+
token_url: str,
203+
scope: Optional[str] = None,
204+
refresh_margin: int = 60,
205+
expires_in: Optional[int] = None,
206+
):
207+
self.token_provider = OAuth2TokenProvider(
208+
client_id,
209+
client_secret,
210+
token_url,
211+
scope,
212+
refresh_margin,
213+
expires_in,
214+
)
215+
216+
def auth_header(self) -> str:
217+
return f"Bearer {self.token_provider.get_token()}"
218+
219+
124220
class GoogleAuthManager(AuthManager):
125221
"""An auth manager that is responsible for handling Google credentials."""
126222

@@ -228,4 +324,5 @@ def create(cls, class_or_name: str, config: Dict[str, Any]) -> AuthManager:
228324
AuthManagerFactory.register("noop", NoopAuthManager)
229325
AuthManagerFactory.register("basic", BasicAuthManager)
230326
AuthManagerFactory.register("legacyoauth2", LegacyOAuth2AuthManager)
327+
AuthManagerFactory.register("oauth2", OAuth2AuthManager)
231328
AuthManagerFactory.register("google", GoogleAuthManager)

tests/catalog/test_rest.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from unittest import mock
2222

2323
import pytest
24+
from requests.exceptions import HTTPError
2425
from requests_mock import Mocker
2526

2627
import pyiceberg
@@ -1646,6 +1647,63 @@ def test_rest_catalog_with_unsupported_auth_type() -> None:
16461647
assert "Could not load AuthManager class for 'unsupported'" in str(e.value)
16471648

16481649

1650+
def test_rest_catalog_with_oauth2_auth_type(requests_mock: Mocker) -> None:
1651+
requests_mock.post(
1652+
f"{TEST_URI}oauth2/token",
1653+
json={
1654+
"access_token": "MTQ0NjJkZmQ5OTM2NDE1ZTZjNGZmZjI3",
1655+
"token_type": "Bearer",
1656+
"expires_in": 3600,
1657+
"refresh_token": "IwOGYzYTlmM2YxOTQ5MGE3YmNmMDFkNTVk",
1658+
"scope": "read",
1659+
},
1660+
status_code=200,
1661+
)
1662+
requests_mock.get(
1663+
f"{TEST_URI}v1/config",
1664+
json={"defaults": {}, "overrides": {}},
1665+
status_code=200,
1666+
)
1667+
# Given
1668+
catalog_properties = {
1669+
"uri": TEST_URI,
1670+
"auth": {
1671+
"type": "oauth2",
1672+
"oauth2": {
1673+
"client_id": "some_client_id",
1674+
"client_secret": "some_client_secret",
1675+
"token_url": f"{TEST_URI}oauth2/token",
1676+
"scope": "read",
1677+
},
1678+
},
1679+
}
1680+
catalog = RestCatalog("rest", **catalog_properties) # type: ignore
1681+
assert catalog.uri == TEST_URI
1682+
1683+
1684+
def test_rest_catalog_oauth2_non_200_token_response(requests_mock: Mocker) -> None:
1685+
requests_mock.post(
1686+
f"{TEST_URI}oauth2/token",
1687+
json={"error": "invalid_client"},
1688+
status_code=401,
1689+
)
1690+
catalog_properties = {
1691+
"uri": TEST_URI,
1692+
"auth": {
1693+
"type": "oauth2",
1694+
"oauth2": {
1695+
"client_id": "bad_client_id",
1696+
"client_secret": "bad_client_secret",
1697+
"token_url": f"{TEST_URI}oauth2/token",
1698+
"scope": "read",
1699+
},
1700+
},
1701+
}
1702+
1703+
with pytest.raises(HTTPError):
1704+
RestCatalog("rest", **catalog_properties) # type: ignore
1705+
1706+
16491707
EXAMPLE_ENV = {"PYICEBERG_CATALOG__PRODUCTION__URI": TEST_URI}
16501708

16511709

0 commit comments

Comments
 (0)