Skip to content

Commit ccc80a6

Browse files
committed
OAuth Azure IMDB implementation
1 parent 6aea1a5 commit ccc80a6

File tree

11 files changed

+613
-77
lines changed

11 files changed

+613
-77
lines changed

CHANGELOG.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,16 @@
11
# Confluent's Python client for Apache Kafka
22

3+
## v2.12.0
4+
5+
v2.12.0 is a feature release with the following enhancements:
6+
7+
- OAuth/OIDC metadata based authentication with Azure IMDS (#).
8+
9+
confluent-kafka-python v2.12.0 is based on librdkafka v2.12.0, see the
10+
[librdkafka release notes](https://github.com/confluentinc/librdkafka/releases/tag/v2.12.0)
11+
for a complete list of changes, enhancements, fixes and upgrade considerations.
12+
13+
314
## v2.11.1
415

516
v2.11.1 is a maintenance release with the following fixes:

examples/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ The scripts in this directory provide various examples of using Confluent's Pyth
1818
Additional examples for [Confluent Cloud](https://www.confluent.io/confluent-cloud/):
1919

2020
* [confluent_cloud.py](confluent_cloud.py): Produce messages to Confluent Cloud and then read them back again.
21+
* [oauth_oidc_ccloud_producer.py](oauth_oidc_ccloud_producer.py): Demonstrates OAuth/OIDC Authentication with Confluent Cloud (client credentials).
22+
* [oauth_oidc_ccloud_azure_imds_producer.py](oauth_oidc_ccloud_azure_imds_producer.py): Demonstrates OAuth/OIDC Authentication with Confluent Cloud (Azure IMDS metadata based authentication).
2123
* [confluentinc/examples](https://github.com/confluentinc/examples/tree/master/clients/cloud/python): Integration with Confluent Cloud and Confluent Cloud Schema Registry
2224

2325
## venv setup
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
#
4+
# Copyright 2025 Confluent Inc.
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License.
8+
# You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
18+
19+
# This example use Azure IMDS for credential-less authentication
20+
# through to Schema Registry on Confluent Cloud
21+
22+
import logging
23+
import argparse
24+
from confluent_kafka import Producer
25+
from confluent_kafka.serialization import StringSerializer
26+
27+
28+
def producer_config(args):
29+
logger = logging.getLogger(__name__)
30+
logger.setLevel(logging.DEBUG)
31+
params = {
32+
'bootstrap.servers': args.bootstrap_servers,
33+
'security.protocol': 'SASL_SSL',
34+
'sasl.mechanisms': 'OAUTHBEARER',
35+
'sasl.oauthbearer.method': 'oidc',
36+
'sasl.oauthbearer.metadata.authentication.type': 'azure_imds',
37+
'sasl.oauthbearer.config': f'query={args.query}'
38+
}
39+
# These two parameters are only applicable when producing to
40+
# confluent cloud where some sasl extensions are required.
41+
if args.logical_cluster and args.identity_pool_id:
42+
params['sasl.oauthbearer.extensions'] = 'logicalCluster=' + args.logical_cluster + \
43+
',identityPoolId=' + args.identity_pool_id
44+
45+
return params
46+
47+
48+
def delivery_report(err, msg):
49+
"""
50+
Reports the failure or success of a message delivery.
51+
52+
Args:
53+
err (KafkaError): The error that occurred on None on success.
54+
55+
msg (Message): The message that was produced or failed.
56+
57+
Note:
58+
In the delivery report callback the Message.key() and Message.value()
59+
will be the binary format as encoded by any configured Serializers and
60+
not the same object that was passed to produce().
61+
If you wish to pass the original object(s) for key and value to delivery
62+
report callback we recommend a bound callback or lambda where you pass
63+
the objects along.
64+
65+
"""
66+
if err is not None:
67+
print('Delivery failed for User record {}: {}'.format(msg.key(), err))
68+
return
69+
print('User record {} successfully produced to {} [{}] at offset {}'.format(
70+
msg.key(), msg.topic(), msg.partition(), msg.offset()))
71+
72+
73+
def main(args):
74+
topic = args.topic
75+
delimiter = args.delimiter
76+
producer_conf = producer_config(args)
77+
producer = Producer(producer_conf)
78+
serializer = StringSerializer('utf_8')
79+
80+
print('Producing records to topic {}. ^C to exit.'.format(topic))
81+
while True:
82+
# Serve on_delivery callbacks from previous calls to produce()
83+
producer.poll(0.0)
84+
try:
85+
msg_data = input(">")
86+
msg = msg_data.split(delimiter)
87+
if len(msg) == 2:
88+
producer.produce(topic=topic,
89+
key=serializer(msg[0]),
90+
value=serializer(msg[1]),
91+
on_delivery=delivery_report)
92+
else:
93+
producer.produce(topic=topic,
94+
value=serializer(msg[0]),
95+
on_delivery=delivery_report)
96+
except KeyboardInterrupt:
97+
break
98+
99+
print('\nFlushing {} records...'.format(len(producer)))
100+
producer.flush()
101+
102+
103+
if __name__ == '__main__':
104+
parser = argparse.ArgumentParser(description="OAUTH example with client credentials grant")
105+
parser.add_argument('-b', dest="bootstrap_servers", required=True,
106+
help="Bootstrap broker(s) (host[:port])")
107+
parser.add_argument('-t', dest="topic", default="example_producer_oauth",
108+
help="Topic name")
109+
parser.add_argument('-d', dest="delimiter", default="|",
110+
help="Key-Value delimiter. Defaults to '|'"),
111+
parser.add_argument('--query', dest="query", required=True,
112+
help="Query parameters for Azure IMDS token endpoint")
113+
parser.add_argument('--logical-cluster', dest="logical_cluster", required=False, help="Logical Cluster.")
114+
parser.add_argument('--identity-pool-id', dest="identity_pool_id", required=False, help="Identity Pool ID.")
115+
116+
main(parser.parse_args())

examples/oauth_schema_registry.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
# limitations under the License.
1717

1818
# Examples of setting up Schema Registry with OAuth with static token,
19-
# Client Credentials, and custom functions
19+
# Client Credentials, Azure IMDS, and custom functions
2020

2121

2222
# CUSTOM OAuth configuration takes in a custom function, config for that
@@ -49,6 +49,16 @@ def main():
4949
client_credentials_oauth_sr_client = SchemaRegistryClient(client_credentials_oauth_config)
5050
print(client_credentials_oauth_sr_client.get_subjects())
5151

52+
azure_imds_oauth_config = {
53+
'url': 'https://psrc-123456.us-east-1.aws.confluent.cloud',
54+
'bearer.auth.credentials.source': 'OAUTHBEARER_AZURE_IMDS',
55+
'bearer.auth.issuer.endpoint.query': 'resource=&api-version=&client_id=',
56+
'bearer.auth.logical.cluster': 'lsrc-12345',
57+
'bearer.auth.identity.pool.id': 'pool-abcd'}
58+
59+
azure_imds_oauth_sr_client = SchemaRegistryClient(azure_imds_oauth_config)
60+
print(azure_imds_oauth_sr_client.get_subjects())
61+
5262
def custom_oauth_function(config):
5363
return config
5464

src/confluent_kafka/schema_registry/_async/schema_registry_client.py

Lines changed: 78 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import ssl
2424
import time
2525
import urllib
26+
import abc
2627
from urllib.parse import unquote, urlparse
2728

2829
import httpx
@@ -36,6 +37,7 @@
3637
from confluent_kafka.schema_registry.common._oauthbearer import (
3738
_BearerFieldProvider,
3839
_AbstractOAuthBearerOIDCFieldProviderBuilder,
40+
_AbstractOAuthBearerOIDCAzureIMDSFieldProviderBuilder,
3941
_StaticOAuthBearerFieldProviderBuilder,
4042
_AbstractCustomOAuthBearerFieldProviderBuilder)
4143
from confluent_kafka.schema_registry.error import SchemaRegistryError, OAuthTokenError
@@ -76,18 +78,15 @@ async def get_bearer_fields(self) -> dict:
7678
return await self.custom_function(self.custom_config)
7779

7880

79-
class _AsyncOAuthClient(_BearerFieldProvider):
80-
def __init__(self, client_id: str, client_secret: str, scope: str, token_endpoint: str, logical_cluster: str,
81+
class _AsyncAbstractOAuthClient(_BearerFieldProvider):
82+
def __init__(self, logical_cluster: str,
8183
identity_pool: str, max_retries: int, retries_wait_ms: int, retries_max_wait_ms: int):
82-
self.token = None
83-
self.logical_cluster = logical_cluster
84-
self.identity_pool = identity_pool
85-
self.client = AsyncOAuth2Client(client_id=client_id, client_secret=client_secret, scope=scope)
86-
self.token_endpoint = token_endpoint
87-
self.max_retries = max_retries
88-
self.retries_wait_ms = retries_wait_ms
89-
self.retries_max_wait_ms = retries_max_wait_ms
90-
self.token_expiry_threshold = 0.8
84+
self.logical_cluster: str = logical_cluster
85+
self.identity_pool: str = identity_pool
86+
self.max_retries: int = max_retries
87+
self.retries_wait_ms: int = retries_wait_ms
88+
self.retries_max_wait_ms: int = retries_max_wait_ms
89+
self.token: str = None
9190

9291
async def get_bearer_fields(self) -> dict:
9392
return {
@@ -96,21 +95,24 @@ async def get_bearer_fields(self) -> dict:
9695
'bearer.auth.identity.pool.id': self.identity_pool
9796
}
9897

99-
def token_expired(self) -> bool:
100-
expiry_window = self.token['expires_in'] * self.token_expiry_threshold
101-
102-
return self.token['expires_at'] < time.time() + expiry_window
103-
10498
async def get_access_token(self) -> str:
10599
if not self.token or self.token_expired():
106100
await self.generate_access_token()
107101

108-
return self.token['access_token']
102+
return self.token
103+
104+
@abc.abstractmethod
105+
def token_expired(self) -> bool:
106+
raise NotImplementedError
107+
108+
@abc.abstractmethod
109+
async def fetch_token(self) -> str:
110+
raise NotImplementedError
109111

110112
async def generate_access_token(self) -> None:
111113
for i in range(self.max_retries + 1):
112114
try:
113-
self.token = await self.client.fetch_token(url=self.token_endpoint, grant_type='client_credentials')
115+
self.token = await self.fetch_token()
114116
return
115117
except Exception as e:
116118
if i >= self.max_retries:
@@ -119,9 +121,51 @@ async def generate_access_token(self) -> None:
119121
await asyncio.sleep(full_jitter(self.retries_wait_ms, self.retries_max_wait_ms, i) / 1000)
120122

121123

124+
class _AsyncOAuthClient(_AsyncAbstractOAuthClient):
125+
def __init__(self, client_id: str, client_secret: str, scope: str, token_endpoint: str, logical_cluster: str,
126+
identity_pool: str, max_retries: int, retries_wait_ms: int, retries_max_wait_ms: int):
127+
super().__init__(
128+
logical_cluster, identity_pool, max_retries, retries_wait_ms,
129+
retries_max_wait_ms)
130+
self.client = AsyncOAuth2Client(client_id=client_id, client_secret=client_secret, scope=scope)
131+
self.token_endpoint: str = token_endpoint
132+
self.token_object: dict = None
133+
self.token_expiry_threshold: float = 0.8
134+
135+
def token_expired(self) -> bool:
136+
expiry_window = self.token_object['expires_in'] * self.token_expiry_threshold
137+
return self.token_object['expires_at'] < time.time() + expiry_window
138+
139+
async def fetch_token(self) -> str:
140+
self.token_object = await self.client.fetch_token(url=self.token_endpoint, grant_type='client_credentials')
141+
return self.token_object['access_token']
142+
143+
144+
class _AsyncOAuthAzureIMDSClient(_AsyncAbstractOAuthClient):
145+
def __init__(self, token_endpoint: str, logical_cluster: str,
146+
identity_pool: str, max_retries: int, retries_wait_ms: int, retries_max_wait_ms: int):
147+
super().__init__(
148+
logical_cluster, identity_pool, max_retries, retries_wait_ms,
149+
retries_max_wait_ms)
150+
self.client = httpx.AsyncClient()
151+
self.token_endpoint: str = token_endpoint
152+
self.token_object: dict = None
153+
self.token_expiry_threshold: float = 0.8
154+
155+
def token_expired(self) -> bool:
156+
expiry_window = int(self.token_object['expires_in']) * self.token_expiry_threshold
157+
return int(self.token_object['expires_on']) < time.time() + expiry_window
158+
159+
async def fetch_token(self) -> str:
160+
self.token_object = await self.client.get(self.token_endpoint, headers=[
161+
('Metadata', 'true')
162+
]).json()
163+
return self.token_object['access_token']
164+
165+
122166
class _AsyncOAuthBearerOIDCFieldProviderBuilder(_AbstractOAuthBearerOIDCFieldProviderBuilder):
123167

124-
def build(self, max_retries, retries_wait_ms, retries_max_wait_ms):
168+
def build(self, max_retries: int, retries_wait_ms: int, retries_max_wait_ms: int):
125169
self._validate()
126170
return _AsyncOAuthClient(
127171
self.client_id, self.client_secret, self.scope,
@@ -132,9 +176,21 @@ def build(self, max_retries, retries_wait_ms, retries_max_wait_ms):
132176
retries_max_wait_ms)
133177

134178

179+
class _AsyncOAuthBearerOIDCAzureIMDSFieldProviderBuilder(_AbstractOAuthBearerOIDCAzureIMDSFieldProviderBuilder):
180+
181+
def build(self, max_retries: int, retries_wait_ms: int, retries_max_wait_ms: int):
182+
self._validate()
183+
return _AsyncOAuthAzureIMDSClient(
184+
self.token_endpoint,
185+
self.logical_cluster,
186+
self.identity_pool,
187+
max_retries, retries_wait_ms,
188+
retries_max_wait_ms)
189+
190+
135191
class _AsyncCustomOAuthBearerFieldProviderBuilder(_AbstractCustomOAuthBearerFieldProviderBuilder):
136192

137-
def build(self, max_retries, retries_wait_ms, retries_max_wait_ms):
193+
def build(self, max_retries: int, retries_wait_ms: int, retries_max_wait_ms: int):
138194
self._validate()
139195
return _AsyncCustomOAuthClient(
140196
self.custom_function,
@@ -146,12 +202,13 @@ class _AsyncFieldProviderBuilder:
146202

147203
__builders = {
148204
"OAUTHBEARER": _AsyncOAuthBearerOIDCFieldProviderBuilder,
205+
"OAUTHBEARER_AZURE_IMDS": _AsyncOAuthBearerOIDCAzureIMDSFieldProviderBuilder,
149206
"STATIC_TOKEN": _StaticOAuthBearerFieldProviderBuilder,
150207
"CUSTOM": _AsyncCustomOAuthBearerFieldProviderBuilder
151208
}
152209

153210
@staticmethod
154-
def build(conf, max_retries, retries_wait_ms, retries_max_wait_ms):
211+
def build(conf, max_retries: int, retries_wait_ms: int, retries_max_wait_ms: int):
155212
bearer_auth_credentials_source = conf.pop('bearer.auth.credentials.source', None)
156213
if bearer_auth_credentials_source is None:
157214
return [None, None]

0 commit comments

Comments
 (0)