Skip to content

Commit 720244a

Browse files
committed
use cursor to make a new version of the classes
1 parent 533b70a commit 720244a

File tree

2 files changed

+332
-1
lines changed

2 files changed

+332
-1
lines changed

airbyte_cdk/sources/streams/http/requests_native_auth/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
33
#
44

5-
from .oauth import Oauth2Authenticator, SingleUseRefreshTokenOauth2Authenticator
5+
from .oauth_v2 import Oauth2Authenticator, SingleUseRefreshTokenOauth2Authenticator
66
from .token import BasicHttpAuthenticator, MultipleTokenAuthenticator, TokenAuthenticator
77

88
__all__ = [
Lines changed: 331 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,331 @@
1+
#
2+
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
3+
#
4+
5+
from datetime import timedelta
6+
from typing import Any, List, Mapping, Optional, Sequence, Tuple, Union
7+
import requests
8+
from requests.auth import AuthBase
9+
import dpath
10+
import json
11+
from json import JSONDecodeError
12+
from abc import abstractmethod
13+
14+
from airbyte_cdk.config_observation import (
15+
create_connector_config_control_message,
16+
emit_configuration_as_airbyte_control_message,
17+
)
18+
from airbyte_cdk.sources.message import MessageRepository, NoopMessageRepository
19+
from airbyte_cdk.utils.datetime_helpers import (
20+
AirbyteDateTime,
21+
ab_datetime_now,
22+
ab_datetime_parse,
23+
)
24+
from airbyte_cdk.models import FailureType, Level
25+
from airbyte_cdk.sources.streams.http.exceptions import DefaultBackoffException
26+
from airbyte_cdk.utils import AirbyteTracedException
27+
from airbyte_cdk.sources.http_logger import format_http_message
28+
29+
class AbstractOauth2Authenticator(AuthBase):
30+
"""Base class for OAuth2 authentication with shared token refresh logic"""
31+
32+
def __init__(
33+
self,
34+
token_refresh_endpoint: str,
35+
client_id: str,
36+
client_secret: str,
37+
refresh_token: str,
38+
client_id_name: str = "client_id",
39+
client_secret_name: str = "client_secret",
40+
refresh_token_name: str = "refresh_token",
41+
scopes: Optional[List[str]] = None,
42+
access_token_name: str = "access_token",
43+
expires_in_name: str = "expires_in",
44+
refresh_request_body: Optional[Mapping[str, Any]] = None,
45+
refresh_request_headers: Optional[Mapping[str, Any]] = None,
46+
grant_type_name: str = "grant_type",
47+
grant_type: str = "refresh_token",
48+
token_expiry_date: Optional[AirbyteDateTime] = None,
49+
token_expiry_date_format: Optional[str] = None,
50+
token_expiry_is_time_of_expiration: bool = False,
51+
refresh_token_error_status_codes: Tuple[int, ...] = (),
52+
refresh_token_error_key: str = "",
53+
refresh_token_error_values: Tuple[str, ...] = (),
54+
):
55+
self._token_refresh_endpoint = token_refresh_endpoint
56+
self._client_id = client_id
57+
self._client_secret = client_secret
58+
self._refresh_token = refresh_token
59+
self._client_id_name = client_id_name
60+
self._client_secret_name = client_secret_name
61+
self._refresh_token_name = refresh_token_name
62+
self._scopes = scopes
63+
self._access_token_name = access_token_name
64+
self._expires_in_name = expires_in_name
65+
self._refresh_request_body = refresh_request_body
66+
self._refresh_request_headers = refresh_request_headers
67+
self._grant_type_name = grant_type_name
68+
self._grant_type = grant_type
69+
self._token_expiry_date = token_expiry_date
70+
self._token_expiry_date_format = token_expiry_date_format
71+
self._token_expiry_is_time_of_expiration = token_expiry_is_time_of_expiration
72+
self._refresh_token_error_status_codes = refresh_token_error_status_codes
73+
self._refresh_token_error_key = refresh_token_error_key
74+
self._refresh_token_error_values = refresh_token_error_values
75+
76+
def __call__(self, request: requests.PreparedRequest) -> requests.PreparedRequest:
77+
request.headers.update(self.get_auth_header())
78+
return request
79+
80+
def get_auth_header(self) -> Mapping[str, Any]:
81+
return {"Authorization": f"Bearer {self.get_access_token()}"}
82+
83+
def get_access_token(self) -> str:
84+
"""Get the current access token, refreshing if expired"""
85+
if self.token_has_expired():
86+
self.refresh_access_token()
87+
return self.access_token
88+
89+
def token_has_expired(self) -> bool:
90+
"""Check if the current token has expired"""
91+
return not self.get_token_expiry_date() or ab_datetime_now() > self.get_token_expiry_date()
92+
93+
def _make_refresh_request(self) -> Mapping[str, Any]:
94+
"""
95+
Make the HTTP request to refresh OAuth tokens.
96+
97+
Returns:
98+
Mapping[str, Any]: The JSON response from the token refresh endpoint.
99+
100+
Raises:
101+
DefaultBackoffException: If the response status code is 429 (Too Many Requests)
102+
or any 5xx server error.
103+
AirbyteTracedException: If the refresh token is invalid or expired, prompting
104+
re-authentication.
105+
"""
106+
try:
107+
response = requests.post(
108+
url=self._token_refresh_endpoint,
109+
data=self._get_refresh_request_body(),
110+
headers=self._refresh_request_headers,
111+
)
112+
self._log_response(response)
113+
response.raise_for_status()
114+
return response.json()
115+
except requests.exceptions.RequestException as e:
116+
if e.response is not None:
117+
if e.response.status_code == 429 or e.response.status_code >= 500:
118+
raise DefaultBackoffException(request=e.response.request, response=e.response)
119+
if e.response.status_code in self._refresh_token_error_status_codes:
120+
try:
121+
error_value = e.response.json().get(self._refresh_token_error_key)
122+
if error_value in self._refresh_token_error_values:
123+
message = "Refresh token is invalid or expired. Please re-authenticate from Sources/<your source>/Settings."
124+
raise AirbyteTracedException(
125+
message=message,
126+
internal_message=message,
127+
failure_type=FailureType.config_error
128+
)
129+
except JSONDecodeError:
130+
pass
131+
raise
132+
133+
def _get_refresh_request_body(self) -> Mapping[str, Any]:
134+
"""Get the request body for token refresh"""
135+
body = {
136+
self._grant_type_name: self._grant_type,
137+
self._client_id_name: self._client_id,
138+
self._client_secret_name: self._client_secret,
139+
self._refresh_token_name: self._refresh_token,
140+
}
141+
if self._scopes:
142+
body["scopes"] = self._scopes
143+
if self._refresh_request_body:
144+
# We defer to existing oauth constructs over custom configured fields
145+
for key, val in self._refresh_request_body.items():
146+
if key not in body:
147+
body[key] = val
148+
return body
149+
150+
def _parse_token_expiration_date(self, value: Union[str, int]) -> AirbyteDateTime:
151+
"""Parse token expiration field from response into an AirbyteDateTime object"""
152+
if self._token_expiry_is_time_of_expiration:
153+
if not self._token_expiry_date_format:
154+
raise ValueError("Token expiry date format required when using expiration time")
155+
return ab_datetime_parse(str(value))
156+
else:
157+
try:
158+
seconds = int(float(str(value)))
159+
return ab_datetime_now() + timedelta(seconds=seconds)
160+
except (ValueError, TypeError):
161+
raise ValueError(f"Invalid expires_in value: {value}")
162+
163+
@property
164+
def _message_repository(self) -> MessageRepository:
165+
"""Override in subclasses that need message logging"""
166+
return NoopMessageRepository()
167+
168+
def _log_response(self, response: requests.Response) -> None:
169+
"""Log the response if a message repository is configured"""
170+
if self._message_repository:
171+
self._message_repository.log_message(
172+
Level.DEBUG,
173+
lambda: format_http_message(
174+
response,
175+
"Refresh token",
176+
"Obtains access token",
177+
None,
178+
is_auxiliary=True,
179+
type="AUTH",
180+
),
181+
)
182+
183+
@property
184+
@abstractmethod
185+
def access_token(self) -> str:
186+
"""Get current access token - implemented by subclasses"""
187+
pass
188+
189+
@abstractmethod
190+
def refresh_access_token(self) -> None:
191+
"""Refresh the access token - implemented by subclasses"""
192+
pass
193+
194+
@abstractmethod
195+
def get_token_expiry_date(self) -> Optional[AirbyteDateTime]:
196+
"""Get token expiry date - implemented by subclasses"""
197+
pass
198+
199+
class Oauth2Authenticator(AbstractOauth2Authenticator):
200+
"""OAuth2 authenticator that stores tokens in memory"""
201+
202+
def __init__(self, *args, **kwargs):
203+
super().__init__(*args, **kwargs)
204+
self._access_token: Optional[str] = None
205+
self._token_expiry_date: Optional[AirbyteDateTime] = None
206+
207+
def refresh_access_token(self) -> Tuple[str, Union[str, int]]:
208+
"""Refresh access token and return token and expiry"""
209+
response = self._make_refresh_request()
210+
access_token = response[self._access_token_name]
211+
expires_in = response.get(self._expires_in_name)
212+
213+
self._access_token = access_token
214+
if expires_in:
215+
self._token_expiry_date = self._parse_token_expiration_date(expires_in)
216+
217+
return access_token, expires_in
218+
219+
def get_token_expiry_date(self) -> Optional[AirbyteDateTime]:
220+
return self._token_expiry_date
221+
222+
@property
223+
def access_token(self) -> str:
224+
if not self._access_token:
225+
raise ValueError("Access token not set")
226+
return self._access_token
227+
228+
class SingleUseRefreshTokenOauth2Authenticator(AbstractOauth2Authenticator):
229+
"""OAuth2 authenticator that stores tokens in config and emits updates"""
230+
231+
def __init__(
232+
self,
233+
connector_config: Mapping[str, Any],
234+
token_refresh_endpoint: str,
235+
client_id: Optional[str] = None,
236+
client_secret: Optional[str] = None,
237+
access_token_config_path: Sequence[str] = ("credentials", "access_token"),
238+
refresh_token_config_path: Sequence[str] = ("credentials", "refresh_token"),
239+
token_expiry_date_config_path: Sequence[str] = ("credentials", "token_expiry_date"),
240+
message_repository: MessageRepository = NoopMessageRepository(),
241+
**kwargs
242+
):
243+
self._connector_config = connector_config
244+
self._access_token_config_path = access_token_config_path
245+
self._refresh_token_config_path = refresh_token_config_path
246+
self._token_expiry_date_config_path = token_expiry_date_config_path
247+
self.__message_repository = message_repository
248+
249+
# Get credentials from config if not provided
250+
if not client_id:
251+
client_id = self._get_config_value_by_path(("credentials", "client_id"))
252+
if not client_secret:
253+
client_secret = self._get_config_value_by_path(("credentials", "client_secret"))
254+
255+
super().__init__(
256+
token_refresh_endpoint=token_refresh_endpoint,
257+
client_id=client_id,
258+
client_secret=client_secret,
259+
refresh_token=self._get_config_value_by_path(refresh_token_config_path),
260+
**kwargs
261+
)
262+
263+
def refresh_access_token(self) -> Tuple[str, Union[str, int], str]:
264+
"""Refresh access token and update config"""
265+
response = self._make_refresh_request()
266+
267+
access_token = response[self._access_token_name]
268+
refresh_token = response.get(self._refresh_token_name)
269+
expires_in = response.get(self._expires_in_name)
270+
271+
self._update_config(access_token, refresh_token, expires_in)
272+
return access_token, expires_in, refresh_token
273+
274+
def _update_config(
275+
self,
276+
access_token: str,
277+
refresh_token: Optional[str],
278+
expires_in: Optional[Union[str, int]]
279+
) -> None:
280+
"""Update the config with new token values"""
281+
dpath.new(self._connector_config, self._access_token_config_path, access_token)
282+
if refresh_token:
283+
dpath.new(self._connector_config, self._refresh_token_config_path, refresh_token)
284+
self._refresh_token = refresh_token
285+
286+
if expires_in:
287+
expiry_date = self._parse_token_expiration_date(expires_in)
288+
dpath.new(
289+
self._connector_config,
290+
self._token_expiry_date_config_path,
291+
str(expiry_date)
292+
)
293+
294+
self._emit_control_message()
295+
296+
def _emit_control_message(self) -> None:
297+
"""Emit control message for config update"""
298+
if not isinstance(self.__message_repository, NoopMessageRepository):
299+
self.__message_repository.emit_message(
300+
create_connector_config_control_message(self._connector_config)
301+
)
302+
else:
303+
emit_configuration_as_airbyte_control_message(self._connector_config)
304+
305+
def get_token_expiry_date(self) -> AirbyteDateTime:
306+
"""Get token expiry date from config"""
307+
expiry_date = self._get_config_value_by_path(self._token_expiry_date_config_path)
308+
if expiry_date == "":
309+
return ab_datetime_now() - timedelta(days=1)
310+
return ab_datetime_parse(str(expiry_date))
311+
312+
@property
313+
def access_token(self) -> str:
314+
"""Get access token from config"""
315+
return self._get_config_value_by_path(self._access_token_config_path)
316+
317+
def _get_config_value_by_path(
318+
self,
319+
config_path: Union[str, Sequence[str]],
320+
default: Optional[str] = None
321+
) -> str:
322+
"""Get a value from the config using a path"""
323+
return dpath.get(
324+
self._connector_config,
325+
config_path,
326+
default=default if default is not None else "",
327+
)
328+
329+
@property
330+
def _message_repository(self) -> MessageRepository:
331+
return self.__message_repository

0 commit comments

Comments
 (0)