Skip to content

Commit 331a59e

Browse files
feat: add optional audience parameter to OAuth credential exchange methods and related functions
1 parent def37c7 commit 331a59e

File tree

6 files changed

+70
-14
lines changed

6 files changed

+70
-14
lines changed

src/posit/connect/client.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
from __future__ import annotations
44

5+
from typing import Optional
6+
57
from typing_extensions import TYPE_CHECKING, overload
68

79
from . import hooks, me
@@ -176,7 +178,11 @@ def __init__(self, *args, **kwargs) -> None:
176178
self._ctx = Context(self)
177179

178180
@requires("2025.01.0")
179-
def with_user_session_token(self, token: str) -> Client:
181+
def with_user_session_token(
182+
self,
183+
token: str,
184+
audience: Optional[str] = None,
185+
) -> Client:
180186
"""Create a new Client scoped to the user specified in the user session token.
181187
182188
Create a new Client instance from a user session token exchange for an api key scoped to the
@@ -256,7 +262,9 @@ def user_profile():
256262
raise ValueError("token must be set to non-empty string.")
257263

258264
visitor_credentials = self.oauth.get_credentials(
259-
token, requested_token_type=OAuthTokenType.API_KEY
265+
token,
266+
requested_token_type=OAuthTokenType.API_KEY,
267+
audience=audience,
260268
)
261269

262270
visitor_api_key = visitor_credentials.get("access_token", "")

src/posit/connect/external/aws.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,11 @@ class Credentials(TypedDict):
1919
expiration: datetime
2020

2121

22-
def get_credentials(client: Client, user_session_token: str) -> Credentials:
22+
def get_credentials(
23+
client: Client,
24+
user_session_token: str,
25+
audience: Optional[str] = None,
26+
) -> Credentials:
2327
"""
2428
Get AWS credentials using OAuth token exchange for an AWS Viewer integration.
2529
@@ -66,6 +70,7 @@ def get_credentials(client: Client, user_session_token: str) -> Credentials:
6670
credentials = client.oauth.get_credentials(
6771
user_session_token=user_session_token,
6872
requested_token_type=OAuthTokenType.AWS_CREDENTIALS,
73+
audience=audience,
6974
)
7075

7176
# Decode base64 access token
@@ -76,7 +81,9 @@ def get_credentials(client: Client, user_session_token: str) -> Credentials:
7681

7782

7883
def get_content_credentials(
79-
client: Client, content_session_token: Optional[str] = None
84+
client: Client,
85+
content_session_token: Optional[str] = None,
86+
audience: Optional[str] = None,
8087
) -> Credentials:
8188
"""
8289
Get AWS credentials using OAuth token exchange for an AWS Service Account integration.
@@ -122,6 +129,7 @@ def get_content_credentials(
122129
credentials = client.oauth.get_content_credentials(
123130
content_session_token=content_session_token,
124131
requested_token_type=OAuthTokenType.AWS_CREDENTIALS,
132+
audience=audience,
125133
)
126134

127135
# Decode base64 access token

src/posit/connect/external/databricks.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -62,11 +62,16 @@ class _PositConnectContentCredentialsProvider:
6262
* https://github.com/posit-dev/posit-sdk-py/blob/main/src/posit/connect/oauth/oauth.py
6363
"""
6464

65-
def __init__(self, client: Client):
65+
def __init__(
66+
self,
67+
client: Client,
68+
audience: Optional[str] = None,
69+
):
6670
self._client = client
71+
self._audience = audience
6772

6873
def __call__(self) -> Dict[str, str]:
69-
credentials = self._client.oauth.get_content_credentials()
74+
credentials = self._client.oauth.get_content_credentials(audience=self._audience)
7075
return _new_bearer_authorization_header(credentials)
7176

7277

@@ -81,12 +86,21 @@ class _PositConnectViewerCredentialsProvider:
8186
* https://github.com/posit-dev/posit-sdk-py/blob/main/src/posit/connect/oauth/oauth.py
8287
"""
8388

84-
def __init__(self, client: Client, user_session_token: str):
89+
def __init__(
90+
self,
91+
client: Client,
92+
user_session_token: str,
93+
audience: Optional[str] = None,
94+
):
8595
self._client = client
8696
self._user_session_token = user_session_token
97+
self._audience = audience
8798

8899
def __call__(self) -> Dict[str, str]:
89-
credentials = self._client.oauth.get_credentials(self._user_session_token)
100+
credentials = self._client.oauth.get_credentials(
101+
self._user_session_token,
102+
audience=self._audience,
103+
)
90104
return _new_bearer_authorization_header(credentials)
91105

92106

@@ -174,10 +188,12 @@ def __init__(
174188
self,
175189
client: Optional[Client] = None,
176190
user_session_token: Optional[str] = None,
191+
audience: Optional[str] = None,
177192
):
178193
self._cp: Optional[CredentialsProvider] = None
179194
self._client = client
180195
self._user_session_token = user_session_token
196+
self._audience = audience
181197

182198
def auth_type(self) -> str:
183199
return POSIT_OAUTH_INTEGRATION_AUTH_TYPE
@@ -194,13 +210,18 @@ def __call__(self, *args, **kwargs) -> CredentialsProvider: # noqa: ARG002
194210
if self._cp is None:
195211
if self._user_session_token:
196212
self._cp = _PositConnectViewerCredentialsProvider(
197-
self._client, self._user_session_token
213+
self._client,
214+
self._user_session_token,
215+
audience=self._audience,
198216
)
199217
else:
200218
logger.info(
201219
"ConnectStrategy will attempt to use OAuth Service Account credentials because user_session_token is not set"
202220
)
203-
self._cp = _PositConnectContentCredentialsProvider(self._client)
221+
self._cp = _PositConnectContentCredentialsProvider(
222+
self._client,
223+
audience=self._audience,
224+
)
204225
return self._cp
205226

206227

src/posit/connect/external/snowflake.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,10 +69,12 @@ def __init__(
6969
local_authenticator: Optional[str] = None,
7070
client: Optional[Client] = None,
7171
user_session_token: Optional[str] = None,
72+
audience: Optional[str] = None,
7273
):
7374
self._local_authenticator = local_authenticator
7475
self._client = client
7576
self._user_session_token = user_session_token
77+
self._audience = audience
7678

7779
@property
7880
def authenticator(self) -> Optional[str]:
@@ -93,5 +95,8 @@ def token(self) -> Optional[str]:
9395
if self._client is None:
9496
self._client = Client()
9597

96-
credentials = self._client.oauth.get_credentials(self._user_session_token)
98+
credentials = self._client.oauth.get_credentials(
99+
self._user_session_token,
100+
audience=self._audience,
101+
)
97102
return credentials.get("access_token")

src/posit/connect/oauth/associations.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,16 @@
11
"""OAuth association resources."""
22

3+
from __future__ import annotations
4+
5+
from typing import TYPE_CHECKING
6+
37
from typing_extensions import List
48

5-
from ..context import Context
69
from ..resources import BaseResource, Resources
710

11+
if TYPE_CHECKING:
12+
from ..context import Context
13+
814

915
class Association(BaseResource):
1016
pass
@@ -66,9 +72,11 @@ def delete(self) -> None:
6672
path = f"v1/content/{self.content_guid}/oauth/integrations/associations"
6773
self._ctx.client.put(path, json=data)
6874

69-
def update(self, integration_guid: str) -> None:
75+
def update(self, integration_guid: str | list[str]) -> None:
7076
"""Set integration associations."""
71-
data = [{"oauth_integration_guid": integration_guid}]
77+
if isinstance(integration_guid, str):
78+
integration_guid = [integration_guid]
79+
data = [{"oauth_integration_guid": guid} for guid in integration_guid]
7280

7381
path = f"v1/content/{self.content_guid}/oauth/integrations/associations"
7482
self._ctx.client.put(path, json=data)

src/posit/connect/oauth/oauth.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ def get_credentials(
6262
self,
6363
user_session_token: Optional[str] = None,
6464
requested_token_type: Optional[str | OAuthTokenType] = None,
65+
audience: Optional[str] = None,
6566
) -> Credentials:
6667
"""Perform an oauth credential exchange with a user-session-token."""
6768
# craft a credential exchange request
@@ -72,6 +73,8 @@ def get_credentials(
7273
data["subject_token"] = user_session_token
7374
if requested_token_type:
7475
data["requested_token_type"] = requested_token_type
76+
if audience:
77+
data["audience"] = audience
7578

7679
response = self._ctx.client.post(self._path, data=data)
7780
return Credentials(**response.json())
@@ -80,6 +83,7 @@ def get_content_credentials(
8083
self,
8184
content_session_token: Optional[str] = None,
8285
requested_token_type: Optional[str | OAuthTokenType] = None,
86+
audience: Optional[str] = None,
8387
) -> Credentials:
8488
"""Perform an oauth credential exchange with a content-session-token."""
8589
# craft a credential exchange request
@@ -89,6 +93,8 @@ def get_content_credentials(
8993
data["subject_token"] = content_session_token or _get_content_session_token()
9094
if requested_token_type:
9195
data["requested_token_type"] = requested_token_type
96+
if audience:
97+
data["audience"] = audience
9298

9399
response = self._ctx.client.post(self._path, data=data)
94100
return Credentials(**response.json())

0 commit comments

Comments
 (0)