|
| 1 | +"""Authentication providers to enrich API requests with authorization headers.""" |
| 2 | +from __future__ import annotations |
| 3 | + |
| 4 | +from abc import ABC, abstractmethod |
| 5 | +from datetime import datetime, timedelta |
| 6 | +from typing import Callable, Dict, Optional, Tuple |
| 7 | + |
| 8 | +AuthHeaders = Dict[str, str] |
| 9 | +TokenWithExpiry = Tuple[str, Optional[datetime]] |
| 10 | + |
| 11 | + |
| 12 | +class AuthProvider(ABC): |
| 13 | + """Represents a strategy capable of injecting authentication information.""" |
| 14 | + |
| 15 | + @abstractmethod |
| 16 | + def apply(self, headers: Optional[AuthHeaders] = None) -> AuthHeaders: |
| 17 | + """Return request headers containing the required authentication details.""" |
| 18 | + |
| 19 | + |
| 20 | +class ApiKeyAuth(AuthProvider): |
| 21 | + """Static API key authentication added to a configurable header.""" |
| 22 | + |
| 23 | + def __init__(self, header_name: str, api_key: str): |
| 24 | + """Store the header name and API key used for authenticated requests.""" |
| 25 | + self.header_name = header_name |
| 26 | + self.api_key = api_key |
| 27 | + |
| 28 | + def apply(self, headers: Optional[AuthHeaders] = None) -> AuthHeaders: |
| 29 | + """Return headers with the API key injected into the configured header.""" |
| 30 | + composed_headers = dict(headers or {}) |
| 31 | + composed_headers[self.header_name] = self.api_key |
| 32 | + return composed_headers |
| 33 | + |
| 34 | + |
| 35 | +class BearerTokenAuth(AuthProvider): |
| 36 | + """Bearer token authentication supporting automatic token refresh.""" |
| 37 | + |
| 38 | + def __init__( |
| 39 | + self, |
| 40 | + token_supplier: Callable[[], TokenWithExpiry], |
| 41 | + *, |
| 42 | + header_name: str = "Authorization", |
| 43 | + scheme: str = "Bearer", |
| 44 | + refresh_margin: float = 0, |
| 45 | + clock: Callable[[], datetime] = datetime.utcnow, |
| 46 | + ): |
| 47 | + """Configure the bearer token strategy and how tokens are supplied.""" |
| 48 | + self._token_supplier = token_supplier |
| 49 | + self.header_name = header_name |
| 50 | + self.scheme = scheme |
| 51 | + self._token: Optional[str] = None |
| 52 | + self._expires_at: Optional[datetime] = None |
| 53 | + self._refresh_margin = timedelta(seconds=refresh_margin) |
| 54 | + self._clock = clock |
| 55 | + |
| 56 | + def _ensure_token(self) -> None: |
| 57 | + """Fetch or refresh the bearer token when it is missing or expired.""" |
| 58 | + if self._token is None or self._should_refresh(): |
| 59 | + self._token, self._expires_at = self._token_supplier() |
| 60 | + |
| 61 | + def _should_refresh(self) -> bool: |
| 62 | + """Determine whether a new token is required based on the expiry time.""" |
| 63 | + if self._expires_at is None: |
| 64 | + return False |
| 65 | + return self._expires_at <= self._clock() + self._refresh_margin |
| 66 | + |
| 67 | + def apply(self, headers: Optional[AuthHeaders] = None) -> AuthHeaders: |
| 68 | + """Return headers containing a valid bearer token with optional refresh.""" |
| 69 | + self._ensure_token() |
| 70 | + composed_headers = dict(headers or {}) |
| 71 | + composed_headers[self.header_name] = f"{self.scheme} {self._token}" |
| 72 | + return composed_headers |
| 73 | + |
| 74 | + |
| 75 | +class OAuth2ClientCredentials(AuthProvider): |
| 76 | + """OAuth2 Client Credentials authentication with token renewal support.""" |
| 77 | + |
| 78 | + def __init__( |
| 79 | + self, |
| 80 | + client_id: str, |
| 81 | + client_secret: str, |
| 82 | + token_fetcher: Callable[[str, str, Optional[str]], TokenWithExpiry], |
| 83 | + *, |
| 84 | + scope: Optional[str] = None, |
| 85 | + header_name: str = "Authorization", |
| 86 | + refresh_margin: float = 30, |
| 87 | + clock: Callable[[], datetime] = datetime.utcnow, |
| 88 | + ): |
| 89 | + """Store client credentials and the callable responsible for new tokens.""" |
| 90 | + self.client_id = client_id |
| 91 | + self.client_secret = client_secret |
| 92 | + self.scope = scope |
| 93 | + self._token_fetcher = token_fetcher |
| 94 | + self.header_name = header_name |
| 95 | + self._refresh_margin = timedelta(seconds=refresh_margin) |
| 96 | + self._clock = clock |
| 97 | + self._token: Optional[str] = None |
| 98 | + self._expires_at: Optional[datetime] = None |
| 99 | + |
| 100 | + def _ensure_token(self) -> None: |
| 101 | + """Obtain a valid OAuth2 access token using the configured fetcher.""" |
| 102 | + if self._token is None or self._should_refresh(): |
| 103 | + self._token, self._expires_at = self._token_fetcher( |
| 104 | + self.client_id, self.client_secret, self.scope |
| 105 | + ) |
| 106 | + |
| 107 | + def _should_refresh(self) -> bool: |
| 108 | + """Determine whether the current OAuth2 token needs to be refreshed.""" |
| 109 | + if self._expires_at is None: |
| 110 | + return False |
| 111 | + return self._expires_at <= self._clock() + self._refresh_margin |
| 112 | + |
| 113 | + def apply(self, headers: Optional[AuthHeaders] = None) -> AuthHeaders: |
| 114 | + """Return headers augmented with a fresh OAuth2 bearer token.""" |
| 115 | + self._ensure_token() |
| 116 | + composed_headers = dict(headers or {}) |
| 117 | + composed_headers[self.header_name] = f"Bearer {self._token}" |
| 118 | + return composed_headers |
| 119 | + |
0 commit comments