Skip to content

Commit 54292f4

Browse files
feat: add optional audience parameter to credential exchange related methods (#419)
- Adds in support for the optional audience parameter for the credential exchange endpoint to determine which integration to use - Removes "only 1 association" constraint on updating content's oauth integration associations. - Adds `find_by` helper method for integrations and associations to allow a user to search available integrations (globally or scoped to a specific piece of content) to get the GUID for the integration they need dynamically.
1 parent 01524f3 commit 54292f4

File tree

18 files changed

+688
-53
lines changed

18 files changed

+688
-53
lines changed

integration/tests/posit/connect/oauth/test_associations.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def setup_class(cls):
7070
task = bundle.deploy()
7171
task.wait_for()
7272

73-
cls.content.oauth.associations.update(cls.integration["guid"])
73+
cls.content.oauth.associations.update([cls.integration["guid"]])
7474

7575
@classmethod
7676
def teardown_class(cls):
@@ -102,7 +102,7 @@ def test_find_update_by_content(self):
102102
assert associations[0]["oauth_integration_guid"] == self.integration["guid"]
103103

104104
# update content association to another_integration
105-
self.content.oauth.associations.update(self.another_integration["guid"])
105+
self.content.oauth.associations.update([self.another_integration["guid"]])
106106
updated_associations = self.content.oauth.associations.find()
107107
assert len(updated_associations) == 1
108108
assert updated_associations[0]["app_guid"] == self.content["guid"]

src/posit/connect/client.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from __future__ import annotations
44

5-
from typing_extensions import TYPE_CHECKING, overload
5+
from typing_extensions import TYPE_CHECKING, Optional, overload
66

77
from . import hooks, me
88
from .auth import Auth
@@ -11,7 +11,8 @@
1111
from .context import Context, ContextManager, requires
1212
from .groups import Groups
1313
from .metrics.metrics import Metrics
14-
from .oauth.oauth import OAuth, OAuthTokenType
14+
from .oauth.oauth import OAuth
15+
from .oauth.types import OAuthTokenType
1516
from .resources import _PaginatedResourceSequence, _ResourceSequence
1617
from .sessions import Session
1718
from .system import System
@@ -176,7 +177,11 @@ def __init__(self, *args, **kwargs) -> None:
176177
self._ctx = Context(self)
177178

178179
@requires("2025.01.0")
179-
def with_user_session_token(self, token: str) -> Client:
180+
def with_user_session_token(
181+
self,
182+
token: str,
183+
audience: Optional[str] = None,
184+
) -> Client:
180185
"""Create a new Client scoped to the user specified in the user session token.
181186
182187
Create a new Client instance from a user session token exchange for an api key scoped to the
@@ -256,7 +261,9 @@ def user_profile():
256261
raise ValueError("token must be set to non-empty string.")
257262

258263
visitor_credentials = self.oauth.get_credentials(
259-
token, requested_token_type=OAuthTokenType.API_KEY
264+
token,
265+
requested_token_type=OAuthTokenType.API_KEY,
266+
audience=audience,
260267
)
261268

262269
visitor_api_key = visitor_credentials.get("access_token", "")

src/posit/connect/content.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from __future__ import annotations
44

5+
import os
56
import posixpath
67
import time
78

@@ -983,18 +984,26 @@ def find_one(self, **conditions) -> Optional[ContentItem]:
983984
items = self.find(**conditions)
984985
return next(iter(items), None)
985986

986-
def get(self, guid: str) -> ContentItem:
987+
def get(self, guid: Optional[str] = None) -> ContentItem:
987988
"""Get a content item.
988989
990+
If `guid` is None, attempts to get the content item for the current context using the
991+
CONNECT_CONTENT_GUID environment variable, which is automatically set when running on Connect.
992+
989993
Parameters
990994
----------
991-
guid : str
995+
guid : str, optional
992996
The unique identifier of the content item.
993997
994998
Returns
995999
-------
9961000
ContentItem
9971001
"""
1002+
if guid is None:
1003+
guid = os.getenv("CONNECT_CONTENT_GUID")
1004+
if not guid:
1005+
raise RuntimeError("CONNECT_CONTENT_GUID environment variable is not set.")
1006+
9981007
# Always request all available optional fields for the content item
9991008
params = {"include": "owner,tags,vanity_url"}
10001009

src/posit/connect/external/aws.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from typing_extensions import TYPE_CHECKING, Optional, TypedDict
88

9-
from ..oauth.oauth import OAuthTokenType
9+
from ..oauth.types import OAuthTokenType
1010

1111
if TYPE_CHECKING:
1212
from ..client import Client
@@ -19,7 +19,11 @@ class Credentials(TypedDict):
1919
expiration: datetime
2020

2121

22-
def get_credentials(client: Client, user_session_token: str) -> Credentials:
22+
def get_credentials(
23+
client: Client,
24+
user_session_token: str,
25+
audience: Optional[str] = None,
26+
) -> Credentials:
2327
"""
2428
Get AWS credentials using OAuth token exchange for an AWS Viewer integration.
2529
@@ -66,6 +70,7 @@ def get_credentials(client: Client, user_session_token: str) -> Credentials:
6670
credentials = client.oauth.get_credentials(
6771
user_session_token=user_session_token,
6872
requested_token_type=OAuthTokenType.AWS_CREDENTIALS,
73+
audience=audience,
6974
)
7075

7176
# Decode base64 access token
@@ -76,7 +81,9 @@ def get_credentials(client: Client, user_session_token: str) -> Credentials:
7681

7782

7883
def get_content_credentials(
79-
client: Client, content_session_token: Optional[str] = None
84+
client: Client,
85+
content_session_token: Optional[str] = None,
86+
audience: Optional[str] = None,
8087
) -> Credentials:
8188
"""
8289
Get AWS credentials using OAuth token exchange for an AWS Service Account integration.
@@ -122,6 +129,7 @@ def get_content_credentials(
122129
credentials = client.oauth.get_content_credentials(
123130
content_session_token=content_session_token,
124131
requested_token_type=OAuthTokenType.AWS_CREDENTIALS,
132+
audience=audience,
125133
)
126134

127135
# Decode base64 access token

src/posit/connect/external/databricks.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -62,11 +62,16 @@ class _PositConnectContentCredentialsProvider:
6262
* https://github.com/posit-dev/posit-sdk-py/blob/main/src/posit/connect/oauth/oauth.py
6363
"""
6464

65-
def __init__(self, client: Client):
65+
def __init__(
66+
self,
67+
client: Client,
68+
audience: Optional[str] = None,
69+
):
6670
self._client = client
71+
self._audience = audience
6772

6873
def __call__(self) -> Dict[str, str]:
69-
credentials = self._client.oauth.get_content_credentials()
74+
credentials = self._client.oauth.get_content_credentials(audience=self._audience)
7075
return _new_bearer_authorization_header(credentials)
7176

7277

@@ -81,12 +86,21 @@ class _PositConnectViewerCredentialsProvider:
8186
* https://github.com/posit-dev/posit-sdk-py/blob/main/src/posit/connect/oauth/oauth.py
8287
"""
8388

84-
def __init__(self, client: Client, user_session_token: str):
89+
def __init__(
90+
self,
91+
client: Client,
92+
user_session_token: str,
93+
audience: Optional[str] = None,
94+
):
8595
self._client = client
8696
self._user_session_token = user_session_token
97+
self._audience = audience
8798

8899
def __call__(self) -> Dict[str, str]:
89-
credentials = self._client.oauth.get_credentials(self._user_session_token)
100+
credentials = self._client.oauth.get_credentials(
101+
self._user_session_token,
102+
audience=self._audience,
103+
)
90104
return _new_bearer_authorization_header(credentials)
91105

92106

@@ -174,10 +188,12 @@ def __init__(
174188
self,
175189
client: Optional[Client] = None,
176190
user_session_token: Optional[str] = None,
191+
audience: Optional[str] = None,
177192
):
178193
self._cp: Optional[CredentialsProvider] = None
179194
self._client = client
180195
self._user_session_token = user_session_token
196+
self._audience = audience
181197

182198
def auth_type(self) -> str:
183199
return POSIT_OAUTH_INTEGRATION_AUTH_TYPE
@@ -194,13 +210,18 @@ def __call__(self, *args, **kwargs) -> CredentialsProvider: # noqa: ARG002
194210
if self._cp is None:
195211
if self._user_session_token:
196212
self._cp = _PositConnectViewerCredentialsProvider(
197-
self._client, self._user_session_token
213+
self._client,
214+
self._user_session_token,
215+
audience=self._audience,
198216
)
199217
else:
200218
logger.info(
201219
"ConnectStrategy will attempt to use OAuth Service Account credentials because user_session_token is not set"
202220
)
203-
self._cp = _PositConnectContentCredentialsProvider(self._client)
221+
self._cp = _PositConnectContentCredentialsProvider(
222+
self._client,
223+
audience=self._audience,
224+
)
204225
return self._cp
205226

206227

src/posit/connect/external/snowflake.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,10 +69,12 @@ def __init__(
6969
local_authenticator: Optional[str] = None,
7070
client: Optional[Client] = None,
7171
user_session_token: Optional[str] = None,
72+
audience: Optional[str] = None,
7273
):
7374
self._local_authenticator = local_authenticator
7475
self._client = client
7576
self._user_session_token = user_session_token
77+
self._audience = audience
7678

7779
@property
7880
def authenticator(self) -> Optional[str]:
@@ -93,5 +95,8 @@ def token(self) -> Optional[str]:
9395
if self._client is None:
9496
self._client = Client()
9597

96-
credentials = self._client.oauth.get_credentials(self._user_session_token)
98+
credentials = self._client.oauth.get_credentials(
99+
self._user_session_token,
100+
audience=self._audience,
101+
)
97102
return credentials.get("access_token")

src/posit/connect/oauth/associations.py

Lines changed: 68 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,17 @@
11
"""OAuth association resources."""
22

3-
from typing_extensions import List
3+
from __future__ import annotations
44

5-
from ..context import Context
6-
from ..resources import BaseResource, Resources
5+
from functools import partial
6+
7+
from typing_extensions import TYPE_CHECKING, List, Optional
8+
9+
from ..context import requires
10+
from ..resources import BaseResource, Resources, _matches_exact, _matches_pattern
11+
12+
if TYPE_CHECKING:
13+
from ..context import Context
14+
from ..oauth import types
715

816

917
class Association(BaseResource):
@@ -59,16 +67,71 @@ def find(self) -> List[Association]:
5967
for result in response.json()
6068
]
6169

70+
@requires("2025.07.0-dev")
71+
def find_by(
72+
self,
73+
integration_type: Optional[types.OAuthIntegrationType | str] = None,
74+
auth_type: Optional[types.OAuthIntegrationAuthType | str] = None,
75+
name: Optional[str] = None,
76+
description: Optional[str] = None,
77+
guid: Optional[str] = None,
78+
) -> Association | None:
79+
"""Find an OAuth integration associated with content by various criteria.
80+
81+
Parameters
82+
----------
83+
integration_type : Optional[types.OAuthIntegrationType | str]
84+
The type of the integration (e.g., "aws", "azure").
85+
auth_type : Optional[types.OAuthIntegrationAuthType | str]
86+
The authentication type of the integration (e.g., "Viewer", "Service Account").
87+
name : Optional[str]
88+
A regex pattern to match the integration name. For exact matches, use `^` and `$`. For example,
89+
`^My Integration$` will match only "My Integration".
90+
description : Optional[str]
91+
A regex pattern to match the integration description. For exact matches, use `^` and `$`. For example,
92+
`^My Integration Description$` will match only "My Integration Description".
93+
guid : Optional[str]
94+
The unique identifier of the integration.
95+
96+
Returns
97+
-------
98+
Association | None
99+
The first matching association, or None if no match is found.
100+
"""
101+
filters = []
102+
if integration_type is not None:
103+
filters.append(
104+
partial(_matches_exact, key="oauth_integration_template", value=integration_type)
105+
)
106+
if auth_type is not None:
107+
filters.append(
108+
partial(_matches_exact, key="oauth_integration_auth_type", value=auth_type)
109+
)
110+
if name is not None:
111+
filters.append(partial(_matches_pattern, key="oauth_integration_name", pattern=name))
112+
if description is not None:
113+
filters.append(
114+
partial(_matches_pattern, key="oauth_integration_description", pattern=description)
115+
)
116+
if guid is not None:
117+
filters.append(partial(_matches_exact, key="oauth_integration_guid", value=guid))
118+
119+
for association in self.find():
120+
if all(f(association) for f in filters):
121+
return association
122+
123+
return None
124+
62125
def delete(self) -> None:
63126
"""Delete integration associations."""
64127
data = []
65128

66129
path = f"v1/content/{self.content_guid}/oauth/integrations/associations"
67130
self._ctx.client.put(path, json=data)
68131

69-
def update(self, integration_guid: str) -> None:
132+
def update(self, integration_guids: list[str]) -> None:
70133
"""Set integration associations."""
71-
data = [{"oauth_integration_guid": integration_guid}]
134+
data = [{"oauth_integration_guid": guid} for guid in integration_guids]
72135

73136
path = f"v1/content/{self.content_guid}/oauth/integrations/associations"
74137
self._ctx.client.put(path, json=data)

0 commit comments

Comments
 (0)