Skip to content

Commit 08796e2

Browse files
committed
test coverage, linting
1 parent 67edfbb commit 08796e2

File tree

5 files changed

+230
-45
lines changed

5 files changed

+230
-45
lines changed

src/posit/connect/external/databricks.py

Lines changed: 77 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,19 @@
22
from typing import Callable, Dict, Optional
33

44
from ..client import Client
5+
from ..oauth import Credentials
56
from .external import is_local
67

78
"""
89
NOTE: These APIs are provided as a convenience and are subject to breaking changes:
910
https://github.com/databricks/databricks-sdk-py#interface-stability
1011
"""
1112

13+
POSIT_OAUTH_INTEGRATION_AUTH_TYPE = "posit-oauth-integration"
14+
1215
# The Databricks SDK CredentialsProvider == Databricks SQL HeaderFactory
1316
CredentialsProvider = Callable[[], Dict[str, str]]
1417

15-
1618
class CredentialsStrategy(abc.ABC):
1719
"""Maintain compatibility with the Databricks SQL/SDK client libraries.
1820
@@ -28,20 +30,74 @@ def auth_type(self) -> str:
2830
def __call__(self, *args, **kwargs) -> CredentialsProvider:
2931
raise NotImplementedError
3032

31-
# TODO: Refactor common behavior across different cred providers.
33+
34+
def _new_bearer_authorization_header(credentials: Credentials) -> Dict[str, str]:
35+
"""Helper to transform an Credentials object into the Bearer auth header consumed by databricks.
36+
37+
Raises
38+
------
39+
ValueError: If provided Credentials object does not contain an access token
40+
41+
Returns
42+
-------
43+
Dict[str, str]
44+
"""
45+
access_token = credentials.get("access_token")
46+
if access_token is None:
47+
raise ValueError("Missing value for field 'access_token' in credentials.")
48+
return {"Authorization": f"Bearer {access_token}"}
49+
50+
def _get_auth_type(local_auth_type: str) -> str:
51+
"""Returns the auth type currently in use.
52+
53+
The databricks-sdk client uses the configurated auth_type to create
54+
a user-agent string which is used for attribution. We should only
55+
overwrite the auth_type if we are using the PositCredentialsStrategy (non-local),
56+
otherwise, we should return the auth_type of the configured local_strategy instead
57+
to avoid breaking someone elses attribution.
58+
59+
https://github.com/databricks/databricks-sdk-py/blob/v0.29.0/databricks/sdk/config.py#L261-L269
60+
61+
NOTE: The databricks-sql client does not use auth_type to set the user-agent.
62+
https://github.com/databricks/databricks-sql-python/blob/v3.3.0/src/databricks/sql/client.py#L214-L219
63+
64+
Returns
65+
-------
66+
str
67+
"""
68+
if is_local():
69+
return local_auth_type
70+
71+
return POSIT_OAUTH_INTEGRATION_AUTH_TYPE
72+
73+
3274

3375
class PositContentCredentialsProvider:
76+
"""CredentialsProvider implementation which initiates a credential exchange using a content-session-token."""
77+
3478
def __init__(self, client: Client):
3579
self._client = client
3680

3781
def __call__(self) -> Dict[str, str]:
3882
credentials = self._client.oauth.get_content_credentials()
39-
access_token = credentials.get("access_token")
40-
if access_token is None:
41-
raise ValueError("Missing value for field 'access_token' in credentials.")
42-
return {"Authorization": f"Bearer {access_token}"}
83+
return _new_bearer_authorization_header(credentials)
84+
85+
86+
class PositCredentialsProvider:
87+
"""CredentialsProvider implementation which initiates a credential exchange using a user-session-token."""
88+
89+
def __init__(self, client: Client, user_session_token: str):
90+
self._client = client
91+
self._user_session_token = user_session_token
92+
93+
def __call__(self) -> Dict[str, str]:
94+
credentials = self._client.oauth.get_credentials(self._user_session_token)
95+
return _new_bearer_authorization_header(credentials)
96+
97+
98+
class PositContentCredentialsStrategy(CredentialsStrategy):
99+
"""CredentialsStrategy implementation which returns a PositContentCredentialsProvider when called."""
43100

44-
class PositContentCredentialsStrategy:
45101
def __init__(
46102
self,
47103
local_strategy: CredentialsStrategy,
@@ -51,15 +107,22 @@ def __init__(
51107
self._client = client
52108

53109
def sql_credentials_provider(self, *args, **kwargs):
110+
"""The sql connector attempts to call the credentials provider w/o any args.
111+
112+
The SQL client's `ExternalAuthProvider` is not compatible w/ the SDK's implementation of
113+
`CredentialsProvider`, so create a no-arg lambda that wraps the args defined by the real caller.
114+
This way we can pass in a databricks `Config` object required by most of the SDK's `CredentialsProvider`
115+
implementations from where `sql.connect` is called.
116+
117+
https://github.com/databricks/databricks-sql-python/issues/148#issuecomment-2271561365
118+
"""
54119
return lambda: self.__call__(*args, **kwargs)
55120

56121
def auth_type(self) -> str:
57-
if is_local():
58-
return self._local_strategy.auth_type()
59-
else:
60-
return "posit-oauth-integration"
122+
return _get_auth_type(self._local_strategy.auth_type())
61123

62124
def __call__(self, *args, **kwargs) -> CredentialsProvider:
125+
# If the content is not running on Connect then fall back to local_strategy
63126
if is_local():
64127
return self._local_strategy(*args, **kwargs)
65128

@@ -69,20 +132,9 @@ def __call__(self, *args, **kwargs) -> CredentialsProvider:
69132
return PositContentCredentialsProvider(self._client)
70133

71134

72-
class PositCredentialsProvider:
73-
def __init__(self, client: Client, user_session_token: str):
74-
self._client = client
75-
self._user_session_token = user_session_token
76-
77-
def __call__(self) -> Dict[str, str]:
78-
credentials = self._client.oauth.get_credentials(self._user_session_token)
79-
access_token = credentials.get("access_token")
80-
if access_token is None:
81-
raise ValueError("Missing value for field 'access_token' in credentials.")
82-
return {"Authorization": f"Bearer {access_token}"}
83-
84-
85135
class PositCredentialsStrategy(CredentialsStrategy):
136+
"""CredentialsStrategy implementation which returns a PositContentCredentialsProvider when called."""
137+
86138
def __init__(
87139
self,
88140
local_strategy: CredentialsStrategy,
@@ -106,23 +158,7 @@ def sql_credentials_provider(self, *args, **kwargs):
106158
return lambda: self.__call__(*args, **kwargs)
107159

108160
def auth_type(self) -> str:
109-
"""Returns the auth type currently in use.
110-
111-
The databricks-sdk client uses the configurated auth_type to create
112-
a user-agent string which is used for attribution. We should only
113-
overwrite the auth_type if we are using the PositCredentialsStrategy (non-local),
114-
otherwise, we should return the auth_type of the configured local_strategy instead
115-
to avoid breaking someone elses attribution.
116-
117-
https://github.com/databricks/databricks-sdk-py/blob/v0.29.0/databricks/sdk/config.py#L261-L269
118-
119-
NOTE: The databricks-sql client does not use auth_type to set the user-agent.
120-
https://github.com/databricks/databricks-sql-python/blob/v3.3.0/src/databricks/sql/client.py#L214-L219
121-
"""
122-
if is_local():
123-
return self._local_strategy.auth_type()
124-
else:
125-
return "posit-oauth-integration"
161+
return _get_auth_type(self._local_strategy.auth_type())
126162

127163
def __call__(self, *args, **kwargs) -> CredentialsProvider:
128164
# If the content is not running on Connect then fall back to local_strategy
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
1+
from .oauth import Credentials as Credentials
12
from .oauth import OAuth as OAuth

src/posit/connect/oauth/oauth.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,7 @@ def sessions(self):
4848
return Sessions(self.params)
4949

5050
def get_credentials(self, user_session_token: Optional[str] = None) -> Credentials:
51-
"""Perform an oauth credential exchange for a viewer's access token."""
52-
51+
"""Perform an oauth credential exchange with a user-session-token."""
5352
# craft a credential exchange request
5453
data = {}
5554
data["grant_type"] = GRANT_TYPE
@@ -61,8 +60,7 @@ def get_credentials(self, user_session_token: Optional[str] = None) -> Credentia
6160
return Credentials(**response.json())
6261

6362
def get_content_credentials(self, content_session_token: Optional[str] = None) -> Credentials:
64-
"""Perform an oauth credential exchange for a service account's access token."""
65-
63+
"""Perform an oauth credential exchange with a content-session-token."""
6664
# craft a credential exchange request
6765
data = {}
6866
data["grant_type"] = GRANT_TYPE

tests/posit/connect/external/test_databricks.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,22 @@
11
from typing import Dict
22
from unittest.mock import patch
33

4+
import pytest
45
import responses
56

67
from posit.connect import Client
78
from posit.connect.external.databricks import (
9+
POSIT_OAUTH_INTEGRATION_AUTH_TYPE,
810
CredentialsProvider,
911
CredentialsStrategy,
12+
PositContentCredentialsProvider,
13+
PositContentCredentialsStrategy,
1014
PositCredentialsProvider,
1115
PositCredentialsStrategy,
16+
_get_auth_type,
17+
_new_bearer_authorization_header,
1218
)
19+
from posit.connect.oauth import Credentials
1320

1421

1522
class mock_strategy(CredentialsStrategy):
@@ -42,8 +49,59 @@ def register_mocks():
4249
},
4350
)
4451

52+
responses.post(
53+
"https://connect.example/__api__/v1/oauth/integrations/credentials",
54+
match=[
55+
responses.matchers.urlencoded_params_matcher(
56+
{
57+
"grant_type": "urn:ietf:params:oauth:grant-type:token-exchange",
58+
"subject_token_type": "urn:posit:connect:content-session-token",
59+
"subject_token": "cit",
60+
},
61+
),
62+
],
63+
json={
64+
"access_token": "content-access-token",
65+
"issued_token_type": "urn:ietf:params:oauth:token-type:access_token",
66+
"token_type": "Bearer",
67+
},
68+
)
69+
70+
71+
4572

4673
class TestPositCredentialsHelpers:
74+
75+
def test_new_bearer_authorization_header(self):
76+
credential = Credentials()
77+
credential["token_type"] = "token_type"
78+
credential["issued_token_type"] = "issued_token_type"
79+
80+
with pytest.raises(ValueError):
81+
_new_bearer_authorization_header(credential)
82+
83+
credential["access_token"] = "access_token"
84+
result = _new_bearer_authorization_header(credential)
85+
assert result == {"Authorization": "Bearer access_token"}
86+
87+
def test_get_auth_type_local(self):
88+
assert _get_auth_type("local-auth") == "local-auth"
89+
90+
91+
@patch.dict("os.environ", {"RSTUDIO_PRODUCT": "CONNECT"})
92+
def test_get_auth_type_connect(self):
93+
assert _get_auth_type("local-auth") == POSIT_OAUTH_INTEGRATION_AUTH_TYPE
94+
95+
@patch.dict("os.environ", {"CONNECT_CONTENT_SESSION_TOKEN": "cit"})
96+
@responses.activate
97+
def test_posit_content_credentials_provider(self):
98+
register_mocks()
99+
100+
client = Client(api_key="12345", url="https://connect.example/")
101+
client._ctx.version = None
102+
cp = PositContentCredentialsProvider(client=client)
103+
assert cp() == {"Authorization": "Bearer content-access-token"}
104+
47105
@responses.activate
48106
def test_posit_credentials_provider(self):
49107
register_mocks()
@@ -53,6 +111,23 @@ def test_posit_credentials_provider(self):
53111
cp = PositCredentialsProvider(client=client, user_session_token="cit")
54112
assert cp() == {"Authorization": "Bearer dynamic-viewer-access-token"}
55113

114+
@patch.dict("os.environ", {"CONNECT_CONTENT_SESSION_TOKEN": "cit"})
115+
@responses.activate
116+
@patch.dict("os.environ", {"RSTUDIO_PRODUCT": "CONNECT"})
117+
def test_posit_content_credentials_strategy(self):
118+
register_mocks()
119+
120+
client = Client(api_key="12345", url="https://connect.example/")
121+
client._ctx.version = None
122+
cs = PositContentCredentialsStrategy(
123+
local_strategy=mock_strategy(),
124+
client=client,
125+
)
126+
cp = cs()
127+
assert cs.auth_type() == "posit-oauth-integration"
128+
assert cp() == {"Authorization": "Bearer content-access-token"}
129+
130+
56131
@responses.activate
57132
@patch.dict("os.environ", {"RSTUDIO_PRODUCT": "CONNECT"})
58133
def test_posit_credentials_strategy(self):
@@ -69,6 +144,17 @@ def test_posit_credentials_strategy(self):
69144
assert cs.auth_type() == "posit-oauth-integration"
70145
assert cp() == {"Authorization": "Bearer dynamic-viewer-access-token"}
71146

147+
def test_posit_content_credentials_strategy_fallback(self):
148+
# local_strategy is used when the content is running locally
149+
client = Client(api_key="12345", url="https://connect.example/")
150+
cs = PositContentCredentialsStrategy(
151+
local_strategy=mock_strategy(),
152+
client=client,
153+
)
154+
cp = cs()
155+
assert cs.auth_type() == "local"
156+
assert cp() == {"Authorization": "Bearer static-pat-token"}
157+
72158
def test_posit_credentials_strategy_fallback(self):
73159
# local_strategy is used when the content is running locally
74160
client = Client(api_key="12345", url="https://connect.example/")

0 commit comments

Comments
 (0)