Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,10 @@ def _get_refresh_access_token_response(self) -> Any:
response_json = response.json()
# Add the access token to the list of secrets so it is replaced before logging the response
# An argument could be made to remove the prevous access key from the list of secrets, but unmasking values seems like a security incident waiting to happen...
access_key = response_json.get(self.get_access_token_name())
access_key = self._find_and_get_value_from_response(
response_json,
self.get_access_token_name(),
)
if not access_key:
raise Exception(
"Token refresh API response was missing access token {self.get_access_token_name()}"
Expand All @@ -164,6 +167,8 @@ def _get_refresh_access_token_response(self) -> Any:
internal_message=message, message=message, failure_type=FailureType.config_error
)
raise
except AirbyteTracedException as e:
raise e
except Exception as e:
raise Exception(f"Error while refreshing access token: {e}") from e

Expand All @@ -175,9 +180,10 @@ def refresh_access_token(self) -> Tuple[str, Union[str, int]]:
"""
response_json = self._get_refresh_access_token_response()

return response_json[self.get_access_token_name()], response_json[
self.get_expires_in_name()
]
return (
self._find_and_get_value_from_response(response_json, self.get_access_token_name()),
self._find_and_get_value_from_response(response_json, self.get_expires_in_name()),
)

def _parse_token_expiration_date(self, value: Union[str, int]) -> pendulum.DateTime:
"""
Expand Down Expand Up @@ -292,6 +298,59 @@ def _message_repository(self) -> Optional[MessageRepository]:
"""
return _NOOP_MESSAGE_REPOSITORY

def _find_and_get_value_from_response(
self,
response_data: Mapping[str, Any],
key_name: str,
max_depth: int = 5,
current_depth: int = 0,
) -> Any:
"""
Recursively searches for a specified key in a nested dictionary or list and returns its value if found.

Args:
response_data (Mapping[str, Any]): The response data to search through, which can be a dictionary or a list.
key_name (str): The key to search for in the response data.
max_depth (int, optional): The maximum depth to search for the key to avoid infinite recursion. Defaults to 5.
current_depth (int, optional): The current depth of the recursion. Defaults to 0.

Returns:
Any: The value associated with the specified key if found, otherwise None.

Raises:
AirbyteTracedException: If the maximum recursion depth is reached without finding the key.
"""
if current_depth > max_depth:
# this is needed to avoid an inf loop, possible with a very deep nesting observed.
message = f"The maximum level of recursion is reached. Couldn't find the speficied `{key_name}` in the response."
raise AirbyteTracedException(
internal_message=message, message=message, failure_type=FailureType.config_error
)

if isinstance(response_data, dict):
# get from the root level
if key_name in response_data:
return response_data[key_name]

# get from the nested object
for _, value in response_data.items():
result = self._find_and_get_value_from_response(
value, key_name, max_depth, current_depth + 1
)
if result is not None:
return result

# get from the nested array object
elif isinstance(response_data, list):
for item in response_data:
result = self._find_and_get_value_from_response(
item, key_name, max_depth, current_depth + 1
)
if result is not None:
return result

return None

def _log_response(self, response: requests.Response) -> None:
if self._message_repository:
self._message_repository.log_message(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -340,9 +340,9 @@ def refresh_access_token( # type: ignore[override] # Signature doesn't match b
) -> Tuple[str, str, str]:
response_json = self._get_refresh_access_token_response()
return (
response_json[self.get_access_token_name()],
response_json[self.get_expires_in_name()],
response_json[self.get_refresh_token_name()],
self._find_and_get_value_from_response(response_json, self.get_access_token_name()),
self._find_and_get_value_from_response(response_json, self.get_expires_in_name()),
self._find_and_get_value_from_response(response_json, self.get_refresh_token_name()),
)

@property
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ def test_refresh_access_token(self, mocker):
assert isinstance(expires_in, int)
assert ("access_token", 1000) == (token, expires_in)

# Test with expires_in as str
# Test with expires_in as str(int)
mocker.patch.object(
resp, "json", return_value={"access_token": "access_token", "expires_in": "2000"}
)
Expand All @@ -266,7 +266,7 @@ def test_refresh_access_token(self, mocker):
assert isinstance(expires_in, str)
assert ("access_token", "2000") == (token, expires_in)

# Test with expires_in as str
# Test with expires_in as datetime(str)
mocker.patch.object(
resp,
"json",
Expand All @@ -277,6 +277,78 @@ def test_refresh_access_token(self, mocker):
assert isinstance(expires_in, str)
assert ("access_token", "2022-04-24T00:00:00Z") == (token, expires_in)

# Test with nested access_token and expires_in as str(int)
mocker.patch.object(
resp,
"json",
return_value={"data": {"access_token": "access_token_nested", "expires_in": "2001"}},
)
token, expires_in = oauth.refresh_access_token()

assert isinstance(expires_in, str)
assert ("access_token_nested", "2001") == (token, expires_in)

# Test with multiple nested levels access_token and expires_in as str(int)
mocker.patch.object(
resp,
"json",
return_value={
"data": {
"scopes": ["one", "two", "three"],
"data2": {
"not_access_token": "test_non_access_token_value",
"data3": {
"some_field": "test_value",
"expires_at": "2800",
"data4": {
"data5": {
"access_token": "access_token_deeply_nested",
"expires_in": "2002",
}
},
},
},
}
},
)
token, expires_in = oauth.refresh_access_token()

assert isinstance(expires_in, str)
assert ("access_token_deeply_nested", "2002") == (token, expires_in)

# Test with max nested levels access_token and expires_in as str(int)
mocker.patch.object(
resp,
"json",
return_value={
"data": {
"scopes": ["one", "two", "three"],
"data2": {
"not_access_token": "test_non_access_token_value",
"data3": {
"some_field": "test_value",
"expires_at": "2800",
"data4": {
"data5": {
# this is the edge case, but worth testing.
"data6": {
"access_token": "access_token_super_deeply_nested",
"expires_in": "2003",
}
}
},
},
},
}
},
)
with pytest.raises(AirbyteTracedException) as exc_info:
oauth.refresh_access_token()
error_message = "The maximum level of recursion is reached. Couldn't find the speficied `access_token` in the response."
assert exc_info.value.internal_message == error_message
assert exc_info.value.message == error_message
assert exc_info.value.failure_type == FailureType.config_error

def test_refresh_access_token_when_headers_provided(self, mocker):
expected_headers = {
"Authorization": "Bearer some_access_token",
Expand Down Expand Up @@ -590,6 +662,11 @@ def test_given_message_repository_when_get_access_token_then_log_request(
"airbyte_cdk.sources.streams.http.requests_native_auth.abstract_oauth.format_http_message",
return_value="formatted json",
)
# patching the `expires_in`
mocker.patch(
"airbyte_cdk.sources.streams.http.requests_native_auth.abstract_oauth.AbstractOauth2Authenticator._find_and_get_value_from_response",
return_value="7200",
)
authenticator.token_has_expired = mocker.Mock(return_value=True)

authenticator.get_access_token()
Expand Down
Loading