Skip to content

Commit d060a8b

Browse files
committed
merge from main
2 parents b5b8197 + 24cbc51 commit d060a8b

File tree

6 files changed

+433
-197
lines changed

6 files changed

+433
-197
lines changed

airbyte_cdk/sources/declarative/auth/oauth.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -239,8 +239,8 @@ def get_token_expiry_date(self) -> AirbyteDateTime:
239239
def _has_access_token_been_initialized(self) -> bool:
240240
return self._access_token is not None
241241

242-
def set_token_expiry_date(self, value: Union[str, int]) -> None:
243-
self._token_expiry_date = self._parse_token_expiration_date(value)
242+
def set_token_expiry_date(self, value: AirbyteDateTime) -> None:
243+
self._token_expiry_date = value
244244

245245
def get_assertion_name(self) -> str:
246246
return self.assertion_name

airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def build_refresh_request_headers(self) -> Mapping[str, Any] | None:
130130
headers = self.get_refresh_request_headers()
131131
return headers if headers else None
132132

133-
def refresh_access_token(self) -> Tuple[str, Union[str, int]]:
133+
def refresh_access_token(self) -> Tuple[str, AirbyteDateTime]:
134134
"""
135135
Returns the refresh token and its expiration datetime
136136
@@ -148,6 +148,14 @@ def refresh_access_token(self) -> Tuple[str, Union[str, int]]:
148148
# PRIVATE METHODS
149149
# ----------------
150150

151+
def _default_token_expiry_date(self) -> AirbyteDateTime:
152+
"""
153+
Returns the default token expiry date
154+
"""
155+
# 1 hour was chosen as a middle ground to avoid unnecessary frequent refreshes and token expiration
156+
default_token_expiry_duration_hours = 1 # 1 hour
157+
return ab_datetime_now() + timedelta(hours=default_token_expiry_duration_hours)
158+
151159
def _wrap_refresh_token_exception(
152160
self, exception: requests.exceptions.RequestException
153161
) -> bool:
@@ -257,14 +265,10 @@ def _ensure_access_token_in_response(self, response_data: Mapping[str, Any]) ->
257265

258266
def _parse_token_expiration_date(self, value: Union[str, int]) -> AirbyteDateTime:
259267
"""
260-
Return the expiration datetime of the refresh token
268+
Parse a string or integer token expiration date into a datetime object
261269
262270
:return: expiration datetime
263271
"""
264-
if not value and not self.token_has_expired():
265-
# No expiry token was provided but the previous one is not expired so it's fine
266-
return self.get_token_expiry_date()
267-
268272
if self.token_expiry_is_time_of_expiration:
269273
if not self.token_expiry_date_format:
270274
raise ValueError(
@@ -308,17 +312,30 @@ def _extract_refresh_token(self, response_data: Mapping[str, Any]) -> Any:
308312
"""
309313
return self._find_and_get_value_from_response(response_data, self.get_refresh_token_name())
310314

311-
def _extract_token_expiry_date(self, response_data: Mapping[str, Any]) -> Any:
315+
def _extract_token_expiry_date(self, response_data: Mapping[str, Any]) -> AirbyteDateTime:
312316
"""
313317
Extracts the token_expiry_date, like `expires_in` or `expires_at`, etc from the given response data.
314318
319+
If the token_expiry_date is not found, it will return an existing token expiry date if set, or a default token expiry date.
320+
315321
Args:
316322
response_data (Mapping[str, Any]): The response data from which to extract the token_expiry_date.
317323
318324
Returns:
319-
str: The extracted token_expiry_date.
325+
The extracted token_expiry_date or None if not found.
320326
"""
321-
return self._find_and_get_value_from_response(response_data, self.get_expires_in_name())
327+
expires_in = self._find_and_get_value_from_response(
328+
response_data, self.get_expires_in_name()
329+
)
330+
if expires_in is not None:
331+
return self._parse_token_expiration_date(expires_in)
332+
333+
# expires_in is None
334+
existing_expiry_date = self.get_token_expiry_date()
335+
if existing_expiry_date and not self.token_has_expired():
336+
return existing_expiry_date
337+
338+
return self._default_token_expiry_date()
322339

323340
def _find_and_get_value_from_response(
324341
self,
@@ -344,7 +361,7 @@ def _find_and_get_value_from_response(
344361
"""
345362
if current_depth > max_depth:
346363
# this is needed to avoid an inf loop, possible with a very deep nesting observed.
347-
message = f"The maximum level of recursion is reached. Couldn't find the speficied `{key_name}` in the response."
364+
message = f"The maximum level of recursion is reached. Couldn't find the specified `{key_name}` in the response."
348365
raise ResponseKeysMaxRecurtionReached(
349366
internal_message=message, message=message, failure_type=FailureType.config_error
350367
)
@@ -441,7 +458,7 @@ def get_token_expiry_date(self) -> AirbyteDateTime:
441458
"""Expiration date of the access token"""
442459

443460
@abstractmethod
444-
def set_token_expiry_date(self, value: Union[str, int]) -> None:
461+
def set_token_expiry_date(self, value: AirbyteDateTime) -> None:
445462
"""Setter for access token expiration date"""
446463

447464
@abstractmethod

airbyte_cdk/sources/streams/http/requests_native_auth/oauth.py

Lines changed: 4 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -120,8 +120,8 @@ def get_grant_type(self) -> str:
120120
def get_token_expiry_date(self) -> AirbyteDateTime:
121121
return self._token_expiry_date
122122

123-
def set_token_expiry_date(self, value: Union[str, int]) -> None:
124-
self._token_expiry_date = self._parse_token_expiration_date(value)
123+
def set_token_expiry_date(self, value: AirbyteDateTime) -> None:
124+
self._token_expiry_date = value
125125

126126
@property
127127
def token_expiry_is_time_of_expiration(self) -> bool:
@@ -316,26 +316,6 @@ def token_has_expired(self) -> bool:
316316
"""Returns True if the token is expired"""
317317
return ab_datetime_now() > self.get_token_expiry_date()
318318

319-
@staticmethod
320-
def get_new_token_expiry_date(
321-
access_token_expires_in: str,
322-
token_expiry_date_format: str | None = None,
323-
) -> AirbyteDateTime:
324-
"""
325-
Calculate the new token expiry date based on the provided expiration duration or format.
326-
327-
Args:
328-
access_token_expires_in (str): The duration (in seconds) until the access token expires, or the expiry date in a specific format.
329-
token_expiry_date_format (str | None, optional): The format of the expiry date if provided. Defaults to None.
330-
331-
Returns:
332-
AirbyteDateTime: The calculated expiry date of the access token.
333-
"""
334-
if token_expiry_date_format:
335-
return ab_datetime_parse(access_token_expires_in)
336-
else:
337-
return ab_datetime_now() + timedelta(seconds=int(access_token_expires_in))
338-
339319
def get_access_token(self) -> str:
340320
"""Retrieve new access and refresh token if the access token has expired.
341321
The new refresh token is persisted with the set_refresh_token function
@@ -346,16 +326,13 @@ def get_access_token(self) -> str:
346326
new_access_token, access_token_expires_in, new_refresh_token = (
347327
self.refresh_access_token()
348328
)
349-
new_token_expiry_date: AirbyteDateTime = self.get_new_token_expiry_date(
350-
access_token_expires_in, self._token_expiry_date_format
351-
)
352329
self.access_token = new_access_token
353330
self.set_refresh_token(new_refresh_token)
354-
self.set_token_expiry_date(new_token_expiry_date)
331+
self.set_token_expiry_date(access_token_expires_in)
355332
self._emit_control_message()
356333
return self.access_token
357334

358-
def refresh_access_token(self) -> Tuple[str, str, str]: # type: ignore[override]
335+
def refresh_access_token(self) -> Tuple[str, AirbyteDateTime, str]: # type: ignore[override]
359336
"""
360337
Refreshes the access token by making a handled request and extracting the necessary token information.
361338

0 commit comments

Comments
 (0)