Skip to content

Commit d747ea0

Browse files
committed
Update databricks tests
1 parent a33c153 commit d747ea0

File tree

2 files changed

+106
-124
lines changed

2 files changed

+106
-124
lines changed

src/posit/connect/external/databricks.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131

3232

3333
POSIT_OAUTH_INTEGRATION_AUTH_TYPE = "posit-oauth-integration"
34-
POSIT_LOCAL_CLIENT_CREDENTIALS_AUTH_TYPE = "posit-local-client-credentials"
3534
POSIT_WORKBENCH_AUTH_TYPE = "posit-workbench"
3635

3736
logger = logging.getLogger("posit.sdk")
Lines changed: 106 additions & 123 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
import base64
1+
from __future__ import annotations
2+
23
from unittest.mock import patch
34

45
import pytest
@@ -8,27 +9,43 @@
89
from posit.connect import Client
910
from posit.connect.external.databricks import (
1011
POSIT_OAUTH_INTEGRATION_AUTH_TYPE,
11-
CredentialsProvider,
12-
CredentialsStrategy,
13-
PositContentCredentialsProvider,
14-
PositContentCredentialsStrategy,
15-
PositCredentialsProvider,
16-
PositCredentialsStrategy,
17-
PositLocalContentCredentialsProvider,
18-
PositLocalContentCredentialsStrategy,
19-
_get_auth_type,
12+
POSIT_WORKBENCH_AUTH_TYPE,
13+
ConnectStrategy,
14+
WorkbenchStrategy,
2015
_new_bearer_authorization_header,
16+
_PositConnectContentCredentialsProvider,
17+
_PositConnectViewerCredentialsProvider,
18+
databricks_config,
2119
)
2220
from posit.connect.oauth import Credentials
2321

22+
try:
23+
from databricks.sdk.core import Config, DefaultCredentials
24+
from databricks.sdk.credentials_provider import (
25+
CredentialsProvider,
26+
CredentialsStrategy,
27+
)
28+
29+
# construct a DefaultCredentials CredentialsStrategy
30+
# weirdly, you have to call `__call__()` at least once in order to initialize `auth_type()`
31+
# This is the expected credentials strategy when none is provided to our databricks_config() helper
32+
expected_credentials = DefaultCredentials() # pyright: ignore[reportPossiblyUnboundVariable]
33+
expected_credentials(Config(auth_type="databricks-cli")) # pyright: ignore[reportPossiblyUnboundVariable]
34+
35+
except ImportError:
36+
pytestmark = pytest.mark.skipif(True, reason="requires the Databricks SDK")
37+
38+
39+
class mock_strategy(CredentialsStrategy): # pyright: ignore[reportPossiblyUnboundVariable]
40+
def __init__(self, name: str):
41+
self.name = name
2442

25-
class mock_strategy(CredentialsStrategy):
2643
def auth_type(self) -> str:
27-
return "local"
44+
return self.name
2845

29-
def __call__(self) -> CredentialsProvider:
46+
def __call__(self, *args, **kwargs) -> CredentialsProvider:
3047
def inner() -> Dict[str, str]:
31-
return {"Authorization": "Bearer static-pat-token"}
48+
return {"Authorization": f"Bearer {self.name}"}
3249

3350
return inner
3451

@@ -84,50 +101,14 @@ def test_new_bearer_authorization_header(self):
84101
result = _new_bearer_authorization_header(credential)
85102
assert result == {"Authorization": "Bearer access_token"}
86103

87-
def test_get_auth_type_local(self):
88-
assert _get_auth_type("local-auth") == "local-auth"
89-
90-
@patch.dict("os.environ", {"RSTUDIO_PRODUCT": "CONNECT"})
91-
def test_get_auth_type_connect(self):
92-
assert _get_auth_type("local-auth") == POSIT_OAUTH_INTEGRATION_AUTH_TYPE
93-
94-
@responses.activate
95-
def test_local_content_credentials_provider(self):
96-
token_url = "https://my-token/url"
97-
client_id = "client_id"
98-
client_secret = "client_secret_123"
99-
basic_auth = f"{client_id}:{client_secret}"
100-
b64_basic_auth = base64.b64encode(basic_auth.encode("utf-8")).decode("utf-8")
101-
102-
responses.post(
103-
token_url,
104-
match=[
105-
responses.matchers.urlencoded_params_matcher(
106-
{
107-
"grant_type": "client_credentials",
108-
"scope": "all-apis",
109-
},
110-
),
111-
responses.matchers.header_matcher({"Authorization": f"Basic {b64_basic_auth}"}),
112-
],
113-
json={
114-
"access_token": "oauth2-m2m-access-token",
115-
"token_type": "Bearer",
116-
"expires_in": 3600,
117-
},
118-
)
119-
120-
cp = PositLocalContentCredentialsProvider(token_url, client_id, client_secret)
121-
assert cp() == {"Authorization": "Bearer oauth2-m2m-access-token"}
122-
123104
@patch.dict("os.environ", {"CONNECT_CONTENT_SESSION_TOKEN": "cit"})
124105
@responses.activate
125106
def test_posit_content_credentials_provider(self):
126107
register_mocks()
127108

128109
client = Client(api_key="12345", url="https://connect.example/")
129110
client._ctx.version = None
130-
cp = PositContentCredentialsProvider(client=client)
111+
cp = _PositConnectContentCredentialsProvider(client=client)
131112
assert cp() == {"Authorization": "Bearer content-access-token"}
132113

133114
@responses.activate
@@ -136,95 +117,97 @@ def test_posit_credentials_provider(self):
136117

137118
client = Client(api_key="12345", url="https://connect.example/")
138119
client._ctx.version = None
139-
cp = PositCredentialsProvider(client=client, user_session_token="cit")
120+
cp = _PositConnectViewerCredentialsProvider(client=client, user_session_token="cit")
140121
assert cp() == {"Authorization": "Bearer dynamic-viewer-access-token"}
141122

142-
@responses.activate
143-
def test_local_content_credentials_strategy(self):
144-
token_url = "https://my-token/url"
145-
client_id = "client_id"
146-
client_secret = "client_secret_123"
147-
basic_auth = f"{client_id}:{client_secret}"
148-
b64_basic_auth = base64.b64encode(basic_auth.encode("utf-8")).decode("utf-8")
149-
150-
responses.post(
151-
token_url,
152-
match=[
153-
responses.matchers.urlencoded_params_matcher(
154-
{
155-
"grant_type": "client_credentials",
156-
"scope": "all-apis",
157-
},
158-
),
159-
responses.matchers.header_matcher({"Authorization": f"Basic {b64_basic_auth}"}),
160-
],
161-
json={
162-
"access_token": "oauth2-m2m-access-token",
163-
"token_type": "Bearer",
164-
"expires_in": 3600,
165-
},
166-
)
123+
def test_workbench_strategy(self):
124+
# default will attempt to load the workbench profile
125+
with pytest.raises(ValueError, match="profile=workbench"):
126+
WorkbenchStrategy()
167127

168-
cs = PositLocalContentCredentialsStrategy(
169-
token_url,
170-
client_id,
171-
client_secret,
128+
# providing a Config is allowed
129+
cs = WorkbenchStrategy(
130+
config=Config(host="https://databricks.com/workspace", token="token") # pyright: ignore[reportPossiblyUnboundVariable]
172131
)
132+
assert cs.auth_type() == POSIT_WORKBENCH_AUTH_TYPE
173133
cp = cs()
174-
assert cs.auth_type() == "posit-local-client-credentials"
175-
assert cp() == {"Authorization": "Bearer oauth2-m2m-access-token"}
176134

135+
# token from the Config is passed through to the auth header
136+
assert cp() == {"Authorization": "Bearer token"}
137+
138+
@patch.dict("os.environ", {"RSTUDIO_PRODUCT": "CONNECT"})
177139
@patch.dict("os.environ", {"CONNECT_CONTENT_SESSION_TOKEN": "cit"})
178140
@responses.activate
179-
@patch.dict("os.environ", {"RSTUDIO_PRODUCT": "CONNECT"})
180-
def test_posit_content_credentials_strategy(self):
141+
def test_connect_strategy(self):
181142
register_mocks()
182-
183143
client = Client(api_key="12345", url="https://connect.example/")
184144
client._ctx.version = None
185-
cs = PositContentCredentialsStrategy(
186-
local_strategy=mock_strategy(),
187-
client=client,
188-
)
145+
146+
# the default implementation uses Service Account authentication
147+
cs = ConnectStrategy(client=client)
148+
assert cs.auth_type() == POSIT_OAUTH_INTEGRATION_AUTH_TYPE
189149
cp = cs()
190-
assert cs.auth_type() == "posit-oauth-integration"
191150
assert cp() == {"Authorization": "Bearer content-access-token"}
192151

193-
@responses.activate
194-
@patch.dict("os.environ", {"RSTUDIO_PRODUCT": "CONNECT"})
195-
def test_posit_credentials_strategy(self):
196-
register_mocks()
197-
198-
client = Client(api_key="12345", url="https://connect.example/")
199-
client._ctx.version = None
200-
cs = PositCredentialsStrategy(
201-
local_strategy=mock_strategy(),
202-
user_session_token="cit",
203-
client=client,
204-
)
152+
# if a session token is provided then Viewer auth is used
153+
cs = ConnectStrategy(client=client, user_session_token="cit")
205154
cp = cs()
206-
assert cs.auth_type() == "posit-oauth-integration"
207155
assert cp() == {"Authorization": "Bearer dynamic-viewer-access-token"}
208156

209-
def test_posit_content_credentials_strategy_fallback(self):
210-
# local_strategy is used when the content is running locally
211-
client = Client(api_key="12345", url="https://connect.example/")
212-
cs = PositContentCredentialsStrategy(
213-
local_strategy=mock_strategy(),
214-
client=client,
157+
def test_databricks_config(self):
158+
# credentials_strategy is removed if it is provided
159+
cfg = databricks_config(credentials_strategy=mock_strategy("mock"))
160+
assert cfg._credentials_strategy is not None
161+
assert cfg._credentials_strategy.auth_type() != "mock"
162+
163+
# kwargs are passed through to the Config() constructor
164+
cfg = databricks_config(
165+
host="https://databricks.com",
166+
cluster_id="cluster_id",
167+
warehouse_id="warehouse_id",
168+
token="token",
215169
)
216-
cp = cs()
217-
assert cs.auth_type() == "local"
218-
assert cp() == {"Authorization": "Bearer static-pat-token"}
170+
assert cfg.host == "https://databricks.com"
171+
assert cfg.cluster_id == "cluster_id"
172+
assert cfg.warehouse_id == "warehouse_id"
173+
assert cfg.token == "token"
174+
175+
def test_databricks_config_default(self):
176+
cfg = databricks_config(
177+
posit_default_strategy=mock_strategy("default"),
178+
posit_workbench_strategy=mock_strategy("workbench"),
179+
posit_connect_strategy=mock_strategy("connect"),
180+
)
181+
assert cfg._credentials_strategy.auth_type() == "default"
182+
183+
# default fallback defaults to DefaultCredentials() when none is provided
184+
cfg = databricks_config(auth_type="databricks-cli")
185+
assert cfg._credentials_strategy.auth_type() == expected_credentials.auth_type()
186+
187+
@patch.dict("os.environ", {"RS_SERVER_ADDRESS": "https://workbench.posit.co/"})
188+
def test_databricks_config_workbench(self):
189+
cfg = databricks_config(
190+
posit_default_strategy=mock_strategy("default"),
191+
posit_workbench_strategy=mock_strategy("workbench"),
192+
posit_connect_strategy=mock_strategy("connect"),
193+
)
194+
assert cfg._credentials_strategy.auth_type() == "workbench"
219195

220-
def test_posit_credentials_strategy_fallback(self):
221-
# local_strategy is used when the content is running locally
222-
client = Client(api_key="12345", url="https://connect.example/")
223-
cs = PositCredentialsStrategy(
224-
local_strategy=mock_strategy(),
225-
user_session_token="cit",
226-
client=client,
196+
# workbench defaults to DefaultCredentials() when none is provided
197+
cfg = databricks_config(auth_type="databricks-cli")
198+
assert cfg._credentials_strategy.auth_type() == expected_credentials.auth_type()
199+
200+
@patch.dict("os.environ", {"CONNECT_API_KEY": "API_KEY"})
201+
@patch.dict("os.environ", {"CONNECT_SERVER": "https://connect.posit.co/"})
202+
@patch.dict("os.environ", {"RSTUDIO_PRODUCT": "CONNECT"})
203+
def test_databricks_config_connect(self):
204+
cfg = databricks_config(
205+
posit_default_strategy=mock_strategy("default"),
206+
posit_workbench_strategy=mock_strategy("workbench"),
207+
posit_connect_strategy=mock_strategy("connect"),
227208
)
228-
cp = cs()
229-
assert cs.auth_type() == "local"
230-
assert cp() == {"Authorization": "Bearer static-pat-token"}
209+
assert cfg._credentials_strategy.auth_type() == "connect"
210+
211+
# connect defaults to ConnectStrategy() when none is provided
212+
cfg = databricks_config()
213+
assert cfg._credentials_strategy.auth_type() == ConnectStrategy().auth_type()

0 commit comments

Comments
 (0)