33#
44
55from dataclasses import InitVar , dataclass , field
6- from typing import Any , List , Mapping , Optional , Union
6+ from typing import Any , List , Mapping , MutableMapping , Optional , Union
77
88import pendulum
99
1010from airbyte_cdk .sources .declarative .auth .declarative_authenticator import DeclarativeAuthenticator
11+ from airbyte_cdk .sources .declarative .interpolation .interpolated_boolean import InterpolatedBoolean
1112from airbyte_cdk .sources .declarative .interpolation .interpolated_mapping import InterpolatedMapping
1213from airbyte_cdk .sources .declarative .interpolation .interpolated_string import InterpolatedString
1314from airbyte_cdk .sources .message import MessageRepository , NoopMessageRepository
@@ -44,10 +45,10 @@ class DeclarativeOauth2Authenticator(AbstractOauth2Authenticator, DeclarativeAut
4445 message_repository (MessageRepository): the message repository used to emit logs on HTTP requests
4546 """
4647
47- client_id : Union [InterpolatedString , str ]
48- client_secret : Union [InterpolatedString , str ]
4948 config : Mapping [str , Any ]
5049 parameters : InitVar [Mapping [str , Any ]]
50+ client_id : Optional [Union [InterpolatedString , str ]] = None
51+ client_secret : Optional [Union [InterpolatedString , str ]] = None
5152 token_refresh_endpoint : Optional [Union [InterpolatedString , str ]] = None
5253 refresh_token : Optional [Union [InterpolatedString , str ]] = None
5354 scopes : Optional [List [str ]] = None
@@ -66,6 +67,8 @@ class DeclarativeOauth2Authenticator(AbstractOauth2Authenticator, DeclarativeAut
6667 grant_type_name : Union [InterpolatedString , str ] = "grant_type"
6768 grant_type : Union [InterpolatedString , str ] = "refresh_token"
6869 message_repository : MessageRepository = NoopMessageRepository ()
70+ profile_assertion : Optional [DeclarativeAuthenticator ] = None
71+ use_profile_assertion : Optional [Union [InterpolatedBoolean , str , bool ]] = False
6972
7073 def __post_init__ (self , parameters : Mapping [str , Any ]) -> None :
7174 super ().__init__ ()
@@ -76,11 +79,19 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None:
7679 else :
7780 self ._token_refresh_endpoint = None
7881 self ._client_id_name = InterpolatedString .create (self .client_id_name , parameters = parameters )
79- self ._client_id = InterpolatedString .create (self .client_id , parameters = parameters )
82+ self ._client_id = (
83+ InterpolatedString .create (self .client_id , parameters = parameters )
84+ if self .client_id
85+ else self .client_id
86+ )
8087 self ._client_secret_name = InterpolatedString .create (
8188 self .client_secret_name , parameters = parameters
8289 )
83- self ._client_secret = InterpolatedString .create (self .client_secret , parameters = parameters )
90+ self ._client_secret = (
91+ InterpolatedString .create (self .client_secret , parameters = parameters )
92+ if self .client_secret
93+ else self .client_secret
94+ )
8495 self ._refresh_token_name = InterpolatedString .create (
8596 self .refresh_token_name , parameters = parameters
8697 )
@@ -99,7 +110,12 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None:
99110 self .grant_type_name = InterpolatedString .create (
100111 self .grant_type_name , parameters = parameters
101112 )
102- self .grant_type = InterpolatedString .create (self .grant_type , parameters = parameters )
113+ self .grant_type = InterpolatedString .create (
114+ "urn:ietf:params:oauth:grant-type:jwt-bearer"
115+ if self .use_profile_assertion
116+ else self .grant_type ,
117+ parameters = parameters ,
118+ )
103119 self ._refresh_request_body = InterpolatedMapping (
104120 self .refresh_request_body or {}, parameters = parameters
105121 )
@@ -115,6 +131,13 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None:
115131 if self .token_expiry_date
116132 else pendulum .now ().subtract (days = 1 ) # type: ignore # substract does not have type hints
117133 )
134+ self .use_profile_assertion = (
135+ InterpolatedBoolean (self .use_profile_assertion , parameters = parameters )
136+ if isinstance (self .use_profile_assertion , str )
137+ else self .use_profile_assertion
138+ )
139+ self .assertion_name = "assertion"
140+
118141 if self .access_token_value is not None :
119142 self ._access_token_value = InterpolatedString .create (
120143 self .access_token_value , parameters = parameters
@@ -126,9 +149,20 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None:
126149 self ._access_token_value if self .access_token_value else None
127150 )
128151
152+ if not self .use_profile_assertion and any (
153+ client_creds is None for client_creds in [self .client_id , self .client_secret ]
154+ ):
155+ raise ValueError (
156+ "OAuthAuthenticator configuration error: Both 'client_id' and 'client_secret' are required for the "
157+ "basic OAuth flow."
158+ )
159+ if self .profile_assertion is None and self .use_profile_assertion :
160+ raise ValueError (
161+ "OAuthAuthenticator configuration error: 'profile_assertion' is required when using the profile assertion flow."
162+ )
129163 if self .get_grant_type () == "refresh_token" and self ._refresh_token is None :
130164 raise ValueError (
131- "OAuthAuthenticator needs a refresh_token parameter if grant_type is set to ` refresh_token` "
165+ "OAuthAuthenticator configuration error: A ' refresh_token' is required when the ' grant_type' is set to ' refresh_token'. "
132166 )
133167
134168 def get_token_refresh_endpoint (self ) -> Optional [str ]:
@@ -145,19 +179,21 @@ def get_client_id_name(self) -> str:
145179 return self ._client_id_name .eval (self .config ) # type: ignore # eval returns a string in this context
146180
147181 def get_client_id (self ) -> str :
148- client_id : str = self ._client_id .eval (self .config )
182+ client_id = self ._client_id .eval (self .config ) if self . _client_id else self . _client_id
149183 if not client_id :
150184 raise ValueError ("OAuthAuthenticator was unable to evaluate client_id parameter" )
151- return client_id
185+ return client_id # type: ignore # value will be returned as a string, or an error will be raised
152186
153187 def get_client_secret_name (self ) -> str :
154188 return self ._client_secret_name .eval (self .config ) # type: ignore # eval returns a string in this context
155189
156190 def get_client_secret (self ) -> str :
157- client_secret : str = self ._client_secret .eval (self .config )
191+ client_secret = (
192+ self ._client_secret .eval (self .config ) if self ._client_secret else self ._client_secret
193+ )
158194 if not client_secret :
159195 raise ValueError ("OAuthAuthenticator was unable to evaluate client_secret parameter" )
160- return client_secret
196+ return client_secret # type: ignore # value will be returned as a string, or an error will be raised
161197
162198 def get_refresh_token_name (self ) -> str :
163199 return self ._refresh_token_name .eval (self .config ) # type: ignore # eval returns a string in this context
@@ -192,6 +228,27 @@ def get_token_expiry_date(self) -> pendulum.DateTime:
192228 def set_token_expiry_date (self , value : Union [str , int ]) -> None :
193229 self ._token_expiry_date = self ._parse_token_expiration_date (value )
194230
231+ def get_assertion_name (self ) -> str :
232+ return self .assertion_name
233+
234+ def get_assertion (self ) -> str :
235+ if self .profile_assertion is None :
236+ raise ValueError ("profile_assertion is not set" )
237+ return self .profile_assertion .token
238+
239+ def build_refresh_request_body (self ) -> Mapping [str , Any ]:
240+ """
241+ Returns the request body to set on the refresh request
242+
243+ Override to define additional parameters
244+ """
245+ if self .use_profile_assertion :
246+ return {
247+ self .get_grant_type_name (): self .get_grant_type (),
248+ self .get_assertion_name (): self .get_assertion (),
249+ }
250+ return super ().build_refresh_request_body ()
251+
195252 @property
196253 def access_token (self ) -> str :
197254 if self ._access_token is None :
0 commit comments