From 6aea1a532dfc43be7878ce1d6c1a6a902600fa6e Mon Sep 17 00:00:00 2001 From: Emanuele Sabellico Date: Thu, 18 Sep 2025 14:00:26 +0200 Subject: [PATCH 1/4] Refactor OAuth token providers validation to be common for the sync and async clients --- .../_async/schema_registry_client.py | 162 +++++++--------- .../_sync/schema_registry_client.py | 162 +++++++--------- .../schema_registry/common/_oauthbearer.py | 174 ++++++++++++++++++ .../common/schema_registry_client.py | 20 -- .../_async/test_bearer_field_provider.py | 2 +- tests/schema_registry/_async/test_config.py | 8 +- .../_sync/test_bearer_field_provider.py | 2 +- tests/schema_registry/_sync/test_config.py | 8 +- 8 files changed, 308 insertions(+), 230 deletions(-) create mode 100644 src/confluent_kafka/schema_registry/common/_oauthbearer.py diff --git a/src/confluent_kafka/schema_registry/_async/schema_registry_client.py b/src/confluent_kafka/schema_registry/_async/schema_registry_client.py index 368748fe6..effd20e4b 100644 --- a/src/confluent_kafka/schema_registry/_async/schema_registry_client.py +++ b/src/confluent_kafka/schema_registry/_async/schema_registry_client.py @@ -33,6 +33,11 @@ from authlib.integrations.httpx_client import AsyncOAuth2Client +from confluent_kafka.schema_registry.common._oauthbearer import ( + _BearerFieldProvider, + _AbstractOAuthBearerOIDCFieldProviderBuilder, + _StaticOAuthBearerFieldProviderBuilder, + _AbstractCustomOAuthBearerFieldProviderBuilder) from confluent_kafka.schema_registry.error import SchemaRegistryError, OAuthTokenError from confluent_kafka.schema_registry.common.schema_registry_client import ( RegisteredSchema, @@ -40,11 +45,9 @@ ServerConfig, is_success, is_retriable, - _BearerFieldProvider, full_jitter, _SchemaCache, - Schema, - _StaticFieldProvider, + Schema ) __all__ = [ @@ -56,24 +59,10 @@ 'AsyncSchemaRegistryClient', ] -# TODO: consider adding `six` dependency or employing a compat file -# Python 2.7 is officially EOL so compatibility issue will be come more the norm. -# We need a better way to handle these issues. -# Six is one possibility but the compat file pattern used by requests -# is also quite nice. -# -# six: https://pypi.org/project/six/ -# compat file : https://github.com/psf/requests/blob/master/requests/compat.py -try: - string_type = basestring # noqa - def _urlencode(value: str) -> str: - return urllib.quote(value, safe='') -except NameError: - string_type = str +def _urlencode(value: str) -> str: + return urllib.parse.quote(value, safe='') - def _urlencode(value: str) -> str: - return urllib.parse.quote(value, safe='') log = logging.getLogger(__name__) @@ -130,6 +119,55 @@ async def generate_access_token(self) -> None: await asyncio.sleep(full_jitter(self.retries_wait_ms, self.retries_max_wait_ms, i) / 1000) +class _AsyncOAuthBearerOIDCFieldProviderBuilder(_AbstractOAuthBearerOIDCFieldProviderBuilder): + + def build(self, max_retries, retries_wait_ms, retries_max_wait_ms): + self._validate() + return _AsyncOAuthClient( + self.client_id, self.client_secret, self.scope, + self.token_endpoint, + self.logical_cluster, + self.identity_pool, + max_retries, retries_wait_ms, + retries_max_wait_ms) + + +class _AsyncCustomOAuthBearerFieldProviderBuilder(_AbstractCustomOAuthBearerFieldProviderBuilder): + + def build(self, max_retries, retries_wait_ms, retries_max_wait_ms): + self._validate() + return _AsyncCustomOAuthClient( + self.custom_function, + self.custom_config + ) + + +class _AsyncFieldProviderBuilder: + + __builders = { + "OAUTHBEARER": _AsyncOAuthBearerOIDCFieldProviderBuilder, + "STATIC_TOKEN": _StaticOAuthBearerFieldProviderBuilder, + "CUSTOM": _AsyncCustomOAuthBearerFieldProviderBuilder + } + + @staticmethod + def build(conf, max_retries, retries_wait_ms, retries_max_wait_ms): + bearer_auth_credentials_source = conf.pop('bearer.auth.credentials.source', None) + if bearer_auth_credentials_source is None: + return [None, None] + + if bearer_auth_credentials_source not in _AsyncFieldProviderBuilder.__builders: + raise ValueError('Unrecognized bearer.auth.credentials.source') + bearer_field_provider_builder = _AsyncFieldProviderBuilder.__builders[bearer_auth_credentials_source](conf) + return ( + bearer_auth_credentials_source, + bearer_field_provider_builder.build( + max_retries, retries_wait_ms, + retries_max_wait_ms + ) + ) + + class _AsyncBaseRestClient(object): def __init__(self, conf: dict): @@ -139,7 +177,7 @@ def __init__(self, conf: dict): base_url = conf_copy.pop('url', None) if base_url is None: raise ValueError("Missing required configuration property url") - if not isinstance(base_url, string_type): + if not isinstance(base_url, str): raise TypeError("url must be a str, not " + str(type(base_url))) base_urls = [] for url in base_url.split(','): @@ -259,86 +297,10 @@ def __init__(self, conf: dict): + str(type(retries_max_wait_ms))) self.retries_max_wait_ms = retries_max_wait_ms - self.bearer_field_provider = None - logical_cluster = None - identity_pool = None - self.bearer_auth_credentials_source = conf_copy.pop('bearer.auth.credentials.source', None) - if self.bearer_auth_credentials_source is not None: - self.auth = None - - if self.bearer_auth_credentials_source in {'OAUTHBEARER', 'STATIC_TOKEN'}: - headers = ['bearer.auth.logical.cluster', 'bearer.auth.identity.pool.id'] - missing_headers = [header for header in headers if header not in conf_copy] - if missing_headers: - raise ValueError("Missing required bearer configuration properties: {}" - .format(", ".join(missing_headers))) - - logical_cluster = conf_copy.pop('bearer.auth.logical.cluster') - if not isinstance(logical_cluster, str): - raise TypeError("logical cluster must be a str, not " + str(type(logical_cluster))) - - identity_pool = conf_copy.pop('bearer.auth.identity.pool.id') - if not isinstance(identity_pool, str): - raise TypeError("identity pool id must be a str, not " + str(type(identity_pool))) - - if self.bearer_auth_credentials_source == 'OAUTHBEARER': - properties_list = ['bearer.auth.client.id', 'bearer.auth.client.secret', 'bearer.auth.scope', - 'bearer.auth.issuer.endpoint.url'] - missing_properties = [prop for prop in properties_list if prop not in conf_copy] - if missing_properties: - raise ValueError("Missing required OAuth configuration properties: {}". - format(", ".join(missing_properties))) - - self.client_id = conf_copy.pop('bearer.auth.client.id') - if not isinstance(self.client_id, string_type): - raise TypeError("bearer.auth.client.id must be a str, not " + str(type(self.client_id))) - - self.client_secret = conf_copy.pop('bearer.auth.client.secret') - if not isinstance(self.client_secret, string_type): - raise TypeError("bearer.auth.client.secret must be a str, not " + str(type(self.client_secret))) - - self.scope = conf_copy.pop('bearer.auth.scope') - if not isinstance(self.scope, string_type): - raise TypeError("bearer.auth.scope must be a str, not " + str(type(self.scope))) - - self.token_endpoint = conf_copy.pop('bearer.auth.issuer.endpoint.url') - if not isinstance(self.token_endpoint, string_type): - raise TypeError("bearer.auth.issuer.endpoint.url must be a str, not " - + str(type(self.token_endpoint))) - - self.bearer_field_provider = _AsyncOAuthClient( - self.client_id, self.client_secret, self.scope, - self.token_endpoint, logical_cluster, identity_pool, - self.max_retries, self.retries_wait_ms, - self.retries_max_wait_ms) - elif self.bearer_auth_credentials_source == 'STATIC_TOKEN': - if 'bearer.auth.token' not in conf_copy: - raise ValueError("Missing bearer.auth.token") - static_token = conf_copy.pop('bearer.auth.token') - self.bearer_field_provider = _StaticFieldProvider(static_token, logical_cluster, identity_pool) - if not isinstance(static_token, string_type): - raise TypeError("bearer.auth.token must be a str, not " + str(type(static_token))) - elif self.bearer_auth_credentials_source == 'CUSTOM': - custom_bearer_properties = ['bearer.auth.custom.provider.function', - 'bearer.auth.custom.provider.config'] - missing_custom_properties = [prop for prop in custom_bearer_properties if prop not in conf_copy] - if missing_custom_properties: - raise ValueError("Missing required custom OAuth configuration properties: {}". - format(", ".join(missing_custom_properties))) - - custom_function = conf_copy.pop('bearer.auth.custom.provider.function') - if not callable(custom_function): - raise TypeError("bearer.auth.custom.provider.function must be a callable, not " - + str(type(custom_function))) - - custom_config = conf_copy.pop('bearer.auth.custom.provider.config') - if not isinstance(custom_config, dict): - raise TypeError("bearer.auth.custom.provider.config must be a dict, not " - + str(type(custom_config))) - - self.bearer_field_provider = _AsyncCustomOAuthClient(custom_function, custom_config) - else: - raise ValueError('Unrecognized bearer.auth.credentials.source') + [self.bearer_auth_credentials_source, self.bearer_field_provider] = \ + _AsyncFieldProviderBuilder.build( + conf_copy, self.max_retries, self.retries_wait_ms, + self.retries_max_wait_ms) # Any leftover keys are unknown to _RestClient if len(conf_copy) > 0: diff --git a/src/confluent_kafka/schema_registry/_sync/schema_registry_client.py b/src/confluent_kafka/schema_registry/_sync/schema_registry_client.py index 1f579f6a8..2b3b5d10d 100644 --- a/src/confluent_kafka/schema_registry/_sync/schema_registry_client.py +++ b/src/confluent_kafka/schema_registry/_sync/schema_registry_client.py @@ -33,6 +33,11 @@ from authlib.integrations.httpx_client import OAuth2Client +from confluent_kafka.schema_registry.common._oauthbearer import ( + _BearerFieldProvider, + _AbstractOAuthBearerOIDCFieldProviderBuilder, + _StaticOAuthBearerFieldProviderBuilder, + _AbstractCustomOAuthBearerFieldProviderBuilder) from confluent_kafka.schema_registry.error import SchemaRegistryError, OAuthTokenError from confluent_kafka.schema_registry.common.schema_registry_client import ( RegisteredSchema, @@ -40,11 +45,9 @@ ServerConfig, is_success, is_retriable, - _BearerFieldProvider, full_jitter, _SchemaCache, - Schema, - _StaticFieldProvider, + Schema ) __all__ = [ @@ -56,24 +59,10 @@ 'SchemaRegistryClient', ] -# TODO: consider adding `six` dependency or employing a compat file -# Python 2.7 is officially EOL so compatibility issue will be come more the norm. -# We need a better way to handle these issues. -# Six is one possibility but the compat file pattern used by requests -# is also quite nice. -# -# six: https://pypi.org/project/six/ -# compat file : https://github.com/psf/requests/blob/master/requests/compat.py -try: - string_type = basestring # noqa - def _urlencode(value: str) -> str: - return urllib.quote(value, safe='') -except NameError: - string_type = str +def _urlencode(value: str) -> str: + return urllib.parse.quote(value, safe='') - def _urlencode(value: str) -> str: - return urllib.parse.quote(value, safe='') log = logging.getLogger(__name__) @@ -130,6 +119,55 @@ def generate_access_token(self) -> None: time.sleep(full_jitter(self.retries_wait_ms, self.retries_max_wait_ms, i) / 1000) +class _OAuthBearerOIDCFieldProviderBuilder(_AbstractOAuthBearerOIDCFieldProviderBuilder): + + def build(self, max_retries, retries_wait_ms, retries_max_wait_ms): + self._validate() + return _OAuthClient( + self.client_id, self.client_secret, self.scope, + self.token_endpoint, + self.logical_cluster, + self.identity_pool, + max_retries, retries_wait_ms, + retries_max_wait_ms) + + +class _CustomOAuthBearerFieldProviderBuilder(_AbstractCustomOAuthBearerFieldProviderBuilder): + + def build(self, max_retries, retries_wait_ms, retries_max_wait_ms): + self._validate() + return _CustomOAuthClient( + self.custom_function, + self.custom_config + ) + + +class _FieldProviderBuilder: + + __builders = { + "OAUTHBEARER": _OAuthBearerOIDCFieldProviderBuilder, + "STATIC_TOKEN": _StaticOAuthBearerFieldProviderBuilder, + "CUSTOM": _CustomOAuthBearerFieldProviderBuilder + } + + @staticmethod + def build(conf, max_retries, retries_wait_ms, retries_max_wait_ms): + bearer_auth_credentials_source = conf.pop('bearer.auth.credentials.source', None) + if bearer_auth_credentials_source is None: + return [None, None] + + if bearer_auth_credentials_source not in _FieldProviderBuilder.__builders: + raise ValueError('Unrecognized bearer.auth.credentials.source') + bearer_field_provider_builder = _FieldProviderBuilder.__builders[bearer_auth_credentials_source](conf) + return ( + bearer_auth_credentials_source, + bearer_field_provider_builder.build( + max_retries, retries_wait_ms, + retries_max_wait_ms + ) + ) + + class _BaseRestClient(object): def __init__(self, conf: dict): @@ -139,7 +177,7 @@ def __init__(self, conf: dict): base_url = conf_copy.pop('url', None) if base_url is None: raise ValueError("Missing required configuration property url") - if not isinstance(base_url, string_type): + if not isinstance(base_url, str): raise TypeError("url must be a str, not " + str(type(base_url))) base_urls = [] for url in base_url.split(','): @@ -259,86 +297,10 @@ def __init__(self, conf: dict): + str(type(retries_max_wait_ms))) self.retries_max_wait_ms = retries_max_wait_ms - self.bearer_field_provider = None - logical_cluster = None - identity_pool = None - self.bearer_auth_credentials_source = conf_copy.pop('bearer.auth.credentials.source', None) - if self.bearer_auth_credentials_source is not None: - self.auth = None - - if self.bearer_auth_credentials_source in {'OAUTHBEARER', 'STATIC_TOKEN'}: - headers = ['bearer.auth.logical.cluster', 'bearer.auth.identity.pool.id'] - missing_headers = [header for header in headers if header not in conf_copy] - if missing_headers: - raise ValueError("Missing required bearer configuration properties: {}" - .format(", ".join(missing_headers))) - - logical_cluster = conf_copy.pop('bearer.auth.logical.cluster') - if not isinstance(logical_cluster, str): - raise TypeError("logical cluster must be a str, not " + str(type(logical_cluster))) - - identity_pool = conf_copy.pop('bearer.auth.identity.pool.id') - if not isinstance(identity_pool, str): - raise TypeError("identity pool id must be a str, not " + str(type(identity_pool))) - - if self.bearer_auth_credentials_source == 'OAUTHBEARER': - properties_list = ['bearer.auth.client.id', 'bearer.auth.client.secret', 'bearer.auth.scope', - 'bearer.auth.issuer.endpoint.url'] - missing_properties = [prop for prop in properties_list if prop not in conf_copy] - if missing_properties: - raise ValueError("Missing required OAuth configuration properties: {}". - format(", ".join(missing_properties))) - - self.client_id = conf_copy.pop('bearer.auth.client.id') - if not isinstance(self.client_id, string_type): - raise TypeError("bearer.auth.client.id must be a str, not " + str(type(self.client_id))) - - self.client_secret = conf_copy.pop('bearer.auth.client.secret') - if not isinstance(self.client_secret, string_type): - raise TypeError("bearer.auth.client.secret must be a str, not " + str(type(self.client_secret))) - - self.scope = conf_copy.pop('bearer.auth.scope') - if not isinstance(self.scope, string_type): - raise TypeError("bearer.auth.scope must be a str, not " + str(type(self.scope))) - - self.token_endpoint = conf_copy.pop('bearer.auth.issuer.endpoint.url') - if not isinstance(self.token_endpoint, string_type): - raise TypeError("bearer.auth.issuer.endpoint.url must be a str, not " - + str(type(self.token_endpoint))) - - self.bearer_field_provider = _OAuthClient( - self.client_id, self.client_secret, self.scope, - self.token_endpoint, logical_cluster, identity_pool, - self.max_retries, self.retries_wait_ms, - self.retries_max_wait_ms) - elif self.bearer_auth_credentials_source == 'STATIC_TOKEN': - if 'bearer.auth.token' not in conf_copy: - raise ValueError("Missing bearer.auth.token") - static_token = conf_copy.pop('bearer.auth.token') - self.bearer_field_provider = _StaticFieldProvider(static_token, logical_cluster, identity_pool) - if not isinstance(static_token, string_type): - raise TypeError("bearer.auth.token must be a str, not " + str(type(static_token))) - elif self.bearer_auth_credentials_source == 'CUSTOM': - custom_bearer_properties = ['bearer.auth.custom.provider.function', - 'bearer.auth.custom.provider.config'] - missing_custom_properties = [prop for prop in custom_bearer_properties if prop not in conf_copy] - if missing_custom_properties: - raise ValueError("Missing required custom OAuth configuration properties: {}". - format(", ".join(missing_custom_properties))) - - custom_function = conf_copy.pop('bearer.auth.custom.provider.function') - if not callable(custom_function): - raise TypeError("bearer.auth.custom.provider.function must be a callable, not " - + str(type(custom_function))) - - custom_config = conf_copy.pop('bearer.auth.custom.provider.config') - if not isinstance(custom_config, dict): - raise TypeError("bearer.auth.custom.provider.config must be a dict, not " - + str(type(custom_config))) - - self.bearer_field_provider = _CustomOAuthClient(custom_function, custom_config) - else: - raise ValueError('Unrecognized bearer.auth.credentials.source') + [self.bearer_auth_credentials_source, self.bearer_field_provider] = \ + _FieldProviderBuilder.build( + conf_copy, self.max_retries, self.retries_wait_ms, + self.retries_max_wait_ms) # Any leftover keys are unknown to _RestClient if len(conf_copy) > 0: diff --git a/src/confluent_kafka/schema_registry/common/_oauthbearer.py b/src/confluent_kafka/schema_registry/common/_oauthbearer.py new file mode 100644 index 000000000..5261629c6 --- /dev/null +++ b/src/confluent_kafka/schema_registry/common/_oauthbearer.py @@ -0,0 +1,174 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright 2025 Confluent Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import abc + +__all__ = [ + '_AbstractOAuthBearerFieldProviderBuilder', + '_AbstractOAuthBearerOIDCFieldProviderBuilder', + '_StaticOAuthBearerFieldProviderBuilder', + '_AbstractCustomOAuthBearerFieldProviderBuilder' +] + + +class _AbstractOAuthBearerFieldProviderBuilder(metaclass=abc.ABCMeta): + """Abstract base class for OAuthBearer client builders""" + required_properties = ['bearer.auth.logical.cluster', 'bearer.auth.identity.pool.id'] + + def __init__(self, conf): + self.conf = conf + self.logical_cluster = None + self.identity_pool = None + + def _validate(self): + missing_properties = [prop for prop in + _AbstractOAuthBearerFieldProviderBuilder.required_properties + if prop not in self.conf] + if missing_properties: + raise ValueError("Missing required bearer configuration properties: {}" + .format(", ".join(missing_properties))) + + self.logical_cluster = self.conf.pop('bearer.auth.logical.cluster') + if not isinstance(self.logical_cluster, str): + raise TypeError("logical cluster must be a str, not " + + str(type(self.logical_cluster))) + + self.identity_pool = self.conf.pop('bearer.auth.identity.pool.id') + if not isinstance(self.identity_pool, str): + raise TypeError("identity pool id must be a str, not " + + str(type(self.identity_pool))) + + @abc.abstractmethod + def build(self, max_retries, retries_wait_ms, retries_max_wait_ms): + pass + + +class _AbstractOAuthBearerOIDCFieldProviderBuilder(_AbstractOAuthBearerFieldProviderBuilder): + required_properties = ['bearer.auth.client.id', 'bearer.auth.client.secret', + 'bearer.auth.scope', + 'bearer.auth.issuer.endpoint.url'] + + def __init__(self, conf): + super().__init__(conf) + self.client_id = None + self.client_secret = None + self.scope = None + self.token_endpoint = None + + def _validate(self): + super()._validate() + + missing_properties = [prop for prop in + _AbstractOAuthBearerOIDCFieldProviderBuilder.required_properties + if prop not in self.conf] + if missing_properties: + raise ValueError("Missing required OAuth configuration properties: {}". + format(", ".join(missing_properties))) + + self.client_id = self.conf.pop('bearer.auth.client.id') + if not isinstance(self.client_id, str): + raise TypeError("bearer.auth.client.id must be a str, not " + + str(type(self.client_id))) + + self.client_secret = self.conf.pop('bearer.auth.client.secret') + if not isinstance(self.client_secret, str): + raise TypeError("bearer.auth.client.secret must be a str, not " + + str(type(self.client_secret))) + + self.scope = self.conf.pop('bearer.auth.scope') + if not isinstance(self.scope, str): + raise TypeError("bearer.auth.scope must be a str, not " + + str(type(self.scope))) + + self.token_endpoint = self.conf.pop('bearer.auth.issuer.endpoint.url') + if not isinstance(self.token_endpoint, str): + raise TypeError("bearer.auth.issuer.endpoint.url must be a str, not " + + str(type(self.token_endpoint))) + + +class _BearerFieldProvider(metaclass=abc.ABCMeta): + @abc.abstractmethod + def get_bearer_fields(self) -> dict: + raise NotImplementedError + + +class _StaticFieldProvider(_BearerFieldProvider): + def __init__(self, token: str, logical_cluster: str, identity_pool: str): + self.token = token + self.logical_cluster = logical_cluster + self.identity_pool = identity_pool + + def get_bearer_fields(self) -> dict: + return {'bearer.auth.token': self.token, 'bearer.auth.logical.cluster': self.logical_cluster, + 'bearer.auth.identity.pool.id': self.identity_pool} + + +class _StaticOAuthBearerFieldProviderBuilder(_AbstractOAuthBearerFieldProviderBuilder): + required_properties = ['bearer.auth.client.id', 'bearer.auth.client.secret', + 'bearer.auth.scope', + 'bearer.auth.issuer.endpoint.url'] + + def __init__(self, conf): + super().__init__(conf) + self.static_token = None + + def _validate(self): + super()._validate() + + if 'bearer.auth.token' not in self.conf: + raise ValueError("Missing bearer.auth.token") + self.static_token = self.conf.pop('bearer.auth.token') + if not isinstance(self.static_token, str): + raise TypeError("bearer.auth.token must be a str, not " + + str(type(self.static_token))) + + def build(self, max_retries, retries_wait_ms, retries_max_wait_ms): + self._validate() + return _StaticFieldProvider( + self.static_token, + self.logical_cluster, + self.identity_pool + ) + + +class _AbstractCustomOAuthBearerFieldProviderBuilder: + required_properties = ['bearer.auth.custom.provider.function', + 'bearer.auth.custom.provider.config'] + + def __init__(self, conf): + self.conf = conf + self.custom_function = None + self.custom_config = None + + def _validate(self): + missing_properties = [prop for prop in + _AbstractCustomOAuthBearerFieldProviderBuilder.required_properties + if prop not in self.conf] + if missing_properties: + raise ValueError("Missing required custom OAuth configuration properties: {}". + format(", ".join(missing_properties))) + + self.custom_function = self.conf.pop('bearer.auth.custom.provider.function') + if not callable(self.custom_function): + raise TypeError("bearer.auth.custom.provider.function must be a callable, not " + + str(type(self.custom_function))) + + self.custom_config = self.conf.pop('bearer.auth.custom.provider.config') + if not isinstance(self.custom_config, dict): + raise TypeError("bearer.auth.custom.provider.config must be a dict, not " + + str(type(self.custom_config))) diff --git a/src/confluent_kafka/schema_registry/common/schema_registry_client.py b/src/confluent_kafka/schema_registry/common/schema_registry_client.py index e52c0254b..7ba1ead2a 100644 --- a/src/confluent_kafka/schema_registry/common/schema_registry_client.py +++ b/src/confluent_kafka/schema_registry/common/schema_registry_client.py @@ -14,7 +14,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import abc import random from attrs import define as _attrs_define @@ -27,11 +26,9 @@ __all__ = [ 'VALID_AUTH_PROVIDERS', - '_BearerFieldProvider', 'is_success', 'is_retriable', 'full_jitter', - '_StaticFieldProvider', '_SchemaCache', 'RuleKind', 'RuleMode', @@ -51,12 +48,6 @@ VALID_AUTH_PROVIDERS = ['URL', 'USER_INFO'] -class _BearerFieldProvider(metaclass=abc.ABCMeta): - @abc.abstractmethod - def get_bearer_fields(self) -> dict: - raise NotImplementedError - - def is_success(status_code: int) -> bool: return 200 <= status_code <= 299 @@ -70,17 +61,6 @@ def full_jitter(base_delay_ms: int, max_delay_ms: int, retries_attempted: int) - return random.random() * min(no_jitter_delay, max_delay_ms) -class _StaticFieldProvider(_BearerFieldProvider): - def __init__(self, token: str, logical_cluster: str, identity_pool: str): - self.token = token - self.logical_cluster = logical_cluster - self.identity_pool = identity_pool - - def get_bearer_fields(self) -> dict: - return {'bearer.auth.token': self.token, 'bearer.auth.logical.cluster': self.logical_cluster, - 'bearer.auth.identity.pool.id': self.identity_pool} - - class _SchemaCache(object): """ Thread-safe cache for use with the Schema Registry Client. diff --git a/tests/schema_registry/_async/test_bearer_field_provider.py b/tests/schema_registry/_async/test_bearer_field_provider.py index 4d7c71138..cb405ef26 100644 --- a/tests/schema_registry/_async/test_bearer_field_provider.py +++ b/tests/schema_registry/_async/test_bearer_field_provider.py @@ -21,7 +21,7 @@ from confluent_kafka.schema_registry._async.schema_registry_client import _AsyncOAuthClient, AsyncSchemaRegistryClient from confluent_kafka.schema_registry._async.schema_registry_client import _AsyncCustomOAuthClient -from confluent_kafka.schema_registry.common.schema_registry_client import _StaticFieldProvider +from confluent_kafka.schema_registry.common._oauthbearer import _StaticFieldProvider from confluent_kafka.schema_registry.error import OAuthTokenError """ diff --git a/tests/schema_registry/_async/test_config.py b/tests/schema_registry/_async/test_config.py index 54695948e..c6ee13fdb 100644 --- a/tests/schema_registry/_async/test_config.py +++ b/tests/schema_registry/_async/test_config.py @@ -222,10 +222,10 @@ def test_oauth_bearer_config_valid(): client = AsyncSchemaRegistryClient(conf) - assert client._rest_client.client_id == TEST_USERNAME - assert client._rest_client.client_secret == TEST_USER_PASSWORD - assert client._rest_client.scope == TEST_SCOPE - assert client._rest_client.token_endpoint == TEST_ENDPOINT + assert client._rest_client.bearer_field_provider.client.client_id == TEST_USERNAME + assert client._rest_client.bearer_field_provider.client.client_secret == TEST_USER_PASSWORD + assert client._rest_client.bearer_field_provider.client.scope == TEST_SCOPE + assert client._rest_client.bearer_field_provider.token_endpoint == TEST_ENDPOINT def test_static_bearer_config(): diff --git a/tests/schema_registry/_sync/test_bearer_field_provider.py b/tests/schema_registry/_sync/test_bearer_field_provider.py index ec086f20a..2e0f3e400 100644 --- a/tests/schema_registry/_sync/test_bearer_field_provider.py +++ b/tests/schema_registry/_sync/test_bearer_field_provider.py @@ -21,7 +21,7 @@ from confluent_kafka.schema_registry._sync.schema_registry_client import _OAuthClient, SchemaRegistryClient from confluent_kafka.schema_registry._sync.schema_registry_client import _CustomOAuthClient -from confluent_kafka.schema_registry.common.schema_registry_client import _StaticFieldProvider +from confluent_kafka.schema_registry.common._oauthbearer import _StaticFieldProvider from confluent_kafka.schema_registry.error import OAuthTokenError """ diff --git a/tests/schema_registry/_sync/test_config.py b/tests/schema_registry/_sync/test_config.py index a9bd3f1d4..c075e869e 100644 --- a/tests/schema_registry/_sync/test_config.py +++ b/tests/schema_registry/_sync/test_config.py @@ -222,10 +222,10 @@ def test_oauth_bearer_config_valid(): client = SchemaRegistryClient(conf) - assert client._rest_client.client_id == TEST_USERNAME - assert client._rest_client.client_secret == TEST_USER_PASSWORD - assert client._rest_client.scope == TEST_SCOPE - assert client._rest_client.token_endpoint == TEST_ENDPOINT + assert client._rest_client.bearer_field_provider.client.client_id == TEST_USERNAME + assert client._rest_client.bearer_field_provider.client.client_secret == TEST_USER_PASSWORD + assert client._rest_client.bearer_field_provider.client.scope == TEST_SCOPE + assert client._rest_client.bearer_field_provider.token_endpoint == TEST_ENDPOINT def test_static_bearer_config(): From 6df4ccef9e19465c956006a54532529511c1e855 Mon Sep 17 00:00:00 2001 From: Emanuele Sabellico Date: Fri, 19 Sep 2025 14:14:48 +0200 Subject: [PATCH 2/4] OAuth Azure IMDB implementation --- CHANGELOG.md | 11 ++ examples/README.md | 2 + .../oauth_oidc_ccloud_azure_imds_producer.py | 116 ++++++++++++++++++ examples/oauth_schema_registry.py | 12 +- .../_async/schema_registry_client.py | 99 +++++++++++---- .../_sync/schema_registry_client.py | 99 +++++++++++---- .../schema_registry/common/_oauthbearer.py | 81 ++++++++---- .../_async/test_bearer_field_provider.py | 42 ++++++- tests/schema_registry/_async/test_config.py | 82 +++++++++++++ .../_sync/test_bearer_field_provider.py | 42 ++++++- tests/schema_registry/_sync/test_config.py | 82 +++++++++++++ 11 files changed, 591 insertions(+), 77 deletions(-) create mode 100644 examples/oauth_oidc_ccloud_azure_imds_producer.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 9b80c875b..2d08f59fb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,16 @@ # Confluent's Python client for Apache Kafka +## v2.12.0 + +v2.12.0 is a feature release with the following enhancements: + +- OAuth/OIDC metadata based authentication with Azure IMDS (#). + +confluent-kafka-python v2.12.0 is based on librdkafka v2.12.0, see the +[librdkafka release notes](https://github.com/confluentinc/librdkafka/releases/tag/v2.12.0) +for a complete list of changes, enhancements, fixes and upgrade considerations. + + ## v2.11.1 v2.11.1 is a maintenance release with the following fixes: diff --git a/examples/README.md b/examples/README.md index 1df0bc0bd..799ead6bf 100644 --- a/examples/README.md +++ b/examples/README.md @@ -18,6 +18,8 @@ The scripts in this directory provide various examples of using Confluent's Pyth Additional examples for [Confluent Cloud](https://www.confluent.io/confluent-cloud/): * [confluent_cloud.py](confluent_cloud.py): Produce messages to Confluent Cloud and then read them back again. +* [oauth_oidc_ccloud_producer.py](oauth_oidc_ccloud_producer.py): Demonstrates OAuth/OIDC Authentication with Confluent Cloud (client credentials). +* [oauth_oidc_ccloud_azure_imds_producer.py](oauth_oidc_ccloud_azure_imds_producer.py): Demonstrates OAuth/OIDC Authentication with Confluent Cloud (Azure IMDS metadata based authentication). * [confluentinc/examples](https://github.com/confluentinc/examples/tree/master/clients/cloud/python): Integration with Confluent Cloud and Confluent Cloud Schema Registry ## venv setup diff --git a/examples/oauth_oidc_ccloud_azure_imds_producer.py b/examples/oauth_oidc_ccloud_azure_imds_producer.py new file mode 100644 index 000000000..e46252e36 --- /dev/null +++ b/examples/oauth_oidc_ccloud_azure_imds_producer.py @@ -0,0 +1,116 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright 2025 Confluent Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# This example use Azure IMDS for credential-less authentication +# through to Schema Registry on Confluent Cloud + +import logging +import argparse +from confluent_kafka import Producer +from confluent_kafka.serialization import StringSerializer + + +def producer_config(args): + logger = logging.getLogger(__name__) + logger.setLevel(logging.DEBUG) + params = { + 'bootstrap.servers': args.bootstrap_servers, + 'security.protocol': 'SASL_SSL', + 'sasl.mechanisms': 'OAUTHBEARER', + 'sasl.oauthbearer.method': 'oidc', + 'sasl.oauthbearer.metadata.authentication.type': 'azure_imds', + 'sasl.oauthbearer.config': f'query={args.query}' + } + # These two parameters are only applicable when producing to + # confluent cloud where some sasl extensions are required. + if args.logical_cluster and args.identity_pool_id: + params['sasl.oauthbearer.extensions'] = 'logicalCluster=' + args.logical_cluster + \ + ',identityPoolId=' + args.identity_pool_id + + return params + + +def delivery_report(err, msg): + """ + Reports the failure or success of a message delivery. + + Args: + err (KafkaError): The error that occurred on None on success. + + msg (Message): The message that was produced or failed. + + Note: + In the delivery report callback the Message.key() and Message.value() + will be the binary format as encoded by any configured Serializers and + not the same object that was passed to produce(). + If you wish to pass the original object(s) for key and value to delivery + report callback we recommend a bound callback or lambda where you pass + the objects along. + + """ + if err is not None: + print('Delivery failed for User record {}: {}'.format(msg.key(), err)) + return + print('User record {} successfully produced to {} [{}] at offset {}'.format( + msg.key(), msg.topic(), msg.partition(), msg.offset())) + + +def main(args): + topic = args.topic + delimiter = args.delimiter + producer_conf = producer_config(args) + producer = Producer(producer_conf) + serializer = StringSerializer('utf_8') + + print('Producing records to topic {}. ^C to exit.'.format(topic)) + while True: + # Serve on_delivery callbacks from previous calls to produce() + producer.poll(0.0) + try: + msg_data = input(">") + msg = msg_data.split(delimiter) + if len(msg) == 2: + producer.produce(topic=topic, + key=serializer(msg[0]), + value=serializer(msg[1]), + on_delivery=delivery_report) + else: + producer.produce(topic=topic, + value=serializer(msg[0]), + on_delivery=delivery_report) + except KeyboardInterrupt: + break + + print('\nFlushing {} records...'.format(len(producer))) + producer.flush() + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description="OAUTH example with client credentials grant") + parser.add_argument('-b', dest="bootstrap_servers", required=True, + help="Bootstrap broker(s) (host[:port])") + parser.add_argument('-t', dest="topic", default="example_producer_oauth", + help="Topic name") + parser.add_argument('-d', dest="delimiter", default="|", + help="Key-Value delimiter. Defaults to '|'"), + parser.add_argument('--query', dest="query", required=True, + help="Query parameters for Azure IMDS token endpoint") + parser.add_argument('--logical-cluster', dest="logical_cluster", required=False, help="Logical Cluster.") + parser.add_argument('--identity-pool-id', dest="identity_pool_id", required=False, help="Identity Pool ID.") + + main(parser.parse_args()) diff --git a/examples/oauth_schema_registry.py b/examples/oauth_schema_registry.py index cbb9a8fe2..ee0e10537 100644 --- a/examples/oauth_schema_registry.py +++ b/examples/oauth_schema_registry.py @@ -16,7 +16,7 @@ # limitations under the License. # Examples of setting up Schema Registry with OAuth with static token, -# Client Credentials, and custom functions +# Client Credentials, Azure IMDS, and custom functions # CUSTOM OAuth configuration takes in a custom function, config for that @@ -49,6 +49,16 @@ def main(): client_credentials_oauth_sr_client = SchemaRegistryClient(client_credentials_oauth_config) print(client_credentials_oauth_sr_client.get_subjects()) + azure_imds_oauth_config = { + 'url': 'https://psrc-123456.us-east-1.aws.confluent.cloud', + 'bearer.auth.credentials.source': 'OAUTHBEARER_AZURE_IMDS', + 'bearer.auth.issuer.endpoint.query': 'resource=&api-version=&client_id=', + 'bearer.auth.logical.cluster': 'lsrc-12345', + 'bearer.auth.identity.pool.id': 'pool-abcd'} + + azure_imds_oauth_sr_client = SchemaRegistryClient(azure_imds_oauth_config) + print(azure_imds_oauth_sr_client.get_subjects()) + def custom_oauth_function(config): return config diff --git a/src/confluent_kafka/schema_registry/_async/schema_registry_client.py b/src/confluent_kafka/schema_registry/_async/schema_registry_client.py index effd20e4b..17589d2c4 100644 --- a/src/confluent_kafka/schema_registry/_async/schema_registry_client.py +++ b/src/confluent_kafka/schema_registry/_async/schema_registry_client.py @@ -23,6 +23,7 @@ import ssl import time import urllib +import abc from urllib.parse import unquote, urlparse import httpx @@ -36,6 +37,7 @@ from confluent_kafka.schema_registry.common._oauthbearer import ( _BearerFieldProvider, _AbstractOAuthBearerOIDCFieldProviderBuilder, + _AbstractOAuthBearerOIDCAzureIMDSFieldProviderBuilder, _StaticOAuthBearerFieldProviderBuilder, _AbstractCustomOAuthBearerFieldProviderBuilder) from confluent_kafka.schema_registry.error import SchemaRegistryError, OAuthTokenError @@ -76,18 +78,15 @@ async def get_bearer_fields(self) -> dict: return await self.custom_function(self.custom_config) -class _AsyncOAuthClient(_BearerFieldProvider): - def __init__(self, client_id: str, client_secret: str, scope: str, token_endpoint: str, logical_cluster: str, +class _AsyncAbstractOAuthClient(_BearerFieldProvider): + def __init__(self, logical_cluster: str, identity_pool: str, max_retries: int, retries_wait_ms: int, retries_max_wait_ms: int): - self.token = None - self.logical_cluster = logical_cluster - self.identity_pool = identity_pool - self.client = AsyncOAuth2Client(client_id=client_id, client_secret=client_secret, scope=scope) - self.token_endpoint = token_endpoint - self.max_retries = max_retries - self.retries_wait_ms = retries_wait_ms - self.retries_max_wait_ms = retries_max_wait_ms - self.token_expiry_threshold = 0.8 + self.logical_cluster: str = logical_cluster + self.identity_pool: str = identity_pool + self.max_retries: int = max_retries + self.retries_wait_ms: int = retries_wait_ms + self.retries_max_wait_ms: int = retries_max_wait_ms + self.token: str = None async def get_bearer_fields(self) -> dict: return { @@ -96,21 +95,24 @@ async def get_bearer_fields(self) -> dict: 'bearer.auth.identity.pool.id': self.identity_pool } - def token_expired(self) -> bool: - expiry_window = self.token['expires_in'] * self.token_expiry_threshold - - return self.token['expires_at'] < time.time() + expiry_window - async def get_access_token(self) -> str: if not self.token or self.token_expired(): await self.generate_access_token() - return self.token['access_token'] + return self.token + + @abc.abstractmethod + def token_expired(self) -> bool: + raise NotImplementedError + + @abc.abstractmethod + async def fetch_token(self) -> str: + raise NotImplementedError async def generate_access_token(self) -> None: for i in range(self.max_retries + 1): try: - self.token = await self.client.fetch_token(url=self.token_endpoint, grant_type='client_credentials') + self.token = await self.fetch_token() return except Exception as e: if i >= self.max_retries: @@ -119,9 +121,51 @@ async def generate_access_token(self) -> None: await asyncio.sleep(full_jitter(self.retries_wait_ms, self.retries_max_wait_ms, i) / 1000) +class _AsyncOAuthClient(_AsyncAbstractOAuthClient): + def __init__(self, client_id: str, client_secret: str, scope: str, token_endpoint: str, logical_cluster: str, + identity_pool: str, max_retries: int, retries_wait_ms: int, retries_max_wait_ms: int): + super().__init__( + logical_cluster, identity_pool, max_retries, retries_wait_ms, + retries_max_wait_ms) + self.client = AsyncOAuth2Client(client_id=client_id, client_secret=client_secret, scope=scope) + self.token_endpoint: str = token_endpoint + self.token_object: dict = None + self.token_expiry_threshold: float = 0.8 + + def token_expired(self) -> bool: + expiry_window = self.token_object['expires_in'] * self.token_expiry_threshold + return self.token_object['expires_at'] < time.time() + expiry_window + + async def fetch_token(self) -> str: + self.token_object = await self.client.fetch_token(url=self.token_endpoint, grant_type='client_credentials') + return self.token_object['access_token'] + + +class _AsyncOAuthAzureIMDSClient(_AsyncAbstractOAuthClient): + def __init__(self, token_endpoint: str, logical_cluster: str, + identity_pool: str, max_retries: int, retries_wait_ms: int, retries_max_wait_ms: int): + super().__init__( + logical_cluster, identity_pool, max_retries, retries_wait_ms, + retries_max_wait_ms) + self.client = httpx.AsyncClient() + self.token_endpoint: str = token_endpoint + self.token_object: dict = None + self.token_expiry_threshold: float = 0.8 + + def token_expired(self) -> bool: + expiry_window = int(self.token_object['expires_in']) * self.token_expiry_threshold + return int(self.token_object['expires_on']) < time.time() + expiry_window + + async def fetch_token(self) -> str: + self.token_object = await self.client.get(self.token_endpoint, headers=[ + ('Metadata', 'true') + ]).json() + return self.token_object['access_token'] + + class _AsyncOAuthBearerOIDCFieldProviderBuilder(_AbstractOAuthBearerOIDCFieldProviderBuilder): - def build(self, max_retries, retries_wait_ms, retries_max_wait_ms): + def build(self, max_retries: int, retries_wait_ms: int, retries_max_wait_ms: int): self._validate() return _AsyncOAuthClient( self.client_id, self.client_secret, self.scope, @@ -132,9 +176,21 @@ def build(self, max_retries, retries_wait_ms, retries_max_wait_ms): retries_max_wait_ms) +class _AsyncOAuthBearerOIDCAzureIMDSFieldProviderBuilder(_AbstractOAuthBearerOIDCAzureIMDSFieldProviderBuilder): + + def build(self, max_retries: int, retries_wait_ms: int, retries_max_wait_ms: int): + self._validate() + return _AsyncOAuthAzureIMDSClient( + self.token_endpoint, + self.logical_cluster, + self.identity_pool, + max_retries, retries_wait_ms, + retries_max_wait_ms) + + class _AsyncCustomOAuthBearerFieldProviderBuilder(_AbstractCustomOAuthBearerFieldProviderBuilder): - def build(self, max_retries, retries_wait_ms, retries_max_wait_ms): + def build(self, max_retries: int, retries_wait_ms: int, retries_max_wait_ms: int): self._validate() return _AsyncCustomOAuthClient( self.custom_function, @@ -146,12 +202,13 @@ class _AsyncFieldProviderBuilder: __builders = { "OAUTHBEARER": _AsyncOAuthBearerOIDCFieldProviderBuilder, + "OAUTHBEARER_AZURE_IMDS": _AsyncOAuthBearerOIDCAzureIMDSFieldProviderBuilder, "STATIC_TOKEN": _StaticOAuthBearerFieldProviderBuilder, "CUSTOM": _AsyncCustomOAuthBearerFieldProviderBuilder } @staticmethod - def build(conf, max_retries, retries_wait_ms, retries_max_wait_ms): + def build(conf, max_retries: int, retries_wait_ms: int, retries_max_wait_ms: int): bearer_auth_credentials_source = conf.pop('bearer.auth.credentials.source', None) if bearer_auth_credentials_source is None: return [None, None] diff --git a/src/confluent_kafka/schema_registry/_sync/schema_registry_client.py b/src/confluent_kafka/schema_registry/_sync/schema_registry_client.py index 2b3b5d10d..4907ea424 100644 --- a/src/confluent_kafka/schema_registry/_sync/schema_registry_client.py +++ b/src/confluent_kafka/schema_registry/_sync/schema_registry_client.py @@ -23,6 +23,7 @@ import ssl import time import urllib +import abc from urllib.parse import unquote, urlparse import httpx @@ -36,6 +37,7 @@ from confluent_kafka.schema_registry.common._oauthbearer import ( _BearerFieldProvider, _AbstractOAuthBearerOIDCFieldProviderBuilder, + _AbstractOAuthBearerOIDCAzureIMDSFieldProviderBuilder, _StaticOAuthBearerFieldProviderBuilder, _AbstractCustomOAuthBearerFieldProviderBuilder) from confluent_kafka.schema_registry.error import SchemaRegistryError, OAuthTokenError @@ -76,18 +78,15 @@ def get_bearer_fields(self) -> dict: return self.custom_function(self.custom_config) -class _OAuthClient(_BearerFieldProvider): - def __init__(self, client_id: str, client_secret: str, scope: str, token_endpoint: str, logical_cluster: str, +class _AbstractOAuthClient(_BearerFieldProvider): + def __init__(self, logical_cluster: str, identity_pool: str, max_retries: int, retries_wait_ms: int, retries_max_wait_ms: int): - self.token = None - self.logical_cluster = logical_cluster - self.identity_pool = identity_pool - self.client = OAuth2Client(client_id=client_id, client_secret=client_secret, scope=scope) - self.token_endpoint = token_endpoint - self.max_retries = max_retries - self.retries_wait_ms = retries_wait_ms - self.retries_max_wait_ms = retries_max_wait_ms - self.token_expiry_threshold = 0.8 + self.logical_cluster: str = logical_cluster + self.identity_pool: str = identity_pool + self.max_retries: int = max_retries + self.retries_wait_ms: int = retries_wait_ms + self.retries_max_wait_ms: int = retries_max_wait_ms + self.token: str = None def get_bearer_fields(self) -> dict: return { @@ -96,21 +95,24 @@ def get_bearer_fields(self) -> dict: 'bearer.auth.identity.pool.id': self.identity_pool } - def token_expired(self) -> bool: - expiry_window = self.token['expires_in'] * self.token_expiry_threshold - - return self.token['expires_at'] < time.time() + expiry_window - def get_access_token(self) -> str: if not self.token or self.token_expired(): self.generate_access_token() - return self.token['access_token'] + return self.token + + @abc.abstractmethod + def token_expired(self) -> bool: + raise NotImplementedError + + @abc.abstractmethod + def fetch_token(self) -> str: + raise NotImplementedError def generate_access_token(self) -> None: for i in range(self.max_retries + 1): try: - self.token = self.client.fetch_token(url=self.token_endpoint, grant_type='client_credentials') + self.token = self.fetch_token() return except Exception as e: if i >= self.max_retries: @@ -119,9 +121,51 @@ def generate_access_token(self) -> None: time.sleep(full_jitter(self.retries_wait_ms, self.retries_max_wait_ms, i) / 1000) +class _OAuthClient(_AbstractOAuthClient): + def __init__(self, client_id: str, client_secret: str, scope: str, token_endpoint: str, logical_cluster: str, + identity_pool: str, max_retries: int, retries_wait_ms: int, retries_max_wait_ms: int): + super().__init__( + logical_cluster, identity_pool, max_retries, retries_wait_ms, + retries_max_wait_ms) + self.client = OAuth2Client(client_id=client_id, client_secret=client_secret, scope=scope) + self.token_endpoint: str = token_endpoint + self.token_object: dict = None + self.token_expiry_threshold: float = 0.8 + + def token_expired(self) -> bool: + expiry_window = self.token_object['expires_in'] * self.token_expiry_threshold + return self.token_object['expires_at'] < time.time() + expiry_window + + def fetch_token(self) -> str: + self.token_object = self.client.fetch_token(url=self.token_endpoint, grant_type='client_credentials') + return self.token_object['access_token'] + + +class _OAuthAzureIMDSClient(_AbstractOAuthClient): + def __init__(self, token_endpoint: str, logical_cluster: str, + identity_pool: str, max_retries: int, retries_wait_ms: int, retries_max_wait_ms: int): + super().__init__( + logical_cluster, identity_pool, max_retries, retries_wait_ms, + retries_max_wait_ms) + self.client = httpx.Client() + self.token_endpoint: str = token_endpoint + self.token_object: dict = None + self.token_expiry_threshold: float = 0.8 + + def token_expired(self) -> bool: + expiry_window = int(self.token_object['expires_in']) * self.token_expiry_threshold + return int(self.token_object['expires_on']) < time.time() + expiry_window + + def fetch_token(self) -> str: + self.token_object = self.client.get(self.token_endpoint, headers=[ + ('Metadata', 'true') + ]).json() + return self.token_object['access_token'] + + class _OAuthBearerOIDCFieldProviderBuilder(_AbstractOAuthBearerOIDCFieldProviderBuilder): - def build(self, max_retries, retries_wait_ms, retries_max_wait_ms): + def build(self, max_retries: int, retries_wait_ms: int, retries_max_wait_ms: int): self._validate() return _OAuthClient( self.client_id, self.client_secret, self.scope, @@ -132,9 +176,21 @@ def build(self, max_retries, retries_wait_ms, retries_max_wait_ms): retries_max_wait_ms) +class _OAuthBearerOIDCAzureIMDSFieldProviderBuilder(_AbstractOAuthBearerOIDCAzureIMDSFieldProviderBuilder): + + def build(self, max_retries: int, retries_wait_ms: int, retries_max_wait_ms: int): + self._validate() + return _OAuthAzureIMDSClient( + self.token_endpoint, + self.logical_cluster, + self.identity_pool, + max_retries, retries_wait_ms, + retries_max_wait_ms) + + class _CustomOAuthBearerFieldProviderBuilder(_AbstractCustomOAuthBearerFieldProviderBuilder): - def build(self, max_retries, retries_wait_ms, retries_max_wait_ms): + def build(self, max_retries: int, retries_wait_ms: int, retries_max_wait_ms: int): self._validate() return _CustomOAuthClient( self.custom_function, @@ -146,12 +202,13 @@ class _FieldProviderBuilder: __builders = { "OAUTHBEARER": _OAuthBearerOIDCFieldProviderBuilder, + "OAUTHBEARER_AZURE_IMDS": _OAuthBearerOIDCAzureIMDSFieldProviderBuilder, "STATIC_TOKEN": _StaticOAuthBearerFieldProviderBuilder, "CUSTOM": _CustomOAuthBearerFieldProviderBuilder } @staticmethod - def build(conf, max_retries, retries_wait_ms, retries_max_wait_ms): + def build(conf, max_retries: int, retries_wait_ms: int, retries_max_wait_ms: int): bearer_auth_credentials_source = conf.pop('bearer.auth.credentials.source', None) if bearer_auth_credentials_source is None: return [None, None] diff --git a/src/confluent_kafka/schema_registry/common/_oauthbearer.py b/src/confluent_kafka/schema_registry/common/_oauthbearer.py index 5261629c6..1df046b0b 100644 --- a/src/confluent_kafka/schema_registry/common/_oauthbearer.py +++ b/src/confluent_kafka/schema_registry/common/_oauthbearer.py @@ -17,23 +17,31 @@ # import abc +from urllib.parse import urlparse, urlunparse __all__ = [ '_AbstractOAuthBearerFieldProviderBuilder', '_AbstractOAuthBearerOIDCFieldProviderBuilder', '_StaticOAuthBearerFieldProviderBuilder', - '_AbstractCustomOAuthBearerFieldProviderBuilder' + '_AbstractCustomOAuthBearerFieldProviderBuilder', + '_AbstractOAuthBearerOIDCAzureIMDSFieldProviderBuilder' ] +class _BearerFieldProvider(metaclass=abc.ABCMeta): + @abc.abstractmethod + def get_bearer_fields(self) -> dict: + raise NotImplementedError + + class _AbstractOAuthBearerFieldProviderBuilder(metaclass=abc.ABCMeta): """Abstract base class for OAuthBearer client builders""" required_properties = ['bearer.auth.logical.cluster', 'bearer.auth.identity.pool.id'] - def __init__(self, conf): - self.conf = conf - self.logical_cluster = None - self.identity_pool = None + def __init__(self, conf: dict): + self.conf: dict = conf + self.logical_cluster: str = None + self.identity_pool: str = None def _validate(self): missing_properties = [prop for prop in @@ -54,7 +62,7 @@ def _validate(self): str(type(self.identity_pool))) @abc.abstractmethod - def build(self, max_retries, retries_wait_ms, retries_max_wait_ms): + def build(self, max_retries: int, retries_wait_ms: int, retries_max_wait_ms: int) -> _BearerFieldProvider: pass @@ -63,12 +71,12 @@ class _AbstractOAuthBearerOIDCFieldProviderBuilder(_AbstractOAuthBearerFieldProv 'bearer.auth.scope', 'bearer.auth.issuer.endpoint.url'] - def __init__(self, conf): + def __init__(self, conf: dict): super().__init__(conf) - self.client_id = None - self.client_secret = None - self.scope = None - self.token_endpoint = None + self.client_id: str = None + self.client_secret: str = None + self.scope: str = None + self.token_endpoint: str = None def _validate(self): super()._validate() @@ -101,17 +109,46 @@ def _validate(self): + str(type(self.token_endpoint))) -class _BearerFieldProvider(metaclass=abc.ABCMeta): - @abc.abstractmethod - def get_bearer_fields(self) -> dict: - raise NotImplementedError +class _AbstractOAuthBearerOIDCAzureIMDSFieldProviderBuilder(_AbstractOAuthBearerFieldProviderBuilder): + + def __init__(self, conf: dict): + super().__init__(conf) + self.token_endpoint: str = 'http://169.254.169.254/metadata/identity/oauth2/token' + + def _validate(self): + super()._validate() + + token_endpoint_override = 'bearer.auth.issuer.endpoint.url' in self.conf + self.token_endpoint = self.conf.pop('bearer.auth.issuer.endpoint.url', self.token_endpoint) + if not isinstance(self.token_endpoint, str): + raise TypeError("bearer.auth.issuer.endpoint.url must be a str, not " + + str(type(self.token_endpoint))) + + try: + parsed_token_endpoint = urlparse(self.token_endpoint) + except Exception as ex: + raise ValueError(f'Failed to parse token endpoint URL: {ex}') + + token_query = self.conf.pop('bearer.auth.issuer.endpoint.query', None) + if token_query: + if not isinstance(token_query, str): + raise TypeError("bearer.auth.issuer.endpoint.query must be a str, not " + + str(type(token_query))) + + parsed_token_endpoint = parsed_token_endpoint._replace( + query=token_query, + fragment=None) + self.token_endpoint = urlunparse(parsed_token_endpoint) + elif not token_endpoint_override: + raise ValueError("bearer.auth.issuer.endpoint.query must be provided " + "when bearer.auth.issuer.endpoint.url isn overridden") class _StaticFieldProvider(_BearerFieldProvider): def __init__(self, token: str, logical_cluster: str, identity_pool: str): - self.token = token - self.logical_cluster = logical_cluster - self.identity_pool = identity_pool + self.token: str = token + self.logical_cluster: str = logical_cluster + self.identity_pool: str = identity_pool def get_bearer_fields(self) -> dict: return {'bearer.auth.token': self.token, 'bearer.auth.logical.cluster': self.logical_cluster, @@ -123,9 +160,9 @@ class _StaticOAuthBearerFieldProviderBuilder(_AbstractOAuthBearerFieldProviderBu 'bearer.auth.scope', 'bearer.auth.issuer.endpoint.url'] - def __init__(self, conf): + def __init__(self, conf: dict): super().__init__(conf) - self.static_token = None + self.static_token: str = None def _validate(self): super()._validate() @@ -137,7 +174,7 @@ def _validate(self): raise TypeError("bearer.auth.token must be a str, not " + str(type(self.static_token))) - def build(self, max_retries, retries_wait_ms, retries_max_wait_ms): + def build(self, max_retries: int, retries_wait_ms: int, retries_max_wait_ms: int): self._validate() return _StaticFieldProvider( self.static_token, @@ -150,7 +187,7 @@ class _AbstractCustomOAuthBearerFieldProviderBuilder: required_properties = ['bearer.auth.custom.provider.function', 'bearer.auth.custom.provider.config'] - def __init__(self, conf): + def __init__(self, conf: dict): self.conf = conf self.custom_function = None self.custom_config = None diff --git a/tests/schema_registry/_async/test_bearer_field_provider.py b/tests/schema_registry/_async/test_bearer_field_provider.py index cb405ef26..7046913ba 100644 --- a/tests/schema_registry/_async/test_bearer_field_provider.py +++ b/tests/schema_registry/_async/test_bearer_field_provider.py @@ -19,7 +19,9 @@ import time from unittest.mock import AsyncMock, patch -from confluent_kafka.schema_registry._async.schema_registry_client import _AsyncOAuthClient, AsyncSchemaRegistryClient +from confluent_kafka.schema_registry._async.schema_registry_client import ( + _AsyncOAuthClient, _AsyncOAuthAzureIMDSClient, AsyncSchemaRegistryClient +) from confluent_kafka.schema_registry._async.schema_registry_client import _AsyncCustomOAuthClient from confluent_kafka.schema_registry.common._oauthbearer import _StaticFieldProvider from confluent_kafka.schema_registry.error import OAuthTokenError @@ -45,7 +47,7 @@ async def custom_oauth_function(config: dict) -> dict: def test_expiry(): oauth_client = _AsyncOAuthClient('id', 'secret', 'scope', 'endpoint', TEST_CLUSTER, TEST_POOL, 2, 1000, 20000) - oauth_client.token = {'expires_at': time.time() + 2, 'expires_in': 1} + oauth_client.token_object = {'expires_at': time.time() + 2, 'expires_in': 1} assert not oauth_client.token_expired() time.sleep(1.5) assert oauth_client.token_expired() @@ -55,21 +57,49 @@ async def test_get_token(): oauth_client = _AsyncOAuthClient('id', 'secret', 'scope', 'endpoint', TEST_CLUSTER, TEST_POOL, 2, 1000, 20000) def update_token1(): - oauth_client.token = {'expires_at': 0, 'expires_in': 1, 'access_token': '123'} + oauth_client.token_object = {'expires_at': 0, 'expires_in': 1, 'access_token': '123'} + oauth_client.token = '123' def update_token2(): - oauth_client.token = {'expires_at': time.time() + 2, 'expires_in': 1, 'access_token': '1234'} + oauth_client.token_object = {'expires_at': time.time() + 2, 'expires_in': 1, 'access_token': '1234'} + oauth_client.token = '1234' oauth_client.generate_access_token = AsyncMock(side_effect=update_token1) await oauth_client.get_access_token() assert oauth_client.generate_access_token.call_count == 1 - assert oauth_client.token['access_token'] == '123' + assert oauth_client.token_object['access_token'] == '123' oauth_client.generate_access_token = AsyncMock(side_effect=update_token2) await oauth_client.get_access_token() # Call count resets to 1 after reassigning generate_access_token assert oauth_client.generate_access_token.call_count == 1 - assert oauth_client.token['access_token'] == '1234' + assert oauth_client.token_object['access_token'] == '1234' + + await oauth_client.get_access_token() + assert oauth_client.generate_access_token.call_count == 1 + + +async def test_get_token_azure_imds(): + oauth_client = _AsyncOAuthAzureIMDSClient('endpoint', TEST_CLUSTER, TEST_POOL, 2, 1000, 20000) + + def update_token1(): + oauth_client.token_object = {'expires_on': 0, 'expires_in': 1, 'access_token': '123'} + oauth_client.token = '123' + + def update_token2(): + oauth_client.token_object = {'expires_on': time.time() + 2, 'expires_in': 1, 'access_token': '1234'} + oauth_client.token = '1234' + + oauth_client.generate_access_token = AsyncMock(side_effect=update_token1) + await oauth_client.get_access_token() + assert oauth_client.generate_access_token.call_count == 1 + assert oauth_client.token_object['access_token'] == '123' + + oauth_client.generate_access_token = AsyncMock(side_effect=update_token2) + await oauth_client.get_access_token() + # Call count resets to 1 after reassigning generate_access_token + assert oauth_client.generate_access_token.call_count == 1 + assert oauth_client.token_object['access_token'] == '1234' await oauth_client.get_access_token() assert oauth_client.generate_access_token.call_count == 1 diff --git a/tests/schema_registry/_async/test_config.py b/tests/schema_registry/_async/test_config.py index c6ee13fdb..882abb1f8 100644 --- a/tests/schema_registry/_async/test_config.py +++ b/tests/schema_registry/_async/test_config.py @@ -228,6 +228,88 @@ def test_oauth_bearer_config_valid(): assert client._rest_client.bearer_field_provider.token_endpoint == TEST_ENDPOINT +def test_oauth_bearer_azure_imds_config_invalid(): + conf = {'url': TEST_URL, + 'bearer.auth.credentials.source': "OAUTHBEARER_AZURE_IMDS", + 'bearer.auth.logical.cluster': TEST_CLUSTER, + 'bearer.auth.identity.pool.id': 1} + + with pytest.raises(TypeError, match=r"identity pool id must be a str, not (.*)"): + AsyncSchemaRegistryClient(conf) + + conf = {'url': TEST_URL, + 'bearer.auth.credentials.source': "OAUTHBEARER_AZURE_IMDS", + 'bearer.auth.logical.cluster': 1, + 'bearer.auth.identity.pool.id': TEST_POOL} + + with pytest.raises(TypeError, match=r"logical cluster must be a str, not (.*)"): + AsyncSchemaRegistryClient(conf) + + conf = {'url': TEST_URL, + 'bearer.auth.credentials.source': "OAUTHBEARER_AZURE_IMDS", + 'bearer.auth.logical.cluster': TEST_CLUSTER, + 'bearer.auth.identity.pool.id': TEST_POOL, + 'bearer.auth.issuer.endpoint.url': 1} + + with pytest.raises(TypeError, match=r"bearer.auth.issuer.endpoint.url must be a str, not (.*)"): + AsyncSchemaRegistryClient(conf) + + conf = {'url': TEST_URL, + 'bearer.auth.credentials.source': "OAUTHBEARER_AZURE_IMDS", + 'bearer.auth.logical.cluster': TEST_CLUSTER, + 'bearer.auth.identity.pool.id': TEST_POOL, + 'bearer.auth.issuer.endpoint.url': 'http://[wrong_url'} + + with pytest.raises(ValueError, match=r"Failed to parse token endpoint URL: (.*)"): + AsyncSchemaRegistryClient(conf) + + for url in [{'bearer.auth.issuer.endpoint.url': 'http://test'}, {}]: + conf = {'url': TEST_URL, + 'bearer.auth.credentials.source': "OAUTHBEARER_AZURE_IMDS", + 'bearer.auth.logical.cluster': TEST_CLUSTER, + 'bearer.auth.identity.pool.id': TEST_POOL, + 'bearer.auth.issuer.endpoint.query': 1, + **url} + + with pytest.raises(TypeError, match=r"bearer.auth.issuer.endpoint.query must be a str, not (.*)"): + AsyncSchemaRegistryClient(conf) + + conf = {'url': TEST_URL, + 'bearer.auth.credentials.source': "OAUTHBEARER_AZURE_IMDS", + 'bearer.auth.logical.cluster': TEST_CLUSTER, + 'bearer.auth.identity.pool.id': TEST_POOL} + + with pytest.raises(ValueError, match=r"bearer.auth.issuer.endpoint.query must be provided (.*)"): + AsyncSchemaRegistryClient(conf) + + +def test_oauth_bearer_azure_imds_config_valid(): + expected_token_endpoints = { + 'http://alias': 'http://alias', + '': 'http://169.254.169.254/metadata/identity/oauth2/token' + } + query = 'resource=api://test&api-version=2018-02-01' + + for url in [{'bearer.auth.issuer.endpoint.url': 'http://alias'}, {}]: + conf = {'url': TEST_URL, + 'bearer.auth.credentials.source': "OAUTHBEARER_AZURE_IMDS", + 'bearer.auth.logical.cluster': TEST_CLUSTER, + 'bearer.auth.identity.pool.id': TEST_POOL, + 'bearer.auth.issuer.endpoint.query': query, + **url} + + client = AsyncSchemaRegistryClient(conf) + if 'bearer.auth.issuer.endpoint.url' in url: + expected_token_endpoint = expected_token_endpoints[ + url['bearer.auth.issuer.endpoint.url'] + ] + else: + expected_token_endpoint = expected_token_endpoints[''] + expected_token_endpoint += f'?{query}' + + assert client._rest_client.bearer_field_provider.token_endpoint == expected_token_endpoint + + def test_static_bearer_config(): conf = {'url': TEST_URL, 'bearer.auth.credentials.source': 'STATIC_TOKEN', diff --git a/tests/schema_registry/_sync/test_bearer_field_provider.py b/tests/schema_registry/_sync/test_bearer_field_provider.py index 2e0f3e400..6bdd461b8 100644 --- a/tests/schema_registry/_sync/test_bearer_field_provider.py +++ b/tests/schema_registry/_sync/test_bearer_field_provider.py @@ -19,7 +19,9 @@ import time from unittest.mock import Mock, patch -from confluent_kafka.schema_registry._sync.schema_registry_client import _OAuthClient, SchemaRegistryClient +from confluent_kafka.schema_registry._sync.schema_registry_client import ( + _OAuthClient, _OAuthAzureIMDSClient, SchemaRegistryClient +) from confluent_kafka.schema_registry._sync.schema_registry_client import _CustomOAuthClient from confluent_kafka.schema_registry.common._oauthbearer import _StaticFieldProvider from confluent_kafka.schema_registry.error import OAuthTokenError @@ -45,7 +47,7 @@ def custom_oauth_function(config: dict) -> dict: def test_expiry(): oauth_client = _OAuthClient('id', 'secret', 'scope', 'endpoint', TEST_CLUSTER, TEST_POOL, 2, 1000, 20000) - oauth_client.token = {'expires_at': time.time() + 2, 'expires_in': 1} + oauth_client.token_object = {'expires_at': time.time() + 2, 'expires_in': 1} assert not oauth_client.token_expired() time.sleep(1.5) assert oauth_client.token_expired() @@ -55,21 +57,49 @@ def test_get_token(): oauth_client = _OAuthClient('id', 'secret', 'scope', 'endpoint', TEST_CLUSTER, TEST_POOL, 2, 1000, 20000) def update_token1(): - oauth_client.token = {'expires_at': 0, 'expires_in': 1, 'access_token': '123'} + oauth_client.token_object = {'expires_at': 0, 'expires_in': 1, 'access_token': '123'} + oauth_client.token = '123' def update_token2(): - oauth_client.token = {'expires_at': time.time() + 2, 'expires_in': 1, 'access_token': '1234'} + oauth_client.token_object = {'expires_at': time.time() + 2, 'expires_in': 1, 'access_token': '1234'} + oauth_client.token = '1234' oauth_client.generate_access_token = Mock(side_effect=update_token1) oauth_client.get_access_token() assert oauth_client.generate_access_token.call_count == 1 - assert oauth_client.token['access_token'] == '123' + assert oauth_client.token_object['access_token'] == '123' oauth_client.generate_access_token = Mock(side_effect=update_token2) oauth_client.get_access_token() # Call count resets to 1 after reassigning generate_access_token assert oauth_client.generate_access_token.call_count == 1 - assert oauth_client.token['access_token'] == '1234' + assert oauth_client.token_object['access_token'] == '1234' + + oauth_client.get_access_token() + assert oauth_client.generate_access_token.call_count == 1 + + +def test_get_token_azure_imds(): + oauth_client = _OAuthAzureIMDSClient('endpoint', TEST_CLUSTER, TEST_POOL, 2, 1000, 20000) + + def update_token1(): + oauth_client.token_object = {'expires_on': 0, 'expires_in': 1, 'access_token': '123'} + oauth_client.token = '123' + + def update_token2(): + oauth_client.token_object = {'expires_on': time.time() + 2, 'expires_in': 1, 'access_token': '1234'} + oauth_client.token = '1234' + + oauth_client.generate_access_token = Mock(side_effect=update_token1) + oauth_client.get_access_token() + assert oauth_client.generate_access_token.call_count == 1 + assert oauth_client.token_object['access_token'] == '123' + + oauth_client.generate_access_token = Mock(side_effect=update_token2) + oauth_client.get_access_token() + # Call count resets to 1 after reassigning generate_access_token + assert oauth_client.generate_access_token.call_count == 1 + assert oauth_client.token_object['access_token'] == '1234' oauth_client.get_access_token() assert oauth_client.generate_access_token.call_count == 1 diff --git a/tests/schema_registry/_sync/test_config.py b/tests/schema_registry/_sync/test_config.py index c075e869e..e124d9a84 100644 --- a/tests/schema_registry/_sync/test_config.py +++ b/tests/schema_registry/_sync/test_config.py @@ -228,6 +228,88 @@ def test_oauth_bearer_config_valid(): assert client._rest_client.bearer_field_provider.token_endpoint == TEST_ENDPOINT +def test_oauth_bearer_azure_imds_config_invalid(): + conf = {'url': TEST_URL, + 'bearer.auth.credentials.source': "OAUTHBEARER_AZURE_IMDS", + 'bearer.auth.logical.cluster': TEST_CLUSTER, + 'bearer.auth.identity.pool.id': 1} + + with pytest.raises(TypeError, match=r"identity pool id must be a str, not (.*)"): + SchemaRegistryClient(conf) + + conf = {'url': TEST_URL, + 'bearer.auth.credentials.source': "OAUTHBEARER_AZURE_IMDS", + 'bearer.auth.logical.cluster': 1, + 'bearer.auth.identity.pool.id': TEST_POOL} + + with pytest.raises(TypeError, match=r"logical cluster must be a str, not (.*)"): + SchemaRegistryClient(conf) + + conf = {'url': TEST_URL, + 'bearer.auth.credentials.source': "OAUTHBEARER_AZURE_IMDS", + 'bearer.auth.logical.cluster': TEST_CLUSTER, + 'bearer.auth.identity.pool.id': TEST_POOL, + 'bearer.auth.issuer.endpoint.url': 1} + + with pytest.raises(TypeError, match=r"bearer.auth.issuer.endpoint.url must be a str, not (.*)"): + SchemaRegistryClient(conf) + + conf = {'url': TEST_URL, + 'bearer.auth.credentials.source': "OAUTHBEARER_AZURE_IMDS", + 'bearer.auth.logical.cluster': TEST_CLUSTER, + 'bearer.auth.identity.pool.id': TEST_POOL, + 'bearer.auth.issuer.endpoint.url': 'http://[wrong_url'} + + with pytest.raises(ValueError, match=r"Failed to parse token endpoint URL: (.*)"): + SchemaRegistryClient(conf) + + for url in [{'bearer.auth.issuer.endpoint.url': 'http://test'}, {}]: + conf = {'url': TEST_URL, + 'bearer.auth.credentials.source': "OAUTHBEARER_AZURE_IMDS", + 'bearer.auth.logical.cluster': TEST_CLUSTER, + 'bearer.auth.identity.pool.id': TEST_POOL, + 'bearer.auth.issuer.endpoint.query': 1, + **url} + + with pytest.raises(TypeError, match=r"bearer.auth.issuer.endpoint.query must be a str, not (.*)"): + SchemaRegistryClient(conf) + + conf = {'url': TEST_URL, + 'bearer.auth.credentials.source': "OAUTHBEARER_AZURE_IMDS", + 'bearer.auth.logical.cluster': TEST_CLUSTER, + 'bearer.auth.identity.pool.id': TEST_POOL} + + with pytest.raises(ValueError, match=r"bearer.auth.issuer.endpoint.query must be provided (.*)"): + SchemaRegistryClient(conf) + + +def test_oauth_bearer_azure_imds_config_valid(): + expected_token_endpoints = { + 'http://alias': 'http://alias', + '': 'http://169.254.169.254/metadata/identity/oauth2/token' + } + query = 'resource=api://test&api-version=2018-02-01' + + for url in [{'bearer.auth.issuer.endpoint.url': 'http://alias'}, {}]: + conf = {'url': TEST_URL, + 'bearer.auth.credentials.source': "OAUTHBEARER_AZURE_IMDS", + 'bearer.auth.logical.cluster': TEST_CLUSTER, + 'bearer.auth.identity.pool.id': TEST_POOL, + 'bearer.auth.issuer.endpoint.query': query, + **url} + + client = SchemaRegistryClient(conf) + if 'bearer.auth.issuer.endpoint.url' in url: + expected_token_endpoint = expected_token_endpoints[ + url['bearer.auth.issuer.endpoint.url'] + ] + else: + expected_token_endpoint = expected_token_endpoints[''] + expected_token_endpoint += f'?{query}' + + assert client._rest_client.bearer_field_provider.token_endpoint == expected_token_endpoint + + def test_static_bearer_config(): conf = {'url': TEST_URL, 'bearer.auth.credentials.source': 'STATIC_TOKEN', From ab09f2848e4c96ed8be71bb95f1abc41365bd299 Mon Sep 17 00:00:00 2001 From: Emanuele Sabellico Date: Thu, 25 Sep 2025 10:02:55 +0200 Subject: [PATCH 3/4] Add SR serialization to the Azure IMDS producer example --- .../oauth_oidc_ccloud_azure_imds_producer.py | 119 +++++++++++++++--- .../schema_registry/common/_oauthbearer.py | 2 +- 2 files changed, 104 insertions(+), 17 deletions(-) diff --git a/examples/oauth_oidc_ccloud_azure_imds_producer.py b/examples/oauth_oidc_ccloud_azure_imds_producer.py index e46252e36..815622210 100644 --- a/examples/oauth_oidc_ccloud_azure_imds_producer.py +++ b/examples/oauth_oidc_ccloud_azure_imds_producer.py @@ -22,7 +22,52 @@ import logging import argparse from confluent_kafka import Producer -from confluent_kafka.serialization import StringSerializer +from confluent_kafka.schema_registry.json_schema import JSONSerializer +from confluent_kafka.serialization import (StringSerializer, + SerializationContext, MessageField) +from confluent_kafka.schema_registry import SchemaRegistryClient + + +class User(object): + """ + User record + + Args: + name (str): User's name + + favorite_number (int): User's favorite number + + favorite_color (str): User's favorite color + + address(str): User's address; confidential + """ + + def __init__(self, name, address, favorite_number, favorite_color): + self.name = name + self.favorite_number = favorite_number + self.favorite_color = favorite_color + # address should not be serialized, see user_to_dict() + self._address = address + + +def user_to_dict(user, ctx): + """ + Returns a dict representation of a User instance for serialization. + + Args: + user (User): User instance. + + ctx (SerializationContext): Metadata pertaining to the serialization + operation. + + Returns: + dict: Dict populated with user attributes to be serialized. + """ + + # User._address must not be serialized; omit from dict + return dict(name=user.name, + favorite_number=user.favorite_number, + favorite_color=user.favorite_color) def producer_config(args): @@ -45,6 +90,21 @@ def producer_config(args): return params +def schema_registry_config(args): + params = { + 'url': args.schema_registry, + 'bearer.auth.credentials.source': 'OAUTHBEARER_AZURE_IMDS', + 'bearer.auth.issuer.endpoint.query': args.query, + } + # These two parameters are only applicable when producing to + # confluent cloud where some sasl extensions are required. + if args.logical_cluster and args.identity_pool_id: + params['bearer.auth.logical.cluster'] = args.logical_cluster + params['bearer.auth.identity.pool.id'] = args.identity_pool_id + + return params + + def delivery_report(err, msg): """ Reports the failure or success of a message delivery. @@ -72,27 +132,54 @@ def delivery_report(err, msg): def main(args): topic = args.topic - delimiter = args.delimiter producer_conf = producer_config(args) producer = Producer(producer_conf) - serializer = StringSerializer('utf_8') + string_serializer = StringSerializer('utf_8') + schema_str = """ + { + "$schema": "http://json-schema.org/draft-07/schema#", + "title": "User", + "description": "A Confluent Kafka Python User", + "type": "object", + "properties": { + "name": { + "description": "User's name", + "type": "string" + }, + "favorite_number": { + "description": "User's favorite number", + "type": "number", + "exclusiveMinimum": 0 + }, + "favorite_color": { + "description": "User's favorite color", + "type": "string" + } + }, + "required": [ "name", "favorite_number", "favorite_color" ] + } + """ + schema_registry_conf = schema_registry_config(args) + schema_registry_client = SchemaRegistryClient(schema_registry_conf) + + string_serializer = StringSerializer('utf_8') + json_serializer = JSONSerializer(schema_str, schema_registry_client, user_to_dict) print('Producing records to topic {}. ^C to exit.'.format(topic)) while True: # Serve on_delivery callbacks from previous calls to produce() producer.poll(0.0) try: - msg_data = input(">") - msg = msg_data.split(delimiter) - if len(msg) == 2: - producer.produce(topic=topic, - key=serializer(msg[0]), - value=serializer(msg[1]), - on_delivery=delivery_report) - else: - producer.produce(topic=topic, - value=serializer(msg[0]), - on_delivery=delivery_report) + name = input(">") + user = User(name=name, + address="NA", + favorite_color="blue", + favorite_number=7) + serialized_user = json_serializer(user, SerializationContext(topic, MessageField.VALUE)) + producer.produce(topic=topic, + key=string_serializer(name), + value=serialized_user, + on_delivery=delivery_report) except KeyboardInterrupt: break @@ -106,8 +193,8 @@ def main(args): help="Bootstrap broker(s) (host[:port])") parser.add_argument('-t', dest="topic", default="example_producer_oauth", help="Topic name") - parser.add_argument('-d', dest="delimiter", default="|", - help="Key-Value delimiter. Defaults to '|'"), + parser.add_argument('-s', dest="schema_registry", required=True, + help="Schema Registry (http(s)://host[:port]") parser.add_argument('--query', dest="query", required=True, help="Query parameters for Azure IMDS token endpoint") parser.add_argument('--logical-cluster', dest="logical_cluster", required=False, help="Logical Cluster.") diff --git a/src/confluent_kafka/schema_registry/common/_oauthbearer.py b/src/confluent_kafka/schema_registry/common/_oauthbearer.py index 1df046b0b..67160020e 100644 --- a/src/confluent_kafka/schema_registry/common/_oauthbearer.py +++ b/src/confluent_kafka/schema_registry/common/_oauthbearer.py @@ -141,7 +141,7 @@ def _validate(self): self.token_endpoint = urlunparse(parsed_token_endpoint) elif not token_endpoint_override: raise ValueError("bearer.auth.issuer.endpoint.query must be provided " - "when bearer.auth.issuer.endpoint.url isn overridden") + "when bearer.auth.issuer.endpoint.url isn't overridden") class _StaticFieldProvider(_BearerFieldProvider): From 87b68980916b266c180adb35f9000bbd64e7f2df Mon Sep 17 00:00:00 2001 From: Emanuele Sabellico Date: Fri, 26 Sep 2025 21:41:11 +0200 Subject: [PATCH 4/4] Add parameter for logical schema registry cluster --- examples/oauth_oidc_ccloud_azure_imds_producer.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/examples/oauth_oidc_ccloud_azure_imds_producer.py b/examples/oauth_oidc_ccloud_azure_imds_producer.py index 815622210..aa1cf5569 100644 --- a/examples/oauth_oidc_ccloud_azure_imds_producer.py +++ b/examples/oauth_oidc_ccloud_azure_imds_producer.py @@ -98,8 +98,8 @@ def schema_registry_config(args): } # These two parameters are only applicable when producing to # confluent cloud where some sasl extensions are required. - if args.logical_cluster and args.identity_pool_id: - params['bearer.auth.logical.cluster'] = args.logical_cluster + if args.logical_schema_registry_cluster and args.identity_pool_id: + params['bearer.auth.logical.cluster'] = args.logical_schema_registry_cluster params['bearer.auth.identity.pool.id'] = args.identity_pool_id return params @@ -198,6 +198,10 @@ def main(args): parser.add_argument('--query', dest="query", required=True, help="Query parameters for Azure IMDS token endpoint") parser.add_argument('--logical-cluster', dest="logical_cluster", required=False, help="Logical Cluster.") + parser.add_argument('--logical-schema-registry-cluster', + dest="logical_schema_registry_cluster", + required=False, + help="Logical Schema Registry Cluster.") parser.add_argument('--identity-pool-id', dest="identity_pool_id", required=False, help="Identity Pool ID.") main(parser.parse_args())