Skip to content

Commit 6aea1a5

Browse files
committed
Refactor OAuth token providers validation
to be common for the sync and async clients
1 parent 9ad4e3a commit 6aea1a5

File tree

8 files changed

+308
-230
lines changed

8 files changed

+308
-230
lines changed

src/confluent_kafka/schema_registry/_async/schema_registry_client.py

Lines changed: 62 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -33,18 +33,21 @@
3333

3434
from authlib.integrations.httpx_client import AsyncOAuth2Client
3535

36+
from confluent_kafka.schema_registry.common._oauthbearer import (
37+
_BearerFieldProvider,
38+
_AbstractOAuthBearerOIDCFieldProviderBuilder,
39+
_StaticOAuthBearerFieldProviderBuilder,
40+
_AbstractCustomOAuthBearerFieldProviderBuilder)
3641
from confluent_kafka.schema_registry.error import SchemaRegistryError, OAuthTokenError
3742
from confluent_kafka.schema_registry.common.schema_registry_client import (
3843
RegisteredSchema,
3944
SchemaVersion,
4045
ServerConfig,
4146
is_success,
4247
is_retriable,
43-
_BearerFieldProvider,
4448
full_jitter,
4549
_SchemaCache,
46-
Schema,
47-
_StaticFieldProvider,
50+
Schema
4851
)
4952

5053
__all__ = [
@@ -56,24 +59,10 @@
5659
'AsyncSchemaRegistryClient',
5760
]
5861

59-
# TODO: consider adding `six` dependency or employing a compat file
60-
# Python 2.7 is officially EOL so compatibility issue will be come more the norm.
61-
# We need a better way to handle these issues.
62-
# Six is one possibility but the compat file pattern used by requests
63-
# is also quite nice.
64-
#
65-
# six: https://pypi.org/project/six/
66-
# compat file : https://github.com/psf/requests/blob/master/requests/compat.py
67-
try:
68-
string_type = basestring # noqa
6962

70-
def _urlencode(value: str) -> str:
71-
return urllib.quote(value, safe='')
72-
except NameError:
73-
string_type = str
63+
def _urlencode(value: str) -> str:
64+
return urllib.parse.quote(value, safe='')
7465

75-
def _urlencode(value: str) -> str:
76-
return urllib.parse.quote(value, safe='')
7766

7867
log = logging.getLogger(__name__)
7968

@@ -130,6 +119,55 @@ async def generate_access_token(self) -> None:
130119
await asyncio.sleep(full_jitter(self.retries_wait_ms, self.retries_max_wait_ms, i) / 1000)
131120

132121

122+
class _AsyncOAuthBearerOIDCFieldProviderBuilder(_AbstractOAuthBearerOIDCFieldProviderBuilder):
123+
124+
def build(self, max_retries, retries_wait_ms, retries_max_wait_ms):
125+
self._validate()
126+
return _AsyncOAuthClient(
127+
self.client_id, self.client_secret, self.scope,
128+
self.token_endpoint,
129+
self.logical_cluster,
130+
self.identity_pool,
131+
max_retries, retries_wait_ms,
132+
retries_max_wait_ms)
133+
134+
135+
class _AsyncCustomOAuthBearerFieldProviderBuilder(_AbstractCustomOAuthBearerFieldProviderBuilder):
136+
137+
def build(self, max_retries, retries_wait_ms, retries_max_wait_ms):
138+
self._validate()
139+
return _AsyncCustomOAuthClient(
140+
self.custom_function,
141+
self.custom_config
142+
)
143+
144+
145+
class _AsyncFieldProviderBuilder:
146+
147+
__builders = {
148+
"OAUTHBEARER": _AsyncOAuthBearerOIDCFieldProviderBuilder,
149+
"STATIC_TOKEN": _StaticOAuthBearerFieldProviderBuilder,
150+
"CUSTOM": _AsyncCustomOAuthBearerFieldProviderBuilder
151+
}
152+
153+
@staticmethod
154+
def build(conf, max_retries, retries_wait_ms, retries_max_wait_ms):
155+
bearer_auth_credentials_source = conf.pop('bearer.auth.credentials.source', None)
156+
if bearer_auth_credentials_source is None:
157+
return [None, None]
158+
159+
if bearer_auth_credentials_source not in _AsyncFieldProviderBuilder.__builders:
160+
raise ValueError('Unrecognized bearer.auth.credentials.source')
161+
bearer_field_provider_builder = _AsyncFieldProviderBuilder.__builders[bearer_auth_credentials_source](conf)
162+
return (
163+
bearer_auth_credentials_source,
164+
bearer_field_provider_builder.build(
165+
max_retries, retries_wait_ms,
166+
retries_max_wait_ms
167+
)
168+
)
169+
170+
133171
class _AsyncBaseRestClient(object):
134172

135173
def __init__(self, conf: dict):
@@ -139,7 +177,7 @@ def __init__(self, conf: dict):
139177
base_url = conf_copy.pop('url', None)
140178
if base_url is None:
141179
raise ValueError("Missing required configuration property url")
142-
if not isinstance(base_url, string_type):
180+
if not isinstance(base_url, str):
143181
raise TypeError("url must be a str, not " + str(type(base_url)))
144182
base_urls = []
145183
for url in base_url.split(','):
@@ -259,86 +297,10 @@ def __init__(self, conf: dict):
259297
+ str(type(retries_max_wait_ms)))
260298
self.retries_max_wait_ms = retries_max_wait_ms
261299

262-
self.bearer_field_provider = None
263-
logical_cluster = None
264-
identity_pool = None
265-
self.bearer_auth_credentials_source = conf_copy.pop('bearer.auth.credentials.source', None)
266-
if self.bearer_auth_credentials_source is not None:
267-
self.auth = None
268-
269-
if self.bearer_auth_credentials_source in {'OAUTHBEARER', 'STATIC_TOKEN'}:
270-
headers = ['bearer.auth.logical.cluster', 'bearer.auth.identity.pool.id']
271-
missing_headers = [header for header in headers if header not in conf_copy]
272-
if missing_headers:
273-
raise ValueError("Missing required bearer configuration properties: {}"
274-
.format(", ".join(missing_headers)))
275-
276-
logical_cluster = conf_copy.pop('bearer.auth.logical.cluster')
277-
if not isinstance(logical_cluster, str):
278-
raise TypeError("logical cluster must be a str, not " + str(type(logical_cluster)))
279-
280-
identity_pool = conf_copy.pop('bearer.auth.identity.pool.id')
281-
if not isinstance(identity_pool, str):
282-
raise TypeError("identity pool id must be a str, not " + str(type(identity_pool)))
283-
284-
if self.bearer_auth_credentials_source == 'OAUTHBEARER':
285-
properties_list = ['bearer.auth.client.id', 'bearer.auth.client.secret', 'bearer.auth.scope',
286-
'bearer.auth.issuer.endpoint.url']
287-
missing_properties = [prop for prop in properties_list if prop not in conf_copy]
288-
if missing_properties:
289-
raise ValueError("Missing required OAuth configuration properties: {}".
290-
format(", ".join(missing_properties)))
291-
292-
self.client_id = conf_copy.pop('bearer.auth.client.id')
293-
if not isinstance(self.client_id, string_type):
294-
raise TypeError("bearer.auth.client.id must be a str, not " + str(type(self.client_id)))
295-
296-
self.client_secret = conf_copy.pop('bearer.auth.client.secret')
297-
if not isinstance(self.client_secret, string_type):
298-
raise TypeError("bearer.auth.client.secret must be a str, not " + str(type(self.client_secret)))
299-
300-
self.scope = conf_copy.pop('bearer.auth.scope')
301-
if not isinstance(self.scope, string_type):
302-
raise TypeError("bearer.auth.scope must be a str, not " + str(type(self.scope)))
303-
304-
self.token_endpoint = conf_copy.pop('bearer.auth.issuer.endpoint.url')
305-
if not isinstance(self.token_endpoint, string_type):
306-
raise TypeError("bearer.auth.issuer.endpoint.url must be a str, not "
307-
+ str(type(self.token_endpoint)))
308-
309-
self.bearer_field_provider = _AsyncOAuthClient(
310-
self.client_id, self.client_secret, self.scope,
311-
self.token_endpoint, logical_cluster, identity_pool,
312-
self.max_retries, self.retries_wait_ms,
313-
self.retries_max_wait_ms)
314-
elif self.bearer_auth_credentials_source == 'STATIC_TOKEN':
315-
if 'bearer.auth.token' not in conf_copy:
316-
raise ValueError("Missing bearer.auth.token")
317-
static_token = conf_copy.pop('bearer.auth.token')
318-
self.bearer_field_provider = _StaticFieldProvider(static_token, logical_cluster, identity_pool)
319-
if not isinstance(static_token, string_type):
320-
raise TypeError("bearer.auth.token must be a str, not " + str(type(static_token)))
321-
elif self.bearer_auth_credentials_source == 'CUSTOM':
322-
custom_bearer_properties = ['bearer.auth.custom.provider.function',
323-
'bearer.auth.custom.provider.config']
324-
missing_custom_properties = [prop for prop in custom_bearer_properties if prop not in conf_copy]
325-
if missing_custom_properties:
326-
raise ValueError("Missing required custom OAuth configuration properties: {}".
327-
format(", ".join(missing_custom_properties)))
328-
329-
custom_function = conf_copy.pop('bearer.auth.custom.provider.function')
330-
if not callable(custom_function):
331-
raise TypeError("bearer.auth.custom.provider.function must be a callable, not "
332-
+ str(type(custom_function)))
333-
334-
custom_config = conf_copy.pop('bearer.auth.custom.provider.config')
335-
if not isinstance(custom_config, dict):
336-
raise TypeError("bearer.auth.custom.provider.config must be a dict, not "
337-
+ str(type(custom_config)))
338-
339-
self.bearer_field_provider = _AsyncCustomOAuthClient(custom_function, custom_config)
340-
else:
341-
raise ValueError('Unrecognized bearer.auth.credentials.source')
300+
[self.bearer_auth_credentials_source, self.bearer_field_provider] = \
301+
_AsyncFieldProviderBuilder.build(
302+
conf_copy, self.max_retries, self.retries_wait_ms,
303+
self.retries_max_wait_ms)
342304

343305
# Any leftover keys are unknown to _RestClient
344306
if len(conf_copy) > 0:

0 commit comments

Comments
 (0)