3333
3434from authlib .integrations .httpx_client import AsyncOAuth2Client
3535
36+ from confluent_kafka .schema_registry .common ._oauthbearer import (
37+ _BearerFieldProvider ,
38+ _AbstractOAuthBearerOIDCFieldProviderBuilder ,
39+ _StaticOAuthBearerFieldProviderBuilder ,
40+ _AbstractCustomOAuthBearerFieldProviderBuilder )
3641from confluent_kafka .schema_registry .error import SchemaRegistryError , OAuthTokenError
3742from confluent_kafka .schema_registry .common .schema_registry_client import (
3843 RegisteredSchema ,
3944 SchemaVersion ,
4045 ServerConfig ,
4146 is_success ,
4247 is_retriable ,
43- _BearerFieldProvider ,
4448 full_jitter ,
4549 _SchemaCache ,
46- Schema ,
47- _StaticFieldProvider ,
50+ Schema
4851)
4952
5053__all__ = [
5659 'AsyncSchemaRegistryClient' ,
5760]
5861
59- # TODO: consider adding `six` dependency or employing a compat file
60- # Python 2.7 is officially EOL so compatibility issue will be come more the norm.
61- # We need a better way to handle these issues.
62- # Six is one possibility but the compat file pattern used by requests
63- # is also quite nice.
64- #
65- # six: https://pypi.org/project/six/
66- # compat file : https://github.com/psf/requests/blob/master/requests/compat.py
67- try :
68- string_type = basestring # noqa
6962
70- def _urlencode (value : str ) -> str :
71- return urllib .quote (value , safe = '' )
72- except NameError :
73- string_type = str
63+ def _urlencode (value : str ) -> str :
64+ return urllib .parse .quote (value , safe = '' )
7465
75- def _urlencode (value : str ) -> str :
76- return urllib .parse .quote (value , safe = '' )
7766
7867log = logging .getLogger (__name__ )
7968
@@ -130,6 +119,55 @@ async def generate_access_token(self) -> None:
130119 await asyncio .sleep (full_jitter (self .retries_wait_ms , self .retries_max_wait_ms , i ) / 1000 )
131120
132121
122+ class _AsyncOAuthBearerOIDCFieldProviderBuilder (_AbstractOAuthBearerOIDCFieldProviderBuilder ):
123+
124+ def build (self , max_retries , retries_wait_ms , retries_max_wait_ms ):
125+ self ._validate ()
126+ return _AsyncOAuthClient (
127+ self .client_id , self .client_secret , self .scope ,
128+ self .token_endpoint ,
129+ self .logical_cluster ,
130+ self .identity_pool ,
131+ max_retries , retries_wait_ms ,
132+ retries_max_wait_ms )
133+
134+
135+ class _AsyncCustomOAuthBearerFieldProviderBuilder (_AbstractCustomOAuthBearerFieldProviderBuilder ):
136+
137+ def build (self , max_retries , retries_wait_ms , retries_max_wait_ms ):
138+ self ._validate ()
139+ return _AsyncCustomOAuthClient (
140+ self .custom_function ,
141+ self .custom_config
142+ )
143+
144+
145+ class _AsyncFieldProviderBuilder :
146+
147+ __builders = {
148+ "OAUTHBEARER" : _AsyncOAuthBearerOIDCFieldProviderBuilder ,
149+ "STATIC_TOKEN" : _StaticOAuthBearerFieldProviderBuilder ,
150+ "CUSTOM" : _AsyncCustomOAuthBearerFieldProviderBuilder
151+ }
152+
153+ @staticmethod
154+ def build (conf , max_retries , retries_wait_ms , retries_max_wait_ms ):
155+ bearer_auth_credentials_source = conf .pop ('bearer.auth.credentials.source' , None )
156+ if bearer_auth_credentials_source is None :
157+ return [None , None ]
158+
159+ if bearer_auth_credentials_source not in _AsyncFieldProviderBuilder .__builders :
160+ raise ValueError ('Unrecognized bearer.auth.credentials.source' )
161+ bearer_field_provider_builder = _AsyncFieldProviderBuilder .__builders [bearer_auth_credentials_source ](conf )
162+ return (
163+ bearer_auth_credentials_source ,
164+ bearer_field_provider_builder .build (
165+ max_retries , retries_wait_ms ,
166+ retries_max_wait_ms
167+ )
168+ )
169+
170+
133171class _AsyncBaseRestClient (object ):
134172
135173 def __init__ (self , conf : dict ):
@@ -139,7 +177,7 @@ def __init__(self, conf: dict):
139177 base_url = conf_copy .pop ('url' , None )
140178 if base_url is None :
141179 raise ValueError ("Missing required configuration property url" )
142- if not isinstance (base_url , string_type ):
180+ if not isinstance (base_url , str ):
143181 raise TypeError ("url must be a str, not " + str (type (base_url )))
144182 base_urls = []
145183 for url in base_url .split (',' ):
@@ -259,86 +297,10 @@ def __init__(self, conf: dict):
259297 + str (type (retries_max_wait_ms )))
260298 self .retries_max_wait_ms = retries_max_wait_ms
261299
262- self .bearer_field_provider = None
263- logical_cluster = None
264- identity_pool = None
265- self .bearer_auth_credentials_source = conf_copy .pop ('bearer.auth.credentials.source' , None )
266- if self .bearer_auth_credentials_source is not None :
267- self .auth = None
268-
269- if self .bearer_auth_credentials_source in {'OAUTHBEARER' , 'STATIC_TOKEN' }:
270- headers = ['bearer.auth.logical.cluster' , 'bearer.auth.identity.pool.id' ]
271- missing_headers = [header for header in headers if header not in conf_copy ]
272- if missing_headers :
273- raise ValueError ("Missing required bearer configuration properties: {}"
274- .format (", " .join (missing_headers )))
275-
276- logical_cluster = conf_copy .pop ('bearer.auth.logical.cluster' )
277- if not isinstance (logical_cluster , str ):
278- raise TypeError ("logical cluster must be a str, not " + str (type (logical_cluster )))
279-
280- identity_pool = conf_copy .pop ('bearer.auth.identity.pool.id' )
281- if not isinstance (identity_pool , str ):
282- raise TypeError ("identity pool id must be a str, not " + str (type (identity_pool )))
283-
284- if self .bearer_auth_credentials_source == 'OAUTHBEARER' :
285- properties_list = ['bearer.auth.client.id' , 'bearer.auth.client.secret' , 'bearer.auth.scope' ,
286- 'bearer.auth.issuer.endpoint.url' ]
287- missing_properties = [prop for prop in properties_list if prop not in conf_copy ]
288- if missing_properties :
289- raise ValueError ("Missing required OAuth configuration properties: {}" .
290- format (", " .join (missing_properties )))
291-
292- self .client_id = conf_copy .pop ('bearer.auth.client.id' )
293- if not isinstance (self .client_id , string_type ):
294- raise TypeError ("bearer.auth.client.id must be a str, not " + str (type (self .client_id )))
295-
296- self .client_secret = conf_copy .pop ('bearer.auth.client.secret' )
297- if not isinstance (self .client_secret , string_type ):
298- raise TypeError ("bearer.auth.client.secret must be a str, not " + str (type (self .client_secret )))
299-
300- self .scope = conf_copy .pop ('bearer.auth.scope' )
301- if not isinstance (self .scope , string_type ):
302- raise TypeError ("bearer.auth.scope must be a str, not " + str (type (self .scope )))
303-
304- self .token_endpoint = conf_copy .pop ('bearer.auth.issuer.endpoint.url' )
305- if not isinstance (self .token_endpoint , string_type ):
306- raise TypeError ("bearer.auth.issuer.endpoint.url must be a str, not "
307- + str (type (self .token_endpoint )))
308-
309- self .bearer_field_provider = _AsyncOAuthClient (
310- self .client_id , self .client_secret , self .scope ,
311- self .token_endpoint , logical_cluster , identity_pool ,
312- self .max_retries , self .retries_wait_ms ,
313- self .retries_max_wait_ms )
314- elif self .bearer_auth_credentials_source == 'STATIC_TOKEN' :
315- if 'bearer.auth.token' not in conf_copy :
316- raise ValueError ("Missing bearer.auth.token" )
317- static_token = conf_copy .pop ('bearer.auth.token' )
318- self .bearer_field_provider = _StaticFieldProvider (static_token , logical_cluster , identity_pool )
319- if not isinstance (static_token , string_type ):
320- raise TypeError ("bearer.auth.token must be a str, not " + str (type (static_token )))
321- elif self .bearer_auth_credentials_source == 'CUSTOM' :
322- custom_bearer_properties = ['bearer.auth.custom.provider.function' ,
323- 'bearer.auth.custom.provider.config' ]
324- missing_custom_properties = [prop for prop in custom_bearer_properties if prop not in conf_copy ]
325- if missing_custom_properties :
326- raise ValueError ("Missing required custom OAuth configuration properties: {}" .
327- format (", " .join (missing_custom_properties )))
328-
329- custom_function = conf_copy .pop ('bearer.auth.custom.provider.function' )
330- if not callable (custom_function ):
331- raise TypeError ("bearer.auth.custom.provider.function must be a callable, not "
332- + str (type (custom_function )))
333-
334- custom_config = conf_copy .pop ('bearer.auth.custom.provider.config' )
335- if not isinstance (custom_config , dict ):
336- raise TypeError ("bearer.auth.custom.provider.config must be a dict, not "
337- + str (type (custom_config )))
338-
339- self .bearer_field_provider = _AsyncCustomOAuthClient (custom_function , custom_config )
340- else :
341- raise ValueError ('Unrecognized bearer.auth.credentials.source' )
300+ [self .bearer_auth_credentials_source , self .bearer_field_provider ] = \
301+ _AsyncFieldProviderBuilder .build (
302+ conf_copy , self .max_retries , self .retries_wait_ms ,
303+ self .retries_max_wait_ms )
342304
343305 # Any leftover keys are unknown to _RestClient
344306 if len (conf_copy ) > 0 :
0 commit comments