1+ #
2+ # Copyright (c) 2023 Airbyte, Inc., all rights reserved.
3+ #
4+
5+ from datetime import timedelta
6+ from typing import Any , List , Mapping , Optional , Sequence , Tuple , Union
7+ import requests
8+ from requests .auth import AuthBase
9+ import dpath
10+ import json
11+ from json import JSONDecodeError
12+ from abc import abstractmethod
13+
14+ from airbyte_cdk .config_observation import (
15+ create_connector_config_control_message ,
16+ emit_configuration_as_airbyte_control_message ,
17+ )
18+ from airbyte_cdk .sources .message import MessageRepository , NoopMessageRepository
19+ from airbyte_cdk .utils .datetime_helpers import (
20+ AirbyteDateTime ,
21+ ab_datetime_now ,
22+ ab_datetime_parse ,
23+ )
24+ from airbyte_cdk .models import FailureType , Level
25+ from airbyte_cdk .sources .streams .http .exceptions import DefaultBackoffException
26+ from airbyte_cdk .utils import AirbyteTracedException
27+ from airbyte_cdk .sources .http_logger import format_http_message
28+
29+ class AbstractOauth2Authenticator (AuthBase ):
30+ """Base class for OAuth2 authentication with shared token refresh logic"""
31+
32+ def __init__ (
33+ self ,
34+ token_refresh_endpoint : str ,
35+ client_id : str ,
36+ client_secret : str ,
37+ refresh_token : str ,
38+ client_id_name : str = "client_id" ,
39+ client_secret_name : str = "client_secret" ,
40+ refresh_token_name : str = "refresh_token" ,
41+ scopes : Optional [List [str ]] = None ,
42+ access_token_name : str = "access_token" ,
43+ expires_in_name : str = "expires_in" ,
44+ refresh_request_body : Optional [Mapping [str , Any ]] = None ,
45+ refresh_request_headers : Optional [Mapping [str , Any ]] = None ,
46+ grant_type_name : str = "grant_type" ,
47+ grant_type : str = "refresh_token" ,
48+ token_expiry_date : Optional [AirbyteDateTime ] = None ,
49+ token_expiry_date_format : Optional [str ] = None ,
50+ token_expiry_is_time_of_expiration : bool = False ,
51+ refresh_token_error_status_codes : Tuple [int , ...] = (),
52+ refresh_token_error_key : str = "" ,
53+ refresh_token_error_values : Tuple [str , ...] = (),
54+ ):
55+ self ._token_refresh_endpoint = token_refresh_endpoint
56+ self ._client_id = client_id
57+ self ._client_secret = client_secret
58+ self ._refresh_token = refresh_token
59+ self ._client_id_name = client_id_name
60+ self ._client_secret_name = client_secret_name
61+ self ._refresh_token_name = refresh_token_name
62+ self ._scopes = scopes
63+ self ._access_token_name = access_token_name
64+ self ._expires_in_name = expires_in_name
65+ self ._refresh_request_body = refresh_request_body
66+ self ._refresh_request_headers = refresh_request_headers
67+ self ._grant_type_name = grant_type_name
68+ self ._grant_type = grant_type
69+ self ._token_expiry_date = token_expiry_date
70+ self ._token_expiry_date_format = token_expiry_date_format
71+ self ._token_expiry_is_time_of_expiration = token_expiry_is_time_of_expiration
72+ self ._refresh_token_error_status_codes = refresh_token_error_status_codes
73+ self ._refresh_token_error_key = refresh_token_error_key
74+ self ._refresh_token_error_values = refresh_token_error_values
75+
76+ def __call__ (self , request : requests .PreparedRequest ) -> requests .PreparedRequest :
77+ request .headers .update (self .get_auth_header ())
78+ return request
79+
80+ def get_auth_header (self ) -> Mapping [str , Any ]:
81+ return {"Authorization" : f"Bearer { self .get_access_token ()} " }
82+
83+ def get_access_token (self ) -> str :
84+ """Get the current access token, refreshing if expired"""
85+ if self .token_has_expired ():
86+ self .refresh_access_token ()
87+ return self .access_token
88+
89+ def token_has_expired (self ) -> bool :
90+ """Check if the current token has expired"""
91+ return not self .get_token_expiry_date () or ab_datetime_now () > self .get_token_expiry_date ()
92+
93+ def _make_refresh_request (self ) -> Mapping [str , Any ]:
94+ """
95+ Make the HTTP request to refresh OAuth tokens.
96+
97+ Returns:
98+ Mapping[str, Any]: The JSON response from the token refresh endpoint.
99+
100+ Raises:
101+ DefaultBackoffException: If the response status code is 429 (Too Many Requests)
102+ or any 5xx server error.
103+ AirbyteTracedException: If the refresh token is invalid or expired, prompting
104+ re-authentication.
105+ """
106+ try :
107+ response = requests .post (
108+ url = self ._token_refresh_endpoint ,
109+ data = self ._get_refresh_request_body (),
110+ headers = self ._refresh_request_headers ,
111+ )
112+ self ._log_response (response )
113+ response .raise_for_status ()
114+ return response .json ()
115+ except requests .exceptions .RequestException as e :
116+ if e .response is not None :
117+ if e .response .status_code == 429 or e .response .status_code >= 500 :
118+ raise DefaultBackoffException (request = e .response .request , response = e .response )
119+ if e .response .status_code in self ._refresh_token_error_status_codes :
120+ try :
121+ error_value = e .response .json ().get (self ._refresh_token_error_key )
122+ if error_value in self ._refresh_token_error_values :
123+ message = "Refresh token is invalid or expired. Please re-authenticate from Sources/<your source>/Settings."
124+ raise AirbyteTracedException (
125+ message = message ,
126+ internal_message = message ,
127+ failure_type = FailureType .config_error
128+ )
129+ except JSONDecodeError :
130+ pass
131+ raise
132+
133+ def _get_refresh_request_body (self ) -> Mapping [str , Any ]:
134+ """Get the request body for token refresh"""
135+ body = {
136+ self ._grant_type_name : self ._grant_type ,
137+ self ._client_id_name : self ._client_id ,
138+ self ._client_secret_name : self ._client_secret ,
139+ self ._refresh_token_name : self ._refresh_token ,
140+ }
141+ if self ._scopes :
142+ body ["scopes" ] = self ._scopes
143+ if self ._refresh_request_body :
144+ # We defer to existing oauth constructs over custom configured fields
145+ for key , val in self ._refresh_request_body .items ():
146+ if key not in body :
147+ body [key ] = val
148+ return body
149+
150+ def _parse_token_expiration_date (self , value : Union [str , int ]) -> AirbyteDateTime :
151+ """Parse token expiration field from response into an AirbyteDateTime object"""
152+ if self ._token_expiry_is_time_of_expiration :
153+ if not self ._token_expiry_date_format :
154+ raise ValueError ("Token expiry date format required when using expiration time" )
155+ return ab_datetime_parse (str (value ))
156+ else :
157+ try :
158+ seconds = int (float (str (value )))
159+ return ab_datetime_now () + timedelta (seconds = seconds )
160+ except (ValueError , TypeError ):
161+ raise ValueError (f"Invalid expires_in value: { value } " )
162+
163+ @property
164+ def _message_repository (self ) -> MessageRepository :
165+ """Override in subclasses that need message logging"""
166+ return NoopMessageRepository ()
167+
168+ def _log_response (self , response : requests .Response ) -> None :
169+ """Log the response if a message repository is configured"""
170+ if self ._message_repository :
171+ self ._message_repository .log_message (
172+ Level .DEBUG ,
173+ lambda : format_http_message (
174+ response ,
175+ "Refresh token" ,
176+ "Obtains access token" ,
177+ None ,
178+ is_auxiliary = True ,
179+ type = "AUTH" ,
180+ ),
181+ )
182+
183+ @property
184+ @abstractmethod
185+ def access_token (self ) -> str :
186+ """Get current access token - implemented by subclasses"""
187+ pass
188+
189+ @abstractmethod
190+ def refresh_access_token (self ) -> None :
191+ """Refresh the access token - implemented by subclasses"""
192+ pass
193+
194+ @abstractmethod
195+ def get_token_expiry_date (self ) -> Optional [AirbyteDateTime ]:
196+ """Get token expiry date - implemented by subclasses"""
197+ pass
198+
199+ class Oauth2Authenticator (AbstractOauth2Authenticator ):
200+ """OAuth2 authenticator that stores tokens in memory"""
201+
202+ def __init__ (self , * args , ** kwargs ):
203+ super ().__init__ (* args , ** kwargs )
204+ self ._access_token : Optional [str ] = None
205+ self ._token_expiry_date : Optional [AirbyteDateTime ] = None
206+
207+ def refresh_access_token (self ) -> Tuple [str , Union [str , int ]]:
208+ """Refresh access token and return token and expiry"""
209+ response = self ._make_refresh_request ()
210+ access_token = response [self ._access_token_name ]
211+ expires_in = response .get (self ._expires_in_name )
212+
213+ self ._access_token = access_token
214+ if expires_in :
215+ self ._token_expiry_date = self ._parse_token_expiration_date (expires_in )
216+
217+ return access_token , expires_in
218+
219+ def get_token_expiry_date (self ) -> Optional [AirbyteDateTime ]:
220+ return self ._token_expiry_date
221+
222+ @property
223+ def access_token (self ) -> str :
224+ if not self ._access_token :
225+ raise ValueError ("Access token not set" )
226+ return self ._access_token
227+
228+ class SingleUseRefreshTokenOauth2Authenticator (AbstractOauth2Authenticator ):
229+ """OAuth2 authenticator that stores tokens in config and emits updates"""
230+
231+ def __init__ (
232+ self ,
233+ connector_config : Mapping [str , Any ],
234+ token_refresh_endpoint : str ,
235+ client_id : Optional [str ] = None ,
236+ client_secret : Optional [str ] = None ,
237+ access_token_config_path : Sequence [str ] = ("credentials" , "access_token" ),
238+ refresh_token_config_path : Sequence [str ] = ("credentials" , "refresh_token" ),
239+ token_expiry_date_config_path : Sequence [str ] = ("credentials" , "token_expiry_date" ),
240+ message_repository : MessageRepository = NoopMessageRepository (),
241+ ** kwargs
242+ ):
243+ self ._connector_config = connector_config
244+ self ._access_token_config_path = access_token_config_path
245+ self ._refresh_token_config_path = refresh_token_config_path
246+ self ._token_expiry_date_config_path = token_expiry_date_config_path
247+ self .__message_repository = message_repository
248+
249+ # Get credentials from config if not provided
250+ if not client_id :
251+ client_id = self ._get_config_value_by_path (("credentials" , "client_id" ))
252+ if not client_secret :
253+ client_secret = self ._get_config_value_by_path (("credentials" , "client_secret" ))
254+
255+ super ().__init__ (
256+ token_refresh_endpoint = token_refresh_endpoint ,
257+ client_id = client_id ,
258+ client_secret = client_secret ,
259+ refresh_token = self ._get_config_value_by_path (refresh_token_config_path ),
260+ ** kwargs
261+ )
262+
263+ def refresh_access_token (self ) -> Tuple [str , Union [str , int ], str ]:
264+ """Refresh access token and update config"""
265+ response = self ._make_refresh_request ()
266+
267+ access_token = response [self ._access_token_name ]
268+ refresh_token = response .get (self ._refresh_token_name )
269+ expires_in = response .get (self ._expires_in_name )
270+
271+ self ._update_config (access_token , refresh_token , expires_in )
272+ return access_token , expires_in , refresh_token
273+
274+ def _update_config (
275+ self ,
276+ access_token : str ,
277+ refresh_token : Optional [str ],
278+ expires_in : Optional [Union [str , int ]]
279+ ) -> None :
280+ """Update the config with new token values"""
281+ dpath .new (self ._connector_config , self ._access_token_config_path , access_token )
282+ if refresh_token :
283+ dpath .new (self ._connector_config , self ._refresh_token_config_path , refresh_token )
284+ self ._refresh_token = refresh_token
285+
286+ if expires_in :
287+ expiry_date = self ._parse_token_expiration_date (expires_in )
288+ dpath .new (
289+ self ._connector_config ,
290+ self ._token_expiry_date_config_path ,
291+ str (expiry_date )
292+ )
293+
294+ self ._emit_control_message ()
295+
296+ def _emit_control_message (self ) -> None :
297+ """Emit control message for config update"""
298+ if not isinstance (self .__message_repository , NoopMessageRepository ):
299+ self .__message_repository .emit_message (
300+ create_connector_config_control_message (self ._connector_config )
301+ )
302+ else :
303+ emit_configuration_as_airbyte_control_message (self ._connector_config )
304+
305+ def get_token_expiry_date (self ) -> AirbyteDateTime :
306+ """Get token expiry date from config"""
307+ expiry_date = self ._get_config_value_by_path (self ._token_expiry_date_config_path )
308+ if expiry_date == "" :
309+ return ab_datetime_now () - timedelta (days = 1 )
310+ return ab_datetime_parse (str (expiry_date ))
311+
312+ @property
313+ def access_token (self ) -> str :
314+ """Get access token from config"""
315+ return self ._get_config_value_by_path (self ._access_token_config_path )
316+
317+ def _get_config_value_by_path (
318+ self ,
319+ config_path : Union [str , Sequence [str ]],
320+ default : Optional [str ] = None
321+ ) -> str :
322+ """Get a value from the config using a path"""
323+ return dpath .get (
324+ self ._connector_config ,
325+ config_path ,
326+ default = default if default is not None else "" ,
327+ )
328+
329+ @property
330+ def _message_repository (self ) -> MessageRepository :
331+ return self .__message_repository
0 commit comments