1- import base64
1+ from __future__ import annotations
2+
23from unittest .mock import patch
34
45import pytest
89from posit .connect import Client
910from posit .connect .external .databricks import (
1011 POSIT_OAUTH_INTEGRATION_AUTH_TYPE ,
11- CredentialsProvider ,
12- CredentialsStrategy ,
13- PositContentCredentialsProvider ,
14- PositContentCredentialsStrategy ,
15- PositCredentialsProvider ,
16- PositCredentialsStrategy ,
17- PositLocalContentCredentialsProvider ,
18- PositLocalContentCredentialsStrategy ,
19- _get_auth_type ,
12+ POSIT_WORKBENCH_AUTH_TYPE ,
13+ ConnectStrategy ,
14+ WorkbenchStrategy ,
2015 _new_bearer_authorization_header ,
16+ _PositConnectContentCredentialsProvider ,
17+ _PositConnectViewerCredentialsProvider ,
18+ databricks_config ,
2119)
2220from posit .connect .oauth import Credentials
2321
22+ try :
23+ from databricks .sdk .core import Config , DefaultCredentials
24+ from databricks .sdk .credentials_provider import (
25+ CredentialsProvider ,
26+ CredentialsStrategy ,
27+ )
28+
29+ # construct a DefaultCredentials CredentialsStrategy
30+ # weirdly, you have to call `__call__()` at least once in order to initialize `auth_type()`
31+ # This is the expected credentials strategy when none is provided to our databricks_config() helper
32+ expected_credentials = DefaultCredentials () # pyright: ignore[reportPossiblyUnboundVariable]
33+ expected_credentials (Config (auth_type = "databricks-cli" )) # pyright: ignore[reportPossiblyUnboundVariable]
34+
35+ except ImportError :
36+ pytestmark = pytest .mark .skipif (True , reason = "requires the Databricks SDK" )
37+
38+
39+ class mock_strategy (CredentialsStrategy ): # pyright: ignore[reportPossiblyUnboundVariable]
40+ def __init__ (self , name : str ):
41+ self .name = name
2442
25- class mock_strategy (CredentialsStrategy ):
2643 def auth_type (self ) -> str :
27- return "local"
44+ return self . name
2845
29- def __call__ (self ) -> CredentialsProvider :
46+ def __call__ (self , * args , ** kwargs ) -> CredentialsProvider :
3047 def inner () -> Dict [str , str ]:
31- return {"Authorization" : "Bearer static-pat-token " }
48+ return {"Authorization" : f "Bearer { self . name } " }
3249
3350 return inner
3451
@@ -84,50 +101,14 @@ def test_new_bearer_authorization_header(self):
84101 result = _new_bearer_authorization_header (credential )
85102 assert result == {"Authorization" : "Bearer access_token" }
86103
87- def test_get_auth_type_local (self ):
88- assert _get_auth_type ("local-auth" ) == "local-auth"
89-
90- @patch .dict ("os.environ" , {"RSTUDIO_PRODUCT" : "CONNECT" })
91- def test_get_auth_type_connect (self ):
92- assert _get_auth_type ("local-auth" ) == POSIT_OAUTH_INTEGRATION_AUTH_TYPE
93-
94- @responses .activate
95- def test_local_content_credentials_provider (self ):
96- token_url = "https://my-token/url"
97- client_id = "client_id"
98- client_secret = "client_secret_123"
99- basic_auth = f"{ client_id } :{ client_secret } "
100- b64_basic_auth = base64 .b64encode (basic_auth .encode ("utf-8" )).decode ("utf-8" )
101-
102- responses .post (
103- token_url ,
104- match = [
105- responses .matchers .urlencoded_params_matcher (
106- {
107- "grant_type" : "client_credentials" ,
108- "scope" : "all-apis" ,
109- },
110- ),
111- responses .matchers .header_matcher ({"Authorization" : f"Basic { b64_basic_auth } " }),
112- ],
113- json = {
114- "access_token" : "oauth2-m2m-access-token" ,
115- "token_type" : "Bearer" ,
116- "expires_in" : 3600 ,
117- },
118- )
119-
120- cp = PositLocalContentCredentialsProvider (token_url , client_id , client_secret )
121- assert cp () == {"Authorization" : "Bearer oauth2-m2m-access-token" }
122-
123104 @patch .dict ("os.environ" , {"CONNECT_CONTENT_SESSION_TOKEN" : "cit" })
124105 @responses .activate
125106 def test_posit_content_credentials_provider (self ):
126107 register_mocks ()
127108
128109 client = Client (api_key = "12345" , url = "https://connect.example/" )
129110 client ._ctx .version = None
130- cp = PositContentCredentialsProvider (client = client )
111+ cp = _PositConnectContentCredentialsProvider (client = client )
131112 assert cp () == {"Authorization" : "Bearer content-access-token" }
132113
133114 @responses .activate
@@ -136,95 +117,97 @@ def test_posit_credentials_provider(self):
136117
137118 client = Client (api_key = "12345" , url = "https://connect.example/" )
138119 client ._ctx .version = None
139- cp = PositCredentialsProvider (client = client , user_session_token = "cit" )
120+ cp = _PositConnectViewerCredentialsProvider (client = client , user_session_token = "cit" )
140121 assert cp () == {"Authorization" : "Bearer dynamic-viewer-access-token" }
141122
142- @responses .activate
143- def test_local_content_credentials_strategy (self ):
144- token_url = "https://my-token/url"
145- client_id = "client_id"
146- client_secret = "client_secret_123"
147- basic_auth = f"{ client_id } :{ client_secret } "
148- b64_basic_auth = base64 .b64encode (basic_auth .encode ("utf-8" )).decode ("utf-8" )
149-
150- responses .post (
151- token_url ,
152- match = [
153- responses .matchers .urlencoded_params_matcher (
154- {
155- "grant_type" : "client_credentials" ,
156- "scope" : "all-apis" ,
157- },
158- ),
159- responses .matchers .header_matcher ({"Authorization" : f"Basic { b64_basic_auth } " }),
160- ],
161- json = {
162- "access_token" : "oauth2-m2m-access-token" ,
163- "token_type" : "Bearer" ,
164- "expires_in" : 3600 ,
165- },
166- )
123+ def test_workbench_strategy (self ):
124+ # default will attempt to load the workbench profile
125+ with pytest .raises (ValueError , match = "profile=workbench" ):
126+ WorkbenchStrategy ()
167127
168- cs = PositLocalContentCredentialsStrategy (
169- token_url ,
170- client_id ,
171- client_secret ,
128+ # providing a Config is allowed
129+ cs = WorkbenchStrategy (
130+ config = Config (host = "https://databricks.com/workspace" , token = "token" ) # pyright: ignore[reportPossiblyUnboundVariable]
172131 )
132+ assert cs .auth_type () == POSIT_WORKBENCH_AUTH_TYPE
173133 cp = cs ()
174- assert cs .auth_type () == "posit-local-client-credentials"
175- assert cp () == {"Authorization" : "Bearer oauth2-m2m-access-token" }
176134
135+ # token from the Config is passed through to the auth header
136+ assert cp () == {"Authorization" : "Bearer token" }
137+
138+ @patch .dict ("os.environ" , {"RSTUDIO_PRODUCT" : "CONNECT" })
177139 @patch .dict ("os.environ" , {"CONNECT_CONTENT_SESSION_TOKEN" : "cit" })
178140 @responses .activate
179- @patch .dict ("os.environ" , {"RSTUDIO_PRODUCT" : "CONNECT" })
180- def test_posit_content_credentials_strategy (self ):
141+ def test_connect_strategy (self ):
181142 register_mocks ()
182-
183143 client = Client (api_key = "12345" , url = "https://connect.example/" )
184144 client ._ctx .version = None
185- cs = PositContentCredentialsStrategy (
186- local_strategy = mock_strategy (),
187- client = client ,
188- )
145+
146+ # the default implementation uses Service Account authentication
147+ cs = ConnectStrategy ( client = client )
148+ assert cs . auth_type () == POSIT_OAUTH_INTEGRATION_AUTH_TYPE
189149 cp = cs ()
190- assert cs .auth_type () == "posit-oauth-integration"
191150 assert cp () == {"Authorization" : "Bearer content-access-token" }
192151
193- @responses .activate
194- @patch .dict ("os.environ" , {"RSTUDIO_PRODUCT" : "CONNECT" })
195- def test_posit_credentials_strategy (self ):
196- register_mocks ()
197-
198- client = Client (api_key = "12345" , url = "https://connect.example/" )
199- client ._ctx .version = None
200- cs = PositCredentialsStrategy (
201- local_strategy = mock_strategy (),
202- user_session_token = "cit" ,
203- client = client ,
204- )
152+ # if a session token is provided then Viewer auth is used
153+ cs = ConnectStrategy (client = client , user_session_token = "cit" )
205154 cp = cs ()
206- assert cs .auth_type () == "posit-oauth-integration"
207155 assert cp () == {"Authorization" : "Bearer dynamic-viewer-access-token" }
208156
209- def test_posit_content_credentials_strategy_fallback (self ):
210- # local_strategy is used when the content is running locally
211- client = Client (api_key = "12345" , url = "https://connect.example/" )
212- cs = PositContentCredentialsStrategy (
213- local_strategy = mock_strategy (),
214- client = client ,
157+ def test_databricks_config (self ):
158+ # credentials_strategy is removed if it is provided
159+ cfg = databricks_config (credentials_strategy = mock_strategy ("mock" ))
160+ assert cfg ._credentials_strategy is not None
161+ assert cfg ._credentials_strategy .auth_type () != "mock"
162+
163+ # kwargs are passed through to the Config() constructor
164+ cfg = databricks_config (
165+ host = "https://databricks.com" ,
166+ cluster_id = "cluster_id" ,
167+ warehouse_id = "warehouse_id" ,
168+ token = "token" ,
215169 )
216- cp = cs ()
217- assert cs .auth_type () == "local"
218- assert cp () == {"Authorization" : "Bearer static-pat-token" }
170+ assert cfg .host == "https://databricks.com"
171+ assert cfg .cluster_id == "cluster_id"
172+ assert cfg .warehouse_id == "warehouse_id"
173+ assert cfg .token == "token"
174+
175+ def test_databricks_config_default (self ):
176+ cfg = databricks_config (
177+ posit_default_strategy = mock_strategy ("default" ),
178+ posit_workbench_strategy = mock_strategy ("workbench" ),
179+ posit_connect_strategy = mock_strategy ("connect" ),
180+ )
181+ assert cfg ._credentials_strategy .auth_type () == "default"
182+
183+ # default fallback defaults to DefaultCredentials() when none is provided
184+ cfg = databricks_config (auth_type = "databricks-cli" )
185+ assert cfg ._credentials_strategy .auth_type () == expected_credentials .auth_type ()
186+
187+ @patch .dict ("os.environ" , {"RS_SERVER_ADDRESS" : "https://workbench.posit.co/" })
188+ def test_databricks_config_workbench (self ):
189+ cfg = databricks_config (
190+ posit_default_strategy = mock_strategy ("default" ),
191+ posit_workbench_strategy = mock_strategy ("workbench" ),
192+ posit_connect_strategy = mock_strategy ("connect" ),
193+ )
194+ assert cfg ._credentials_strategy .auth_type () == "workbench"
219195
220- def test_posit_credentials_strategy_fallback (self ):
221- # local_strategy is used when the content is running locally
222- client = Client (api_key = "12345" , url = "https://connect.example/" )
223- cs = PositCredentialsStrategy (
224- local_strategy = mock_strategy (),
225- user_session_token = "cit" ,
226- client = client ,
196+ # workbench defaults to DefaultCredentials() when none is provided
197+ cfg = databricks_config (auth_type = "databricks-cli" )
198+ assert cfg ._credentials_strategy .auth_type () == expected_credentials .auth_type ()
199+
200+ @patch .dict ("os.environ" , {"CONNECT_API_KEY" : "API_KEY" })
201+ @patch .dict ("os.environ" , {"CONNECT_SERVER" : "https://connect.posit.co/" })
202+ @patch .dict ("os.environ" , {"RSTUDIO_PRODUCT" : "CONNECT" })
203+ def test_databricks_config_connect (self ):
204+ cfg = databricks_config (
205+ posit_default_strategy = mock_strategy ("default" ),
206+ posit_workbench_strategy = mock_strategy ("workbench" ),
207+ posit_connect_strategy = mock_strategy ("connect" ),
227208 )
228- cp = cs ()
229- assert cs .auth_type () == "local"
230- assert cp () == {"Authorization" : "Bearer static-pat-token" }
209+ assert cfg ._credentials_strategy .auth_type () == "connect"
210+
211+ # connect defaults to ConnectStrategy() when none is provided
212+ cfg = databricks_config ()
213+ assert cfg ._credentials_strategy .auth_type () == ConnectStrategy ().auth_type ()
0 commit comments