1111import requests
1212
1313from .clock import Clock , RealClock
14- from .credentials_provider import CredentialsProvider , DefaultCredentials
14+ from .credentials_provider import CredentialsStrategy , DefaultCredentials
1515from .environments import (ALL_ENVS , AzureEnvironment , Cloud ,
1616 DatabricksEnvironment , get_environment_for_hostname )
17- from .oauth import OidcEndpoints
17+ from .oauth import OidcEndpoints , Token
1818from .version import __version__
1919
2020logger = logging .getLogger ('databricks.sdk' )
@@ -81,15 +81,25 @@ class Config:
8181
8282 def __init__ (self ,
8383 * ,
84- credentials_provider : CredentialsProvider = None ,
84+ # Deprecated. Use credentials_strategy instead.
85+ credentials_provider : CredentialsStrategy = None ,
86+ credentials_strategy : CredentialsStrategy = None ,
8587 product = "unknown" ,
8688 product_version = "0.0.0" ,
8789 clock : Clock = None ,
8890 ** kwargs ):
8991 self ._header_factory = None
9092 self ._inner = {}
9193 self ._user_agent_other_info = []
92- self ._credentials_provider = credentials_provider if credentials_provider else DefaultCredentials ()
94+ if credentials_strategy and credentials_provider :
95+ raise ValueError (
96+ "When providing `credentials_strategy` field, `credential_provider` cannot be specified." )
97+ if credentials_provider :
98+ logger .warning (
99+ "parameter 'credentials_provider' is deprecated. Use 'credentials_strategy' instead." )
100+ self ._credentials_strategy = next (
101+ s for s in [credentials_strategy , credentials_provider ,
102+ DefaultCredentials ()] if s is not None )
93103 if 'databricks_environment' in kwargs :
94104 self .databricks_environment = kwargs ['databricks_environment' ]
95105 del kwargs ['databricks_environment' ]
@@ -107,6 +117,9 @@ def __init__(self,
107117 message = self .wrap_debug_info (str (e ))
108118 raise ValueError (message ) from e
109119
120+ def oauth_token (self ) -> Token :
121+ return self ._credentials_strategy .oauth_token (self )
122+
110123 def wrap_debug_info (self , message : str ) -> str :
111124 debug_string = self .debug_string ()
112125 if debug_string :
@@ -436,12 +449,12 @@ def _validate(self):
436449
437450 def init_auth (self ):
438451 try :
439- self ._header_factory = self ._credentials_provider (self )
440- self .auth_type = self ._credentials_provider .auth_type ()
452+ self ._header_factory = self ._credentials_strategy (self )
453+ self .auth_type = self ._credentials_strategy .auth_type ()
441454 if not self ._header_factory :
442455 raise ValueError ('not configured' )
443456 except ValueError as e :
444- raise ValueError (f'{ self ._credentials_provider .auth_type ()} auth: { e } ' ) from e
457+ raise ValueError (f'{ self ._credentials_strategy .auth_type ()} auth: { e } ' ) from e
445458
446459 def __repr__ (self ):
447460 return f'<{ self .debug_string ()} >'
0 commit comments