2323import  ssl 
2424import  time 
2525import  urllib 
26+ import  abc 
2627from  urllib .parse  import  unquote , urlparse 
2728
2829import  httpx 
3637from  confluent_kafka .schema_registry .common ._oauthbearer  import  (
3738    _BearerFieldProvider ,
3839    _AbstractOAuthBearerOIDCFieldProviderBuilder ,
40+     _AbstractOAuthBearerOIDCAzureIMDSFieldProviderBuilder ,
3941    _StaticOAuthBearerFieldProviderBuilder ,
4042    _AbstractCustomOAuthBearerFieldProviderBuilder )
4143from  confluent_kafka .schema_registry .error  import  SchemaRegistryError , OAuthTokenError 
@@ -76,18 +78,15 @@ async def get_bearer_fields(self) -> dict:
7678        return  await  self .custom_function (self .custom_config )
7779
7880
79- class  _AsyncOAuthClient (_BearerFieldProvider ):
80-     def  __init__ (self , client_id :  str ,  client_secret :  str ,  scope :  str ,  token_endpoint :  str ,  logical_cluster : str ,
81+ class  _AsyncAbstractOAuthClient (_BearerFieldProvider ):
82+     def  __init__ (self , logical_cluster : str ,
8183                 identity_pool : str , max_retries : int , retries_wait_ms : int , retries_max_wait_ms : int ):
82-         self .token  =  None 
83-         self .logical_cluster  =  logical_cluster 
84-         self .identity_pool  =  identity_pool 
85-         self .client  =  AsyncOAuth2Client (client_id = client_id , client_secret = client_secret , scope = scope )
86-         self .token_endpoint  =  token_endpoint 
87-         self .max_retries  =  max_retries 
88-         self .retries_wait_ms  =  retries_wait_ms 
89-         self .retries_max_wait_ms  =  retries_max_wait_ms 
90-         self .token_expiry_threshold  =  0.8 
84+         self .logical_cluster : str  =  logical_cluster 
85+         self .identity_pool : str  =  identity_pool 
86+         self .max_retries : int  =  max_retries 
87+         self .retries_wait_ms : int  =  retries_wait_ms 
88+         self .retries_max_wait_ms : int  =  retries_max_wait_ms 
89+         self .token : str  =  None 
9190
9291    async  def  get_bearer_fields (self ) ->  dict :
9392        return  {
@@ -96,21 +95,24 @@ async def get_bearer_fields(self) -> dict:
9695            'bearer.auth.identity.pool.id' : self .identity_pool 
9796        }
9897
99-     def  token_expired (self ) ->  bool :
100-         expiry_window  =  self .token ['expires_in' ] *  self .token_expiry_threshold 
101- 
102-         return  self .token ['expires_at' ] <  time .time () +  expiry_window 
103- 
10498    async  def  get_access_token (self ) ->  str :
10599        if  not  self .token  or  self .token_expired ():
106100            await  self .generate_access_token ()
107101
108-         return  self .token ['access_token' ]
102+         return  self .token 
103+ 
104+     @abc .abstractmethod  
105+     def  token_expired (self ) ->  bool :
106+         raise  NotImplementedError 
107+ 
108+     @abc .abstractmethod  
109+     async  def  fetch_token (self ) ->  str :
110+         raise  NotImplementedError 
109111
110112    async  def  generate_access_token (self ) ->  None :
111113        for  i  in  range (self .max_retries  +  1 ):
112114            try :
113-                 self .token  =  await  self .client . fetch_token (url = self . token_endpoint ,  grant_type = 'client_credentials' )
115+                 self .token  =  await  self .fetch_token ()
114116                return 
115117            except  Exception  as  e :
116118                if  i  >=  self .max_retries :
@@ -119,9 +121,51 @@ async def generate_access_token(self) -> None:
119121                await  asyncio .sleep (full_jitter (self .retries_wait_ms , self .retries_max_wait_ms , i ) /  1000 )
120122
121123
124+ class  _AsyncOAuthClient (_AsyncAbstractOAuthClient ):
125+     def  __init__ (self , client_id : str , client_secret : str , scope : str , token_endpoint : str , logical_cluster : str ,
126+                  identity_pool : str , max_retries : int , retries_wait_ms : int , retries_max_wait_ms : int ):
127+         super ().__init__ (
128+             logical_cluster , identity_pool , max_retries , retries_wait_ms ,
129+             retries_max_wait_ms )
130+         self .client  =  AsyncOAuth2Client (client_id = client_id , client_secret = client_secret , scope = scope )
131+         self .token_endpoint : str  =  token_endpoint 
132+         self .token_object : dict  =  None 
133+         self .token_expiry_threshold : float  =  0.8 
134+ 
135+     def  token_expired (self ) ->  bool :
136+         expiry_window  =  self .token_object ['expires_in' ] *  self .token_expiry_threshold 
137+         return  self .token_object ['expires_at' ] <  time .time () +  expiry_window 
138+ 
139+     async  def  fetch_token (self ) ->  str :
140+         self .token_object  =  await  self .client .fetch_token (url = self .token_endpoint , grant_type = 'client_credentials' )
141+         return  self .token_object ['access_token' ]
142+ 
143+ 
144+ class  _AsyncOAuthAzureIMDSClient (_AsyncAbstractOAuthClient ):
145+     def  __init__ (self , token_endpoint : str , logical_cluster : str ,
146+                  identity_pool : str , max_retries : int , retries_wait_ms : int , retries_max_wait_ms : int ):
147+         super ().__init__ (
148+             logical_cluster , identity_pool , max_retries , retries_wait_ms ,
149+             retries_max_wait_ms )
150+         self .client  =  httpx .AsyncClient ()
151+         self .token_endpoint : str  =  token_endpoint 
152+         self .token_object : dict  =  None 
153+         self .token_expiry_threshold : float  =  0.8 
154+ 
155+     def  token_expired (self ) ->  bool :
156+         expiry_window  =  int (self .token_object ['expires_in' ]) *  self .token_expiry_threshold 
157+         return  int (self .token_object ['expires_on' ]) <  time .time () +  expiry_window 
158+ 
159+     async  def  fetch_token (self ) ->  str :
160+         self .token_object  =  await  self .client .get (self .token_endpoint , headers = [
161+             ('Metadata' , 'true' )
162+         ]).json ()
163+         return  self .token_object ['access_token' ]
164+ 
165+ 
122166class  _AsyncOAuthBearerOIDCFieldProviderBuilder (_AbstractOAuthBearerOIDCFieldProviderBuilder ):
123167
124-     def  build (self , max_retries , retries_wait_ms , retries_max_wait_ms ):
168+     def  build (self , max_retries :  int , retries_wait_ms :  int , retries_max_wait_ms :  int ):
125169        self ._validate ()
126170        return  _AsyncOAuthClient (
127171            self .client_id , self .client_secret , self .scope ,
@@ -132,9 +176,21 @@ def build(self, max_retries, retries_wait_ms, retries_max_wait_ms):
132176            retries_max_wait_ms )
133177
134178
179+ class  _AsyncOAuthBearerOIDCAzureIMDSFieldProviderBuilder (_AbstractOAuthBearerOIDCAzureIMDSFieldProviderBuilder ):
180+ 
181+     def  build (self , max_retries : int , retries_wait_ms : int , retries_max_wait_ms : int ):
182+         self ._validate ()
183+         return  _AsyncOAuthAzureIMDSClient (
184+             self .token_endpoint ,
185+             self .logical_cluster ,
186+             self .identity_pool ,
187+             max_retries , retries_wait_ms ,
188+             retries_max_wait_ms )
189+ 
190+ 
135191class  _AsyncCustomOAuthBearerFieldProviderBuilder (_AbstractCustomOAuthBearerFieldProviderBuilder ):
136192
137-     def  build (self , max_retries , retries_wait_ms , retries_max_wait_ms ):
193+     def  build (self , max_retries :  int , retries_wait_ms :  int , retries_max_wait_ms :  int ):
138194        self ._validate ()
139195        return  _AsyncCustomOAuthClient (
140196            self .custom_function ,
@@ -146,12 +202,13 @@ class _AsyncFieldProviderBuilder:
146202
147203    __builders  =  {
148204        "OAUTHBEARER" : _AsyncOAuthBearerOIDCFieldProviderBuilder ,
205+         "OAUTHBEARER_AZURE_IMDS" : _AsyncOAuthBearerOIDCAzureIMDSFieldProviderBuilder ,
149206        "STATIC_TOKEN" : _StaticOAuthBearerFieldProviderBuilder ,
150207        "CUSTOM" : _AsyncCustomOAuthBearerFieldProviderBuilder 
151208    }
152209
153210    @staticmethod  
154-     def  build (conf , max_retries , retries_wait_ms , retries_max_wait_ms ):
211+     def  build (conf , max_retries :  int , retries_wait_ms :  int , retries_max_wait_ms :  int ):
155212        bearer_auth_credentials_source  =  conf .pop ('bearer.auth.credentials.source' , None )
156213        if  bearer_auth_credentials_source  is  None :
157214            return  [None , None ]
0 commit comments