99from typing import List , Optional
1010
1111import oauthlib .oauth2
12- import requests
1312from oauthlib .oauth2 .rfc6749 .errors import OAuth2Error
14- from requests .exceptions import RequestException
15- from databricks .sql .common .http import HttpMethod , DatabricksHttpClient , HttpHeader
13+ from databricks .sql .common .http import HttpMethod , HttpHeader
1614from databricks .sql .common .http import OAuthResponse
1715from databricks .sql .auth .oauth_http_handler import OAuthHttpSingleRequestHandler
1816from databricks .sql .auth .endpoint import OAuthEndpointCollection
@@ -63,33 +61,19 @@ def refresh(self) -> Token:
6361 pass
6462
6563
66- class IgnoreNetrcAuth (requests .auth .AuthBase ):
67- """This auth method is a no-op.
68-
69- We use it to force requestslib to not use .netrc to write auth headers
70- when making .post() requests to the oauth token endpoints, since these
71- don't require authentication.
72-
73- In cases where .netrc is outdated or corrupt, these requests will fail.
74-
75- See issue #121
76- """
77-
78- def __call__ (self , r ):
79- return r
80-
81-
8264class OAuthManager :
8365 def __init__ (
8466 self ,
8567 port_range : List [int ],
8668 client_id : str ,
8769 idp_endpoint : OAuthEndpointCollection ,
70+ http_client ,
8871 ):
8972 self .port_range = port_range
9073 self .client_id = client_id
9174 self .redirect_port = None
9275 self .idp_endpoint = idp_endpoint
76+ self .http_client = http_client
9377
9478 @staticmethod
9579 def __token_urlsafe (nbytes = 32 ):
@@ -103,8 +87,11 @@ def __fetch_well_known_config(self, hostname: str):
10387 known_config_url = self .idp_endpoint .get_openid_config_url (hostname )
10488
10589 try :
106- response = requests .get (url = known_config_url , auth = IgnoreNetrcAuth ())
107- except RequestException as e :
90+ response = self .http_client .request (HttpMethod .GET , url = known_config_url )
91+ # Convert urllib3 response to requests-like response for compatibility
92+ response .status_code = response .status
93+ response .json = lambda : json .loads (response .data .decode ())
94+ except Exception as e :
10895 logger .error (
10996 f"Unable to fetch OAuth configuration from { known_config_url } .\n "
11097 "Verify it is a valid workspace URL and that OAuth is "
@@ -122,7 +109,7 @@ def __fetch_well_known_config(self, hostname: str):
122109 raise RuntimeError (msg )
123110 try :
124111 return response .json ()
125- except requests . exceptions . JSONDecodeError as e :
112+ except Exception as e :
126113 logger .error (
127114 f"Unable to decode OAuth configuration from { known_config_url } .\n "
128115 "Verify it is a valid workspace URL and that OAuth is "
@@ -203,16 +190,17 @@ def __send_auth_code_token_request(
203190 data = f"{ token_request_body } &code_verifier={ verifier } "
204191 return self .__send_token_request (token_request_url , data )
205192
206- @staticmethod
207- def __send_token_request (token_request_url , data ):
193+ def __send_token_request (self , token_request_url , data ):
208194 headers = {
209195 "Accept" : "application/json" ,
210196 "Content-Type" : "application/x-www-form-urlencoded" ,
211197 }
212- response = requests .post (
213- url = token_request_url , data = data , headers = headers , auth = IgnoreNetrcAuth ()
198+ # Use unified HTTP client
199+ response = self .http_client .request (
200+ HttpMethod .POST , url = token_request_url , body = data , headers = headers
214201 )
215- return response .json ()
202+ # Convert urllib3 response to dict for compatibility
203+ return json .loads (response .data .decode ())
216204
217205 def __send_refresh_token_request (self , hostname , refresh_token ):
218206 oauth_config = self .__fetch_well_known_config (hostname )
@@ -221,7 +209,7 @@ def __send_refresh_token_request(self, hostname, refresh_token):
221209 token_request_body = client .prepare_refresh_body (
222210 refresh_token = refresh_token , client_id = client .client_id
223211 )
224- return OAuthManager .__send_token_request (token_request_url , token_request_body )
212+ return self .__send_token_request (token_request_url , token_request_body )
225213
226214 @staticmethod
227215 def __get_tokens_from_response (oauth_response ):
@@ -320,14 +308,15 @@ def __init__(
320308 token_url ,
321309 client_id ,
322310 client_secret ,
311+ http_client ,
323312 extra_params : dict = {},
324313 ):
325314 self .client_id = client_id
326315 self .client_secret = client_secret
327316 self .token_url = token_url
328317 self .extra_params = extra_params
329318 self .token : Optional [Token ] = None
330- self ._http_client = DatabricksHttpClient . get_instance ()
319+ self ._http_client = http_client
331320
332321 def get_token (self ) -> Token :
333322 if self .token is None or self .token .is_expired ():
@@ -348,17 +337,17 @@ def refresh(self) -> Token:
348337 }
349338 )
350339
351- with self ._http_client .execute (
352- method = HttpMethod .POST , url = self .token_url , headers = headers , data = data
353- ) as response :
354- if response .status_code == 200 :
355- oauth_response = OAuthResponse (** response .json ( ))
356- return Token (
357- oauth_response .access_token ,
358- oauth_response .token_type ,
359- oauth_response .refresh_token ,
360- )
361- else :
362- raise Exception (
363- f"Failed to get token: { response .status_code } { response .text } "
364- )
340+ response = self ._http_client .request (
341+ method = HttpMethod .POST , url = self .token_url , headers = headers , body = data
342+ )
343+ if response .status == 200 :
344+ oauth_response = OAuthResponse (** json . loads ( response .data . decode ( "utf-8" ) ))
345+ return Token (
346+ oauth_response .access_token ,
347+ oauth_response .token_type ,
348+ oauth_response .refresh_token ,
349+ )
350+ else :
351+ raise Exception (
352+ f"Failed to get token: { response .status } { response .data . decode ( 'utf-8' ) } "
353+ )
0 commit comments