2525_NOOP_MESSAGE_REPOSITORY = NoopMessageRepository ()
2626
2727
28+ class ResponseKeysMaxRecurtionReached (AirbyteTracedException ):
29+ """
30+ Raised when the max level of recursion is reached, when trying to
31+ find-and-get the target key, during the `_make_handled_request`
32+ """
33+
34+
2835class AbstractOauth2Authenticator (AuthBase ):
2936 """
3037 Abstract class for an OAuth authenticators that implements the OAuth token refresh flow. The authenticator
@@ -53,15 +60,31 @@ def __call__(self, request: requests.PreparedRequest) -> requests.PreparedReques
5360 request .headers .update (self .get_auth_header ())
5461 return request
5562
63+ @property
64+ def _is_access_token_flow (self ) -> bool :
65+ return self .get_token_refresh_endpoint () is None and self .access_token is not None
66+
67+ @property
68+ def token_expiry_is_time_of_expiration (self ) -> bool :
69+ """
70+ Indicates that the Token Expiry returns the date until which the token will be valid, not the amount of time it will be valid.
71+ """
72+
73+ return False
74+
75+ @property
76+ def token_expiry_date_format (self ) -> Optional [str ]:
77+ """
78+ Format of the datetime; exists it if expires_in is returned as the expiration datetime instead of seconds until it expires
79+ """
80+
81+ return None
82+
5683 def get_auth_header (self ) -> Mapping [str , Any ]:
5784 """HTTP header to set on the requests"""
5885 token = self .access_token if self ._is_access_token_flow else self .get_access_token ()
5986 return {"Authorization" : f"Bearer { token } " }
6087
61- @property
62- def _is_access_token_flow (self ) -> bool :
63- return self .get_token_refresh_endpoint () is None and self .access_token is not None
64-
6588 def get_access_token (self ) -> str :
6689 """Returns the access token"""
6790 if self .token_has_expired ():
@@ -107,9 +130,39 @@ def build_refresh_request_headers(self) -> Mapping[str, Any] | None:
107130 headers = self .get_refresh_request_headers ()
108131 return headers if headers else None
109132
133+ def refresh_access_token (self ) -> Tuple [str , Union [str , int ]]:
134+ """
135+ Returns the refresh token and its expiration datetime
136+
137+ :return: a tuple of (access_token, token_lifespan)
138+ """
139+ response_json = self ._make_handled_request ()
140+ self ._ensure_access_token_in_response (response_json )
141+
142+ return (
143+ self ._extract_access_token (response_json ),
144+ self ._extract_token_expiry_date (response_json ),
145+ )
146+
147+ # ----------------
148+ # PRIVATE METHODS
149+ # ----------------
150+
110151 def _wrap_refresh_token_exception (
111152 self , exception : requests .exceptions .RequestException
112153 ) -> bool :
154+ """
155+ Wraps and handles exceptions that occur during the refresh token process.
156+
157+ This method checks if the provided exception is related to a refresh token error
158+ by examining the response status code and specific error content.
159+
160+ Args:
161+ exception (requests.exceptions.RequestException): The exception raised during the request.
162+
163+ Returns:
164+ bool: True if the exception is related to a refresh token error, False otherwise.
165+ """
113166 try :
114167 if exception .response is not None :
115168 exception_content = exception .response .json ()
@@ -131,30 +184,35 @@ def _wrap_refresh_token_exception(
131184 ),
132185 max_time = 300 ,
133186 )
134- def _get_refresh_access_token_response (self ) -> Any :
187+ def _make_handled_request (self ) -> Any :
188+ """
189+ Makes a handled HTTP request to refresh an OAuth token.
190+
191+ This method sends a POST request to the token refresh endpoint with the necessary
192+ headers and body to obtain a new access token. It handles various exceptions that
193+ may occur during the request and logs the response for troubleshooting purposes.
194+
195+ Returns:
196+ Mapping[str, Any]: The JSON response from the token refresh endpoint.
197+
198+ Raises:
199+ DefaultBackoffException: If the response status code is 429 (Too Many Requests)
200+ or any 5xx server error.
201+ AirbyteTracedException: If the refresh token is invalid or expired, prompting
202+ re-authentication.
203+ Exception: For any other exceptions that occur during the request.
204+ """
135205 try :
136206 response = requests .request (
137207 method = "POST" ,
138208 url = self .get_token_refresh_endpoint (), # type: ignore # returns None, if not provided, but str | bytes is expected.
139209 data = self .build_refresh_request_body (),
140210 headers = self .build_refresh_request_headers (),
141211 )
142- if response .ok :
143- response_json = response .json ()
144- # Add the access token to the list of secrets so it is replaced before logging the response
145- # An argument could be made to remove the prevous access key from the list of secrets, but unmasking values seems like a security incident waiting to happen...
146- access_key = response_json .get (self .get_access_token_name ())
147- if not access_key :
148- raise Exception (
149- "Token refresh API response was missing access token {self.get_access_token_name()}"
150- )
151- add_to_secrets (access_key )
152- self ._log_response (response )
153- return response_json
154- else :
155- # log the response even if the request failed for troubleshooting purposes
156- self ._log_response (response )
157- response .raise_for_status ()
212+ # log the response even if the request failed for troubleshooting purposes
213+ self ._log_response (response )
214+ response .raise_for_status ()
215+ return response .json ()
158216 except requests .exceptions .RequestException as e :
159217 if e .response is not None :
160218 if e .response .status_code == 429 or e .response .status_code >= 500 :
@@ -168,17 +226,34 @@ def _get_refresh_access_token_response(self) -> Any:
168226 except Exception as e :
169227 raise Exception (f"Error while refreshing access token: { e } " ) from e
170228
171- def refresh_access_token (self ) -> Tuple [str , Union [ str , int ]] :
229+ def _ensure_access_token_in_response (self , response_data : Mapping [str , Any ]) -> None :
172230 """
173- Returns the refresh token and its expiration datetime
231+ Ensures that the access token is present in the response data.
174232
175- :return: a tuple of (access_token, token_lifespan)
176- """
177- response_json = self ._get_refresh_access_token_response ()
233+ This method attempts to extract the access token from the provided response data.
234+ If the access token is not found, it raises an exception indicating that the token
235+ refresh API response was missing the access token. If the access token is found,
236+ it adds the token to the list of secrets to ensure it is replaced before logging
237+ the response.
238+
239+ Args:
240+ response_data (Mapping[str, Any]): The response data from which to extract the access token.
178241
179- return response_json [self .get_access_token_name ()], response_json [
180- self .get_expires_in_name ()
181- ]
242+ Raises:
243+ Exception: If the access token is not found in the response data.
244+ ResponseKeysMaxRecurtionReached: If the maximum recursion depth is reached while extracting the access token.
245+ """
246+ try :
247+ access_key = self ._extract_access_token (response_data )
248+ if not access_key :
249+ raise Exception (
250+ "Token refresh API response was missing access token {self.get_access_token_name()}"
251+ )
252+ # Add the access token to the list of secrets so it is replaced before logging the response
253+ # An argument could be made to remove the prevous access key from the list of secrets, but unmasking values seems like a security incident waiting to happen...
254+ add_to_secrets (access_key )
255+ except ResponseKeysMaxRecurtionReached as e :
256+ raise e
182257
183258 def _parse_token_expiration_date (self , value : Union [str , int ]) -> AirbyteDateTime :
184259 """
@@ -206,22 +281,125 @@ def _parse_token_expiration_date(self, value: Union[str, int]) -> AirbyteDateTim
206281 f"Invalid expires_in value: { value } . Expected number of seconds when no format specified."
207282 )
208283
209- @property
210- def token_expiry_is_time_of_expiration (self ) -> bool :
284+ def _extract_access_token (self , response_data : Mapping [str , Any ]) -> Any :
211285 """
212- Indicates that the Token Expiry returns the date until which the token will be valid, not the amount of time it will be valid.
286+ Extracts the access token from the given response data.
287+
288+ Args:
289+ response_data (Mapping[str, Any]): The response data from which to extract the access token.
290+
291+ Returns:
292+ str: The extracted access token.
213293 """
294+ return self ._find_and_get_value_from_response (response_data , self .get_access_token_name ())
214295
215- return False
296+ def _extract_refresh_token (self , response_data : Mapping [str , Any ]) -> Any :
297+ """
298+ Extracts the refresh token from the given response data.
216299
217- @property
218- def token_expiry_date_format (self ) -> Optional [str ]:
300+ Args:
301+ response_data (Mapping[str, Any]): The response data from which to extract the refresh token.
302+
303+ Returns:
304+ str: The extracted refresh token.
219305 """
220- Format of the datetime; exists it if expires_in is returned as the expiration datetime instead of seconds until it expires
306+ return self ._find_and_get_value_from_response (response_data , self .get_refresh_token_name ())
307+
308+ def _extract_token_expiry_date (self , response_data : Mapping [str , Any ]) -> Any :
309+ """
310+ Extracts the token_expiry_date, like `expires_in` or `expires_at`, etc from the given response data.
311+
312+ Args:
313+ response_data (Mapping[str, Any]): The response data from which to extract the token_expiry_date.
314+
315+ Returns:
316+ str: The extracted token_expiry_date.
221317 """
318+ return self ._find_and_get_value_from_response (response_data , self .get_expires_in_name ())
319+
320+ def _find_and_get_value_from_response (
321+ self ,
322+ response_data : Mapping [str , Any ],
323+ key_name : str ,
324+ max_depth : int = 5 ,
325+ current_depth : int = 0 ,
326+ ) -> Any :
327+ """
328+ Recursively searches for a specified key in a nested dictionary or list and returns its value if found.
329+
330+ Args:
331+ response_data (Mapping[str, Any]): The response data to search through, which can be a dictionary or a list.
332+ key_name (str): The key to search for in the response data.
333+ max_depth (int, optional): The maximum depth to search for the key to avoid infinite recursion. Defaults to 5.
334+ current_depth (int, optional): The current depth of the recursion. Defaults to 0.
335+
336+ Returns:
337+ Any: The value associated with the specified key if found, otherwise None.
338+
339+ Raises:
340+ AirbyteTracedException: If the maximum recursion depth is reached without finding the key.
341+ """
342+ if current_depth > max_depth :
343+ # this is needed to avoid an inf loop, possible with a very deep nesting observed.
344+ message = f"The maximum level of recursion is reached. Couldn't find the speficied `{ key_name } ` in the response."
345+ raise ResponseKeysMaxRecurtionReached (
346+ internal_message = message , message = message , failure_type = FailureType .config_error
347+ )
348+
349+ if isinstance (response_data , dict ):
350+ # get from the root level
351+ if key_name in response_data :
352+ return response_data [key_name ]
353+
354+ # get from the nested object
355+ for _ , value in response_data .items ():
356+ result = self ._find_and_get_value_from_response (
357+ value , key_name , max_depth , current_depth + 1
358+ )
359+ if result is not None :
360+ return result
361+
362+ # get from the nested array object
363+ elif isinstance (response_data , list ):
364+ for item in response_data :
365+ result = self ._find_and_get_value_from_response (
366+ item , key_name , max_depth , current_depth + 1
367+ )
368+ if result is not None :
369+ return result
222370
223371 return None
224372
373+ @property
374+ def _message_repository (self ) -> Optional [MessageRepository ]:
375+ """
376+ The implementation can define a message_repository if it wants debugging logs for HTTP requests
377+ """
378+ return _NOOP_MESSAGE_REPOSITORY
379+
380+ def _log_response (self , response : requests .Response ) -> None :
381+ """
382+ Logs the HTTP response using the message repository if it is available.
383+
384+ Args:
385+ response (requests.Response): The HTTP response to log.
386+ """
387+ if self ._message_repository :
388+ self ._message_repository .log_message (
389+ Level .DEBUG ,
390+ lambda : format_http_message (
391+ response ,
392+ "Refresh token" ,
393+ "Obtains access token" ,
394+ self ._NO_STREAM_NAME ,
395+ is_auxiliary = True ,
396+ ),
397+ )
398+
399+ # ----------------
400+ # ABSTR METHODS
401+ # ----------------
402+
225403 @abstractmethod
226404 def get_token_refresh_endpoint (self ) -> Optional [str ]:
227405 """Returns the endpoint to refresh the access token"""
@@ -295,23 +473,3 @@ def access_token(self) -> str:
295473 @abstractmethod
296474 def access_token (self , value : str ) -> str :
297475 """Setter for the access token"""
298-
299- @property
300- def _message_repository (self ) -> Optional [MessageRepository ]:
301- """
302- The implementation can define a message_repository if it wants debugging logs for HTTP requests
303- """
304- return _NOOP_MESSAGE_REPOSITORY
305-
306- def _log_response (self , response : requests .Response ) -> None :
307- if self ._message_repository :
308- self ._message_repository .log_message (
309- Level .DEBUG ,
310- lambda : format_http_message (
311- response ,
312- "Refresh token" ,
313- "Obtains access token" ,
314- self ._NO_STREAM_NAME ,
315- is_auxiliary = True ,
316- ),
317- )
0 commit comments