Skip to content

Commit 0e3e80d

Browse files
authored
Introduce AuthManager (#1908)
<!-- Thanks for opening a pull request! --> <!-- In the case this PR will resolve an issue, please replace ${GITHUB_ISSUE_ID} below with the actual Github issue id. --> <!-- Closes #${GITHUB_ISSUE_ID} --> # Rationale for this change #1906 # Are these changes tested? Yes, unit tested # Are there any user-facing changes? Not yet <!-- In the case of user-facing changes, please add the changelog label. -->
1 parent a1287d4 commit 0e3e80d

File tree

3 files changed

+147
-0
lines changed

3 files changed

+147
-0
lines changed
File renamed without changes.

pyiceberg/catalog/rest/auth.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
import base64
19+
from abc import ABC, abstractmethod
20+
from typing import Optional
21+
22+
from requests import PreparedRequest
23+
from requests.auth import AuthBase
24+
25+
26+
class AuthManager(ABC):
27+
"""
28+
Abstract base class for Authentication Managers used to supply authorization headers to HTTP clients (e.g. requests.Session).
29+
30+
Subclasses must implement the `auth_header` method to return an Authorization header value.
31+
"""
32+
33+
@abstractmethod
34+
def auth_header(self) -> Optional[str]:
35+
"""Return the Authorization header value, or None if not applicable."""
36+
37+
38+
class NoopAuthManager(AuthManager):
39+
def auth_header(self) -> Optional[str]:
40+
return None
41+
42+
43+
class BasicAuthManager(AuthManager):
44+
def __init__(self, username: str, password: str):
45+
credentials = f"{username}:{password}"
46+
self._token = base64.b64encode(credentials.encode()).decode()
47+
48+
def auth_header(self) -> str:
49+
return f"Basic {self._token}"
50+
51+
52+
class AuthManagerAdapter(AuthBase):
53+
"""A `requests.auth.AuthBase` adapter that integrates an `AuthManager` into a `requests.Session` to automatically attach the appropriate Authorization header to every request.
54+
55+
This adapter is useful when working with `requests.Session.auth`
56+
and allows reuse of authentication strategies defined by `AuthManager`.
57+
This AuthManagerAdapter is only intended to be used against the REST Catalog
58+
Server that expects the Authorization Header.
59+
"""
60+
61+
def __init__(self, auth_manager: AuthManager):
62+
"""
63+
Initialize AuthManagerAdapter.
64+
65+
Args:
66+
auth_manager (AuthManager): An instance of an AuthManager subclass.
67+
"""
68+
self.auth_manager = auth_manager
69+
70+
def __call__(self, request: PreparedRequest) -> PreparedRequest:
71+
"""
72+
Modify the outgoing request to include the Authorization header.
73+
74+
Args:
75+
request (requests.PreparedRequest): The HTTP request being prepared.
76+
77+
Returns:
78+
requests.PreparedRequest: The modified request with Authorization header.
79+
"""
80+
if auth_header := self.auth_manager.auth_header():
81+
request.headers["Authorization"] = auth_header
82+
return request

tests/catalog/test_rest_auth.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
import base64
19+
20+
import pytest
21+
import requests
22+
from requests_mock import Mocker
23+
24+
from pyiceberg.catalog.rest.auth import AuthManagerAdapter, BasicAuthManager, NoopAuthManager
25+
26+
TEST_URI = "https://iceberg-test-catalog/"
27+
28+
29+
@pytest.fixture
30+
def rest_mock(requests_mock: Mocker) -> Mocker:
31+
requests_mock.get(
32+
TEST_URI,
33+
json={},
34+
status_code=200,
35+
)
36+
return requests_mock
37+
38+
39+
def test_noop_auth_header(rest_mock: Mocker) -> None:
40+
auth_manager = NoopAuthManager()
41+
session = requests.Session()
42+
session.auth = AuthManagerAdapter(auth_manager)
43+
44+
session.get(TEST_URI)
45+
history = rest_mock.request_history
46+
assert len(history) == 1
47+
actual_headers = history[0].headers
48+
assert "Authorization" not in actual_headers
49+
50+
51+
def test_basic_auth_header(rest_mock: Mocker) -> None:
52+
username = "testuser"
53+
password = "testpassword"
54+
expected_token = base64.b64encode(f"{username}:{password}".encode()).decode()
55+
expected_header = f"Basic {expected_token}"
56+
57+
auth_manager = BasicAuthManager(username=username, password=password)
58+
session = requests.Session()
59+
session.auth = AuthManagerAdapter(auth_manager)
60+
61+
session.get(TEST_URI)
62+
history = rest_mock.request_history
63+
assert len(history) == 1
64+
actual_headers = history[0].headers
65+
assert actual_headers["Authorization"] == expected_header

0 commit comments

Comments
 (0)