Skip to content

Commit 9406632

Browse files
authored
Merge pull request #128 from henriaidasso/auth-handler
2 parents aa285f6 + 081d85c commit 9406632

File tree

15 files changed

+989
-92
lines changed

15 files changed

+989
-92
lines changed

STACpopulator/auth/__init__.py

Whitespace-only changes.

STACpopulator/auth/handlers.py

Lines changed: 325 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,325 @@
1+
from __future__ import annotations
2+
3+
import abc
4+
import inspect
5+
import logging
6+
from http import cookiejar
7+
from typing import Any, Optional, Type
8+
9+
from requests import PreparedRequest, Response
10+
from requests.auth import AuthBase, HTTPBasicAuth, HTTPDigestAuth, HTTPProxyAuth
11+
from requests.structures import CaseInsensitiveDict
12+
13+
from STACpopulator.auth.utils import fully_qualified_name, make_request
14+
from STACpopulator.exceptions import AuthenticationError
15+
from STACpopulator.request.typedefs import (
16+
APP_JSON,
17+
AnyHeadersContainer,
18+
AnyRequestType,
19+
CookiesType,
20+
RequestMethod,
21+
)
22+
23+
LOGGER = logging.getLogger(__name__)
24+
25+
26+
class AuthHandler(AuthBase):
27+
"""Authentication handler class."""
28+
29+
url: Optional[str]
30+
method: RequestMethod
31+
headers: AnyHeadersContainer
32+
identity: Optional[str]
33+
password: Optional[str]
34+
35+
def __init__(
36+
self,
37+
identity: Optional[str] = None,
38+
password: Optional[str] = None,
39+
url: Optional[str] = None,
40+
method: RequestMethod = "GET",
41+
headers: Optional[AnyHeadersContainer] = None,
42+
) -> None:
43+
self.identity = identity
44+
self.password = password
45+
self.url = url
46+
self.method = method if method is not None else "GET"
47+
self.headers = headers if headers is not None else {}
48+
49+
@abc.abstractmethod
50+
def __call__(self, r: PreparedRequest) -> PreparedRequest:
51+
"""Call method to perform inline authentication retrieval prior to sending the request."""
52+
raise NotImplementedError
53+
54+
@staticmethod
55+
def from_data(
56+
auth_handler: Optional[Type[AuthHandler]] = None,
57+
auth_identity: Optional[str] = None,
58+
auth_url: Optional[str] = None,
59+
auth_method: Optional[str] = None,
60+
auth_headers: Optional[AnyHeadersContainer] = None,
61+
auth_token: Optional[str] = None,
62+
) -> Optional[AuthHandler]:
63+
"""Parse arguments that define an authentication handler.
64+
65+
:param auth_handler: The authentication handler class to instantiate.
66+
:param auth_identity: Identity string, optionally containing password as "user:pass".
67+
:param auth_url: URL for authentication.
68+
:param auth_method: Authentication method (HTTP verb).
69+
:param auth_headers: Additional headers for authentication.
70+
:param auth_token: Authentication token.
71+
72+
:return: An instantiated `AuthHandler`, or None if `auth_handler` is invalid.
73+
"""
74+
if not (auth_handler and issubclass(auth_handler, (AuthHandler, AuthBase))):
75+
return None
76+
77+
auth_password = None
78+
if auth_identity and ":" in auth_identity:
79+
auth_identity, auth_password = auth_identity.split(":", 1)
80+
81+
auth_headers = auth_headers or {}
82+
83+
auth_handler_sign = inspect.signature(auth_handler)
84+
auth_opts = [
85+
("username", auth_identity),
86+
("identity", auth_identity),
87+
("password", auth_password),
88+
("url", auth_url),
89+
("method", auth_method),
90+
("headers", CaseInsensitiveDict(auth_headers)),
91+
("token", auth_token),
92+
]
93+
94+
if not auth_handler_sign.parameters:
95+
auth_handler_obj = auth_handler()
96+
for auth_param, auth_option in auth_opts:
97+
if auth_option and hasattr(auth_handler_obj, auth_param):
98+
setattr(auth_handler_obj, auth_param, auth_option)
99+
else:
100+
auth_params = list(auth_handler_sign.parameters)
101+
auth_kwargs = {opt: val for opt, val in auth_opts if opt in auth_params}
102+
103+
# allow partial match of required parameters by name to support custom implementations
104+
# (e.g.: 'MagpieAuth' using 'magpie_url' instead of plain 'url')
105+
for param_name, param in auth_handler_sign.parameters.items():
106+
if param.kind not in [param.POSITIONAL_ONLY, param.POSITIONAL_OR_KEYWORD]:
107+
continue
108+
if param_name not in auth_kwargs:
109+
for opt, val in auth_opts:
110+
if param_name.endswith(opt):
111+
LOGGER.debug("Using authentication partial match: [%s] -> [%s]", opt, param_name)
112+
auth_kwargs[param_name] = val
113+
break
114+
LOGGER.debug("Using authentication parameters: %s", auth_kwargs)
115+
auth_handler_obj = auth_handler(**auth_kwargs)
116+
LOGGER.info(
117+
"Will use specified Authentication Handler [%s] with provided options.",
118+
fully_qualified_name(auth_handler),
119+
)
120+
return auth_handler_obj
121+
122+
123+
class BasicAuthHandler(AuthHandler, HTTPBasicAuth):
124+
"""Basic authentication handler class.
125+
126+
Adds the `Authorization` header formed from basic authentication encoding of username and password to the request.
127+
128+
Authentication URL and method are not needed for this handler.
129+
"""
130+
131+
def __init__(self, username: str, password: str, **kwargs) -> None:
132+
AuthHandler.__init__(self, identity=username, password=password, **kwargs)
133+
HTTPBasicAuth.__init__(self, username=username, password=password)
134+
135+
def __call__(self, r: PreparedRequest) -> PreparedRequest:
136+
"""Call method to perform authentication prior to sending the request."""
137+
return HTTPBasicAuth.__call__(self, r)
138+
139+
140+
class DigestAuthHandler(AuthHandler, HTTPDigestAuth):
141+
"""Digest authentication handler class."""
142+
143+
def __init__(self, username: str, password: str, **kwargs) -> None:
144+
AuthHandler.__init__(self, identity=username, password=password, **kwargs)
145+
HTTPDigestAuth.__init__(self, username=username, password=password)
146+
147+
def __call__(self, r: PreparedRequest) -> PreparedRequest:
148+
"""Call method to perform authentication prior to sending the request."""
149+
return HTTPDigestAuth.__call__(self, r)
150+
151+
152+
class ProxyAuthHandler(AuthHandler, HTTPProxyAuth):
153+
"""Proxy authentication handler class."""
154+
155+
def __init__(self, username: str, password: str, **kwargs) -> None:
156+
AuthHandler.__init__(self, identity=username, password=password, **kwargs)
157+
HTTPProxyAuth.__init__(self, username=username, password=password)
158+
159+
def __call__(self, r: PreparedRequest) -> PreparedRequest:
160+
"""Call method to perform authentication prior to sending the request."""
161+
return HTTPProxyAuth.__call__(self, r)
162+
163+
164+
class CookieJarAuthHandler(AuthHandler):
165+
"""Cookie jar authentication handler class."""
166+
167+
def __init__(self, identity: str, **kwargs) -> None:
168+
AuthHandler.__init__(self, identity=identity, **kwargs)
169+
self.cookiefile = identity
170+
self._cookiejar = None
171+
172+
def __call__(self, r: PreparedRequest) -> PreparedRequest:
173+
"""Call method loading cookie jar prior to sending the request."""
174+
# Lazy-load cookie jar
175+
if self._cookiejar is None:
176+
jar = cookiejar.MozillaCookieJar(self.cookie_file)
177+
jar.load(ignore_discard=True, ignore_expires=True)
178+
self._cookiejar = jar
179+
180+
r._cookies = self._cookiejar
181+
return r
182+
183+
184+
class RequestAuthHandler(AuthHandler):
185+
"""Base class to send a request in order to retrieve an authorization token."""
186+
187+
def __init__(
188+
self,
189+
identity: Optional[str] = None,
190+
password: Optional[str] = None,
191+
url: Optional[str] = None,
192+
method: RequestMethod = "GET",
193+
headers: Optional[AnyHeadersContainer] = None,
194+
token: Optional[str] = None,
195+
) -> None:
196+
AuthHandler.__init__(
197+
self,
198+
identity=identity,
199+
password=password,
200+
url=url,
201+
method=method,
202+
headers=headers,
203+
)
204+
self.token = token
205+
self._common_token_names = ["auth", "access_token", "token"]
206+
207+
if not self.token and not self.url:
208+
raise AuthenticationError("Either the token or the URL to retrieve it must be provided to the handler.")
209+
210+
@property
211+
def auth_token_name(self) -> Optional[str]:
212+
"""Override token name to retrieve in response authentication handler implementation.
213+
214+
Defaults to `None` and auth handler then looks amongst common names: [`auth`, `access_token`, `token`]
215+
"""
216+
return None
217+
218+
@abc.abstractmethod
219+
def auth_header(self, token: str) -> AnyHeadersContainer:
220+
"""Get the header definition with the provided authorization token."""
221+
raise NotImplementedError
222+
223+
@staticmethod
224+
@abc.abstractmethod
225+
def parse_token(token: Any) -> str:
226+
"""Parse token to a format that can be included in a request header."""
227+
raise NotImplementedError
228+
229+
def authenticate(self) -> Optional[str]:
230+
"""Launch an authentication request to retrieve the authorization token."""
231+
auth_headers = {"Accept": APP_JSON}
232+
auth_headers.update(self.headers)
233+
resp = make_request(self.method, self.url, headers=auth_headers)
234+
if not resp.ok:
235+
return None
236+
return self.get_token_from_response(resp)
237+
238+
def get_token_from_response(self, response: Response) -> Optional[str]:
239+
"""Extract the authorization token from a valid authentication response."""
240+
content_type = response.headers.get("Content-Type")
241+
if not content_type == APP_JSON:
242+
return None
243+
244+
body = response.json()
245+
if self.auth_token_name:
246+
auth_token = body.get(self.auth_token_name)
247+
else:
248+
auth_token = next(
249+
(body[name] for name in self._common_token_names if name in body),
250+
None,
251+
)
252+
return auth_token
253+
254+
def __call__(self, request: AnyRequestType) -> AnyRequestType:
255+
"""Call method handling authentication and request forward."""
256+
auth_token = self.authenticate() if self.token is None and self.url else self.token
257+
if not auth_token:
258+
LOGGER.warning(
259+
"Expected authorization token could not be retrieved from URL: [%s] in [%s]",
260+
self.url,
261+
fully_qualified_name(self),
262+
)
263+
else:
264+
auth_token = self.parse_token(auth_token)
265+
auth_header = self.auth_header(auth_token)
266+
request.headers.update(auth_header)
267+
return request
268+
269+
270+
class BearerAuthHandler(RequestAuthHandler):
271+
"""Bearer authentication handler class.
272+
273+
Adds the ``Authorization`` header formed of the authentication bearer token from the underlying request.
274+
"""
275+
276+
@staticmethod
277+
def parse_token(token: str) -> str:
278+
"""Parse token to a form that can be included in a request header."""
279+
return token
280+
281+
def auth_header(self, token: str) -> AnyHeadersContainer:
282+
"""Header definition for bearer token-based authentication."""
283+
return {"Authorization": f"Bearer {token}"}
284+
285+
286+
class CookieAuthHandler(RequestAuthHandler):
287+
"""Cookie-based authentication handler class.
288+
289+
Adds the ``Cookie`` header formed from the authentication bearer token from the underlying request.
290+
"""
291+
292+
def __init__(
293+
self,
294+
identity: Optional[str] = None,
295+
password: Optional[str] = None,
296+
url: Optional[str] = None,
297+
method: RequestMethod = "GET",
298+
headers: Optional[AnyHeadersContainer] = None,
299+
token: Optional[str | CookiesType] = None,
300+
) -> None:
301+
super().__init__(
302+
identity=identity,
303+
password=password,
304+
url=url,
305+
method=method,
306+
headers=headers,
307+
token=token,
308+
)
309+
310+
@staticmethod
311+
def parse_token(token: str | CookiesType) -> str:
312+
"""Parse token to a form that can be included in a request `Cookie` header.
313+
314+
Returns the token string as is if it's a string. Otherwise, if the token is a mapping where keys are cookie
315+
names and values are cookie values, converts the cookie to a `key=val;...` string that can be accepted as the
316+
value of the "Cookie" header.
317+
"""
318+
if isinstance(token, str):
319+
return token
320+
cookie_dict = CaseInsensitiveDict(token)
321+
return "; ".join(f"{key}={val}" for key, val in cookie_dict.items())
322+
323+
def auth_header(self, token: str) -> AnyHeadersContainer:
324+
"""Header definition for cookie-based authentication."""
325+
return {"Cookie": token}

STACpopulator/auth/utils.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
import inspect
2+
import logging
3+
from typing import Any, Type, Union
4+
5+
import requests
6+
from requests import Response
7+
8+
from STACpopulator.request.typedefs import RequestMethod
9+
10+
LOGGER = logging.getLogger(__name__)
11+
12+
13+
def make_request(
14+
method: RequestMethod,
15+
url: str,
16+
ssl_verify: bool = True,
17+
**request_kwargs: Any,
18+
) -> Response:
19+
"""Make an HTTP request with additional request options.
20+
21+
Parameters
22+
----------
23+
method: AnyRequestMethod
24+
The HTTP method to use (e.g., 'GET', 'POST').
25+
url: str
26+
The target URL for the request.
27+
ssl_verify: bool, optional
28+
Whether to verify SSL certificates (default is True). Overrides the `verify` parameter in `requests`.
29+
request_kwargs: dict, optional
30+
Additional keyword arguments to pass to the underlying request, such as headers, data, params, json, etc.
31+
32+
Returns
33+
-------
34+
Response
35+
The response object returned by the request, typically containing status, headers, and content.
36+
"""
37+
request_kwargs.setdefault("timeout", 5)
38+
request_kwargs.setdefault("verify", ssl_verify)
39+
# remove leftover options unknown to requests method in case of multiple entries
40+
known_req_opts = set(inspect.signature(requests.Session.request).parameters)
41+
known_req_opts -= {
42+
"url",
43+
"method",
44+
} # add as unknown to always remove them since they are passed by arguments
45+
for req_opt in set(request_kwargs) - known_req_opts:
46+
request_kwargs.pop(req_opt)
47+
res = requests.request(method, url, **request_kwargs)
48+
return res
49+
50+
51+
def fully_qualified_name(obj: Union[Any, Type[Any]]) -> str:
52+
"""Get the full path definition of the object to allow finding and importing it.
53+
54+
For classes, functions, and exceptions, the returned format is:
55+
56+
.. code-block:: python
57+
module.name
58+
59+
The ``module`` is omitted if it is a builtin object or type.
60+
61+
For methods, the owning class is also included, resulting in:
62+
63+
.. code-block:: python
64+
module.class.name
65+
"""
66+
if inspect.ismethod(obj):
67+
return ".".join([obj.__module__, obj.__qualname__])
68+
cls = obj if inspect.isclass(obj) or inspect.isfunction(obj) else type(obj)
69+
if "builtins" in getattr(cls, "__module__", "builtins"): # sometimes '_sitebuiltins'
70+
return cls.__name__
71+
return ".".join([cls.__module__, cls.__name__])

0 commit comments

Comments
 (0)