22from typing import Callable , Dict , Optional
33
44from ..client import Client
5+ from ..oauth import Credentials
56from .external import is_local
67
78"""
89NOTE: These APIs are provided as a convenience and are subject to breaking changes:
910https://github.com/databricks/databricks-sdk-py#interface-stability
1011"""
1112
13+ POSIT_OAUTH_INTEGRATION_AUTH_TYPE = "posit-oauth-integration"
14+
1215# The Databricks SDK CredentialsProvider == Databricks SQL HeaderFactory
1316CredentialsProvider = Callable [[], Dict [str , str ]]
1417
15-
1618class CredentialsStrategy (abc .ABC ):
1719 """Maintain compatibility with the Databricks SQL/SDK client libraries.
1820
@@ -28,20 +30,74 @@ def auth_type(self) -> str:
2830 def __call__ (self , * args , ** kwargs ) -> CredentialsProvider :
2931 raise NotImplementedError
3032
31- # TODO: Refactor common behavior across different cred providers.
33+
34+ def _new_bearer_authorization_header (credentials : Credentials ) -> Dict [str , str ]:
35+ """Helper to transform an Credentials object into the Bearer auth header consumed by databricks.
36+
37+ Raises
38+ ------
39+ ValueError: If provided Credentials object does not contain an access token
40+
41+ Returns
42+ -------
43+ Dict[str, str]
44+ """
45+ access_token = credentials .get ("access_token" )
46+ if access_token is None :
47+ raise ValueError ("Missing value for field 'access_token' in credentials." )
48+ return {"Authorization" : f"Bearer { access_token } " }
49+
50+ def _get_auth_type (local_auth_type : str ) -> str :
51+ """Returns the auth type currently in use.
52+
53+ The databricks-sdk client uses the configurated auth_type to create
54+ a user-agent string which is used for attribution. We should only
55+ overwrite the auth_type if we are using the PositCredentialsStrategy (non-local),
56+ otherwise, we should return the auth_type of the configured local_strategy instead
57+ to avoid breaking someone elses attribution.
58+
59+ https://github.com/databricks/databricks-sdk-py/blob/v0.29.0/databricks/sdk/config.py#L261-L269
60+
61+ NOTE: The databricks-sql client does not use auth_type to set the user-agent.
62+ https://github.com/databricks/databricks-sql-python/blob/v3.3.0/src/databricks/sql/client.py#L214-L219
63+
64+ Returns
65+ -------
66+ str
67+ """
68+ if is_local ():
69+ return local_auth_type
70+
71+ return POSIT_OAUTH_INTEGRATION_AUTH_TYPE
72+
73+
3274
3375class PositContentCredentialsProvider :
76+ """CredentialsProvider implementation which initiates a credential exchange using a content-session-token."""
77+
3478 def __init__ (self , client : Client ):
3579 self ._client = client
3680
3781 def __call__ (self ) -> Dict [str , str ]:
3882 credentials = self ._client .oauth .get_content_credentials ()
39- access_token = credentials .get ("access_token" )
40- if access_token is None :
41- raise ValueError ("Missing value for field 'access_token' in credentials." )
42- return {"Authorization" : f"Bearer { access_token } " }
83+ return _new_bearer_authorization_header (credentials )
84+
85+
86+ class PositCredentialsProvider :
87+ """CredentialsProvider implementation which initiates a credential exchange using a user-session-token."""
88+
89+ def __init__ (self , client : Client , user_session_token : str ):
90+ self ._client = client
91+ self ._user_session_token = user_session_token
92+
93+ def __call__ (self ) -> Dict [str , str ]:
94+ credentials = self ._client .oauth .get_credentials (self ._user_session_token )
95+ return _new_bearer_authorization_header (credentials )
96+
97+
98+ class PositContentCredentialsStrategy (CredentialsStrategy ):
99+ """CredentialsStrategy implementation which returns a PositContentCredentialsProvider when called."""
43100
44- class PositContentCredentialsStrategy :
45101 def __init__ (
46102 self ,
47103 local_strategy : CredentialsStrategy ,
@@ -51,15 +107,22 @@ def __init__(
51107 self ._client = client
52108
53109 def sql_credentials_provider (self , * args , ** kwargs ):
110+ """The sql connector attempts to call the credentials provider w/o any args.
111+
112+ The SQL client's `ExternalAuthProvider` is not compatible w/ the SDK's implementation of
113+ `CredentialsProvider`, so create a no-arg lambda that wraps the args defined by the real caller.
114+ This way we can pass in a databricks `Config` object required by most of the SDK's `CredentialsProvider`
115+ implementations from where `sql.connect` is called.
116+
117+ https://github.com/databricks/databricks-sql-python/issues/148#issuecomment-2271561365
118+ """
54119 return lambda : self .__call__ (* args , ** kwargs )
55120
56121 def auth_type (self ) -> str :
57- if is_local ():
58- return self ._local_strategy .auth_type ()
59- else :
60- return "posit-oauth-integration"
122+ return _get_auth_type (self ._local_strategy .auth_type ())
61123
62124 def __call__ (self , * args , ** kwargs ) -> CredentialsProvider :
125+ # If the content is not running on Connect then fall back to local_strategy
63126 if is_local ():
64127 return self ._local_strategy (* args , ** kwargs )
65128
@@ -69,20 +132,9 @@ def __call__(self, *args, **kwargs) -> CredentialsProvider:
69132 return PositContentCredentialsProvider (self ._client )
70133
71134
72- class PositCredentialsProvider :
73- def __init__ (self , client : Client , user_session_token : str ):
74- self ._client = client
75- self ._user_session_token = user_session_token
76-
77- def __call__ (self ) -> Dict [str , str ]:
78- credentials = self ._client .oauth .get_credentials (self ._user_session_token )
79- access_token = credentials .get ("access_token" )
80- if access_token is None :
81- raise ValueError ("Missing value for field 'access_token' in credentials." )
82- return {"Authorization" : f"Bearer { access_token } " }
83-
84-
85135class PositCredentialsStrategy (CredentialsStrategy ):
136+ """CredentialsStrategy implementation which returns a PositContentCredentialsProvider when called."""
137+
86138 def __init__ (
87139 self ,
88140 local_strategy : CredentialsStrategy ,
@@ -106,23 +158,7 @@ def sql_credentials_provider(self, *args, **kwargs):
106158 return lambda : self .__call__ (* args , ** kwargs )
107159
108160 def auth_type (self ) -> str :
109- """Returns the auth type currently in use.
110-
111- The databricks-sdk client uses the configurated auth_type to create
112- a user-agent string which is used for attribution. We should only
113- overwrite the auth_type if we are using the PositCredentialsStrategy (non-local),
114- otherwise, we should return the auth_type of the configured local_strategy instead
115- to avoid breaking someone elses attribution.
116-
117- https://github.com/databricks/databricks-sdk-py/blob/v0.29.0/databricks/sdk/config.py#L261-L269
118-
119- NOTE: The databricks-sql client does not use auth_type to set the user-agent.
120- https://github.com/databricks/databricks-sql-python/blob/v3.3.0/src/databricks/sql/client.py#L214-L219
121- """
122- if is_local ():
123- return self ._local_strategy .auth_type ()
124- else :
125- return "posit-oauth-integration"
161+ return _get_auth_type (self ._local_strategy .auth_type ())
126162
127163 def __call__ (self , * args , ** kwargs ) -> CredentialsProvider :
128164 # If the content is not running on Connect then fall back to local_strategy
0 commit comments