2323import ssl
2424import time
2525import urllib
26+ import abc
2627from urllib .parse import unquote , urlparse
2728
2829import httpx
3637from confluent_kafka .schema_registry .common ._oauthbearer import (
3738 _BearerFieldProvider ,
3839 _AbstractOAuthBearerOIDCFieldProviderBuilder ,
40+ _AbstractOAuthBearerOIDCAzureIMDSFieldProviderBuilder ,
3941 _StaticOAuthBearerFieldProviderBuilder ,
4042 _AbstractCustomOAuthBearerFieldProviderBuilder )
4143from confluent_kafka .schema_registry .error import SchemaRegistryError , OAuthTokenError
@@ -76,18 +78,15 @@ async def get_bearer_fields(self) -> dict:
7678 return await self .custom_function (self .custom_config )
7779
7880
79- class _AsyncOAuthClient (_BearerFieldProvider ):
80- def __init__ (self , client_id : str , client_secret : str , scope : str , token_endpoint : str , logical_cluster : str ,
81+ class _AsyncAbstractOAuthClient (_BearerFieldProvider ):
82+ def __init__ (self , logical_cluster : str ,
8183 identity_pool : str , max_retries : int , retries_wait_ms : int , retries_max_wait_ms : int ):
82- self .token = None
83- self .logical_cluster = logical_cluster
84- self .identity_pool = identity_pool
85- self .client = AsyncOAuth2Client (client_id = client_id , client_secret = client_secret , scope = scope )
86- self .token_endpoint = token_endpoint
87- self .max_retries = max_retries
88- self .retries_wait_ms = retries_wait_ms
89- self .retries_max_wait_ms = retries_max_wait_ms
90- self .token_expiry_threshold = 0.8
84+ self .logical_cluster : str = logical_cluster
85+ self .identity_pool : str = identity_pool
86+ self .max_retries : int = max_retries
87+ self .retries_wait_ms : int = retries_wait_ms
88+ self .retries_max_wait_ms : int = retries_max_wait_ms
89+ self .token : str = None
9190
9291 async def get_bearer_fields (self ) -> dict :
9392 return {
@@ -96,21 +95,24 @@ async def get_bearer_fields(self) -> dict:
9695 'bearer.auth.identity.pool.id' : self .identity_pool
9796 }
9897
99- def token_expired (self ) -> bool :
100- expiry_window = self .token ['expires_in' ] * self .token_expiry_threshold
101-
102- return self .token ['expires_at' ] < time .time () + expiry_window
103-
10498 async def get_access_token (self ) -> str :
10599 if not self .token or self .token_expired ():
106100 await self .generate_access_token ()
107101
108- return self .token ['access_token' ]
102+ return self .token
103+
104+ @abc .abstractmethod
105+ def token_expired (self ) -> bool :
106+ raise NotImplementedError
107+
108+ @abc .abstractmethod
109+ async def fetch_token (self ) -> str :
110+ raise NotImplementedError
109111
110112 async def generate_access_token (self ) -> None :
111113 for i in range (self .max_retries + 1 ):
112114 try :
113- self .token = await self .client . fetch_token (url = self . token_endpoint , grant_type = 'client_credentials' )
115+ self .token = await self .fetch_token ()
114116 return
115117 except Exception as e :
116118 if i >= self .max_retries :
@@ -119,9 +121,51 @@ async def generate_access_token(self) -> None:
119121 await asyncio .sleep (full_jitter (self .retries_wait_ms , self .retries_max_wait_ms , i ) / 1000 )
120122
121123
124+ class _AsyncOAuthClient (_AsyncAbstractOAuthClient ):
125+ def __init__ (self , client_id : str , client_secret : str , scope : str , token_endpoint : str , logical_cluster : str ,
126+ identity_pool : str , max_retries : int , retries_wait_ms : int , retries_max_wait_ms : int ):
127+ super ().__init__ (
128+ logical_cluster , identity_pool , max_retries , retries_wait_ms ,
129+ retries_max_wait_ms )
130+ self .client = AsyncOAuth2Client (client_id = client_id , client_secret = client_secret , scope = scope )
131+ self .token_endpoint : str = token_endpoint
132+ self .token_object : dict = None
133+ self .token_expiry_threshold : float = 0.8
134+
135+ def token_expired (self ) -> bool :
136+ expiry_window = self .token_object ['expires_in' ] * self .token_expiry_threshold
137+ return self .token_object ['expires_at' ] < time .time () + expiry_window
138+
139+ async def fetch_token (self ) -> str :
140+ self .token_object = await self .client .fetch_token (url = self .token_endpoint , grant_type = 'client_credentials' )
141+ return self .token_object ['access_token' ]
142+
143+
144+ class _AsyncOAuthAzureIMDSClient (_AsyncAbstractOAuthClient ):
145+ def __init__ (self , token_endpoint : str , logical_cluster : str ,
146+ identity_pool : str , max_retries : int , retries_wait_ms : int , retries_max_wait_ms : int ):
147+ super ().__init__ (
148+ logical_cluster , identity_pool , max_retries , retries_wait_ms ,
149+ retries_max_wait_ms )
150+ self .client = httpx .AsyncClient ()
151+ self .token_endpoint : str = token_endpoint
152+ self .token_object : dict = None
153+ self .token_expiry_threshold : float = 0.8
154+
155+ def token_expired (self ) -> bool :
156+ expiry_window = int (self .token_object ['expires_in' ]) * self .token_expiry_threshold
157+ return int (self .token_object ['expires_on' ]) < time .time () + expiry_window
158+
159+ async def fetch_token (self ) -> str :
160+ self .token_object = await self .client .get (self .token_endpoint , headers = [
161+ ('Metadata' , 'true' )
162+ ]).json ()
163+ return self .token_object ['access_token' ]
164+
165+
122166class _AsyncOAuthBearerOIDCFieldProviderBuilder (_AbstractOAuthBearerOIDCFieldProviderBuilder ):
123167
124- def build (self , max_retries , retries_wait_ms , retries_max_wait_ms ):
168+ def build (self , max_retries : int , retries_wait_ms : int , retries_max_wait_ms : int ):
125169 self ._validate ()
126170 return _AsyncOAuthClient (
127171 self .client_id , self .client_secret , self .scope ,
@@ -132,9 +176,21 @@ def build(self, max_retries, retries_wait_ms, retries_max_wait_ms):
132176 retries_max_wait_ms )
133177
134178
179+ class _AsyncOAuthBearerOIDCAzureIMDSFieldProviderBuilder (_AbstractOAuthBearerOIDCAzureIMDSFieldProviderBuilder ):
180+
181+ def build (self , max_retries : int , retries_wait_ms : int , retries_max_wait_ms : int ):
182+ self ._validate ()
183+ return _AsyncOAuthAzureIMDSClient (
184+ self .token_endpoint ,
185+ self .logical_cluster ,
186+ self .identity_pool ,
187+ max_retries , retries_wait_ms ,
188+ retries_max_wait_ms )
189+
190+
135191class _AsyncCustomOAuthBearerFieldProviderBuilder (_AbstractCustomOAuthBearerFieldProviderBuilder ):
136192
137- def build (self , max_retries , retries_wait_ms , retries_max_wait_ms ):
193+ def build (self , max_retries : int , retries_wait_ms : int , retries_max_wait_ms : int ):
138194 self ._validate ()
139195 return _AsyncCustomOAuthClient (
140196 self .custom_function ,
@@ -146,12 +202,13 @@ class _AsyncFieldProviderBuilder:
146202
147203 __builders = {
148204 "OAUTHBEARER" : _AsyncOAuthBearerOIDCFieldProviderBuilder ,
205+ "OAUTHBEARER_AZURE_IMDS" : _AsyncOAuthBearerOIDCAzureIMDSFieldProviderBuilder ,
149206 "STATIC_TOKEN" : _StaticOAuthBearerFieldProviderBuilder ,
150207 "CUSTOM" : _AsyncCustomOAuthBearerFieldProviderBuilder
151208 }
152209
153210 @staticmethod
154- def build (conf , max_retries , retries_wait_ms , retries_max_wait_ms ):
211+ def build (conf , max_retries : int , retries_wait_ms : int , retries_max_wait_ms : int ):
155212 bearer_auth_credentials_source = conf .pop ('bearer.auth.credentials.source' , None )
156213 if bearer_auth_credentials_source is None :
157214 return [None , None ]
0 commit comments