Skip to content

Commit 23d5778

Browse files
feat: Add support for AWS integration content session token exchange (#399)
Similar to existing helper for Viewer integrations but the content session token counterpart. Usage example in docstring for function. Tests updated.
1 parent d4b350e commit 23d5778

File tree

4 files changed

+218
-34
lines changed

4 files changed

+218
-34
lines changed

src/posit/connect/external/aws.py

Lines changed: 120 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,22 +2,52 @@
22

33
import base64
44
import json
5+
from datetime import datetime
56

6-
from typing_extensions import TYPE_CHECKING, Dict
7+
from typing_extensions import TYPE_CHECKING, Optional, TypedDict
78

89
from ..oauth.oauth import OAuthTokenType
910

1011
if TYPE_CHECKING:
1112
from ..client import Client
1213

1314

14-
def get_aws_credentials(client: Client, user_session_token: str) -> Dict[str, str]:
15+
class Credentials(TypedDict):
16+
aws_access_key_id: str
17+
aws_secret_access_key: str
18+
aws_session_token: str
19+
expiration: datetime
20+
21+
22+
def get_credentials(client: Client, user_session_token: str) -> Credentials:
1523
"""
16-
Get AWS credentials using OAuth token exchange.
24+
Get AWS credentials using OAuth token exchange for an AWS Viewer integration.
1725
18-
According to RFC 8693, the access token must be a base64 encoded JSON object
19-
containing the AWS credentials. This function will decode and deserialize the
20-
access token and return the AWS credentials.
26+
According to RFC 8693, the access token must be a base64-encoded JSON object
27+
containing the AWS credentials. This function will return the decoded and
28+
deserialized AWS credentials.
29+
30+
Examples
31+
--------
32+
```python
33+
from posit.connect import Client
34+
from posit.connect.external.aws import get_aws_credentials
35+
import boto3
36+
from shiny.express import session
37+
38+
client = Client()
39+
session_token = session.http_conn.headers.get("Posit-Connect-User-Session-Token")
40+
credentials = get_aws_credentials(client, user_session_token)
41+
aws_session_expiration = credentials["expiration"]
42+
aws_session = boto3.Session(
43+
aws_access_key_id=credentials["aws_access_key_id"],
44+
aws_secret_access_key=credentials["aws_secret_access_key"],
45+
aws_session_token=credentials["aws_session_token"],
46+
)
47+
48+
s3 = aws_session.resource("s3")
49+
bucket = s3.Bucket("your-bucket-name")
50+
```
2151
2252
Parameters
2353
----------
@@ -42,8 +72,91 @@ def get_aws_credentials(client: Client, user_session_token: str) -> Dict[str, st
4272
access_token = credentials.get("access_token")
4373
if not access_token:
4474
raise ValueError("No access token found in credentials")
75+
return _decode_access_token(access_token)
76+
77+
78+
def get_content_credentials(
79+
client: Client, content_session_token: Optional[str] = None
80+
) -> Credentials:
81+
"""
82+
Get AWS credentials using OAuth token exchange for an AWS Service Account integration.
83+
84+
According to RFC 8693, the access token must be a base64-encoded JSON object
85+
containing the AWS credentials. This function will return the decoded and
86+
deserialized AWS credentials.
87+
88+
Examples
89+
--------
90+
```python
91+
from posit.connect import Client
92+
from posit.connect.external.aws import get_aws_content_credentials
93+
import boto3
94+
95+
client = Client()
96+
credentials = get_aws_content_credentials(client)
97+
session_expiration = credentials["expiration"]
98+
aws_session = boto3.Session(
99+
aws_access_key_id=credentials["aws_access_key_id"],
100+
aws_secret_access_key=credentials["aws_secret_access_key"],
101+
aws_session_token=credentials["aws_session_token"],
102+
)
103+
104+
s3 = session.resource("s3")
105+
bucket = s3.Bucket("your-bucket-name")
106+
```
107+
108+
Parameters
109+
----------
110+
client : Client
111+
The client to use for making requests
112+
content_session_token : str
113+
The content session token to exchange
114+
115+
Returns
116+
-------
117+
Dict[str, str]
118+
Dictionary containing AWS credentials with keys:
119+
access_key_id, secret_access_key, session_token, and expiration
120+
"""
121+
# Get credentials using OAuth
122+
credentials = client.oauth.get_content_credentials(
123+
content_session_token=content_session_token,
124+
requested_token_type=OAuthTokenType.AWS_CREDENTIALS,
125+
)
126+
127+
# Decode base64 access token
128+
access_token = credentials.get("access_token")
129+
if not access_token:
130+
raise ValueError("No access token found in credentials")
131+
return _decode_access_token(access_token)
132+
133+
134+
def _decode_access_token(access_token: str) -> Credentials:
135+
"""
136+
Decode and deserialize an access token containing AWS credentials.
137+
138+
According to RFC 8693, the access token must be a base64-encoded JSON object
139+
containing the AWS credentials. This function will decode and deserialize the
140+
access token and return the AWS credentials.
141+
142+
Parameters
143+
----------
144+
access_token : str
145+
The access token to decode
146+
147+
Returns
148+
-------
149+
Credentials
150+
Dictionary containing AWS credentials with keys:
151+
access_key_id, secret_access_key, session_token, and expiration
152+
"""
45153
decoded_bytes = base64.b64decode(access_token)
46154
decoded_str = decoded_bytes.decode("utf-8")
47155
aws_credentials = json.loads(decoded_str)
48156

49-
return aws_credentials
157+
return Credentials(
158+
aws_access_key_id=aws_credentials["accessKeyId"],
159+
aws_secret_access_key=aws_credentials["secretAccessKey"],
160+
aws_session_token=aws_credentials["sessionToken"],
161+
expiration=datetime.strptime(aws_credentials["expiration"], "%Y-%m-%dT%H:%M:%SZ"),
162+
)

src/posit/connect/oauth/oauth.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,13 +76,19 @@ def get_credentials(
7676
response = self._ctx.client.post(self._path, data=data)
7777
return Credentials(**response.json())
7878

79-
def get_content_credentials(self, content_session_token: Optional[str] = None) -> Credentials:
79+
def get_content_credentials(
80+
self,
81+
content_session_token: Optional[str] = None,
82+
requested_token_type: Optional[str | OAuthTokenType] = None,
83+
) -> Credentials:
8084
"""Perform an oauth credential exchange with a content-session-token."""
8185
# craft a credential exchange request
8286
data = {}
8387
data["grant_type"] = GRANT_TYPE
8488
data["subject_token_type"] = OAuthTokenType.CONTENT_SESSION_TOKEN
8589
data["subject_token"] = content_session_token or _get_content_session_token()
90+
if requested_token_type:
91+
data["requested_token_type"] = requested_token_type
8692

8793
response = self._ctx.client.post(self._path, data=data)
8894
return Credentials(**response.json())

tests/posit/connect/external/test_aws.py

Lines changed: 72 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,22 @@
1+
from datetime import datetime
2+
13
import pytest
24
import responses
35

46
from posit.connect import Client
5-
from posit.connect.external.aws import get_aws_credentials
7+
from posit.connect.external.aws import (
8+
Credentials,
9+
_decode_access_token,
10+
get_content_credentials,
11+
get_credentials,
12+
)
613

7-
aws_creds = {
8-
"accessKeyId": "abc123",
9-
"secretAccessKey": "def456",
10-
"sessionToken": "ghi789",
11-
"expiration": "2025-01-01T00:00:00Z",
12-
}
14+
aws_creds = Credentials(
15+
aws_access_key_id="abc123",
16+
aws_secret_access_key="def456",
17+
aws_session_token="ghi789",
18+
expiration=datetime(2025, 1, 1, 0, 0, 0, 0),
19+
)
1320

1421
encoded_aws_creds = "eyJhY2Nlc3NLZXlJZCI6ICJhYmMxMjMiLCAic2VjcmV0QWNjZXNzS2V5IjogImRlZjQ1NiIsICJzZXNzaW9uVG9rZW4iOiAiZ2hpNzg5IiwgImV4cGlyYXRpb24iOiAiMjAyNS0wMS0wMVQwMDowMDowMFoifQ=="
1522

@@ -38,7 +45,7 @@ def test_get_aws_credentials(self):
3845

3946
c = Client(api_key="12345", url="https://connect.example/")
4047
c._ctx.version = None
41-
response = get_aws_credentials(c, "cit")
48+
response = get_credentials(c, "cit")
4249

4350
assert response == aws_creds
4451

@@ -63,6 +70,62 @@ def test_get_aws_credentials_no_token(self):
6370
c._ctx.version = None
6471

6572
with pytest.raises(ValueError) as e:
66-
get_aws_credentials(c, "cit")
73+
get_credentials(c, "cit")
74+
75+
assert e.match("No access token found in credentials")
76+
77+
@responses.activate
78+
def test_get_aws_content_credentials(self):
79+
responses.post(
80+
"https://connect.example/__api__/v1/oauth/integrations/credentials",
81+
match=[
82+
responses.matchers.urlencoded_params_matcher(
83+
{
84+
"grant_type": "urn:ietf:params:oauth:grant-type:token-exchange",
85+
"subject_token_type": "urn:posit:connect:content-session-token",
86+
"subject_token": "cit",
87+
"requested_token_type": "urn:ietf:params:aws:token-type:credentials",
88+
}
89+
)
90+
],
91+
json={
92+
"access_token": encoded_aws_creds,
93+
"issued_token_type": "urn:ietf:params:aws:token-type:credentials",
94+
"token_type": "aws_credentials",
95+
},
96+
)
97+
98+
c = Client(api_key="12345", url="https://connect.example/")
99+
c._ctx.version = None
100+
response = get_content_credentials(c, "cit")
101+
102+
assert response == aws_creds
103+
104+
@responses.activate
105+
def test_get_aws_content_credentials_no_token(self):
106+
responses.post(
107+
"https://connect.example/__api__/v1/oauth/integrations/credentials",
108+
match=[
109+
responses.matchers.urlencoded_params_matcher(
110+
{
111+
"grant_type": "urn:ietf:params:oauth:grant-type:token-exchange",
112+
"subject_token_type": "urn:posit:connect:content-session-token",
113+
"subject_token": "cit",
114+
"requested_token_type": "urn:ietf:params:aws:token-type:credentials",
115+
}
116+
)
117+
],
118+
json={},
119+
)
120+
121+
c = Client(api_key="12345", url="https://connect.example/")
122+
c._ctx.version = None
123+
124+
with pytest.raises(ValueError) as e:
125+
get_content_credentials(c, "cit")
67126

68127
assert e.match("No access token found in credentials")
128+
129+
def test_decode_access_token(self):
130+
decoded_creds = _decode_access_token(encoded_aws_creds)
131+
assert decoded_creds == aws_creds

tests/posit/connect/oauth/test_oauth.py

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ def test_get_credentials(self):
2626
"grant_type": "urn:ietf:params:oauth:grant-type:token-exchange",
2727
"subject_token_type": "urn:posit:connect:user-session-token",
2828
"subject_token": "cit",
29+
# no requested token type set
2930
},
3031
),
3132
],
@@ -41,7 +42,7 @@ def test_get_credentials(self):
4142
assert creds.get("access_token") == "viewer-token"
4243

4344
@responses.activate
44-
def test_get_credentials_api_key(self):
45+
def test_get_credentials_with_requested_token_type(self):
4546
responses.post(
4647
"https://connect.example/__api__/v1/oauth/integrations/credentials",
4748
match=[
@@ -68,34 +69,32 @@ def test_get_credentials_api_key(self):
6869
assert creds.get("token_type") == "Key"
6970

7071
@responses.activate
71-
def test_get_credentials_aws(self):
72+
def test_get_content_credentials(self):
7273
responses.post(
7374
"https://connect.example/__api__/v1/oauth/integrations/credentials",
7475
match=[
7576
responses.matchers.urlencoded_params_matcher(
7677
{
7778
"grant_type": "urn:ietf:params:oauth:grant-type:token-exchange",
78-
"subject_token_type": "urn:posit:connect:user-session-token",
79+
"subject_token_type": "urn:posit:connect:content-session-token",
7980
"subject_token": "cit",
80-
"requested_token_type": "urn:ietf:params:aws:token-type:credentials",
81+
# no requested token type set
8182
},
8283
),
8384
],
8485
json={
85-
"access_token": "encoded-aws-creds",
86-
"issued_token_type": "urn:ietf:params:aws:token-type:credentials",
87-
"token_type": "aws_credentials",
86+
"access_token": "content-token",
87+
"issued_token_type": "urn:ietf:params:oauth:token-type:access_token",
88+
"token_type": "Bearer",
8889
},
8990
)
9091
c = Client(api_key="12345", url="https://connect.example/")
9192
c._ctx.version = None
92-
creds = c.oauth.get_credentials("cit", OAuthTokenType.AWS_CREDENTIALS)
93-
assert creds.get("access_token") == "encoded-aws-creds"
94-
assert creds.get("issued_token_type") == "urn:ietf:params:aws:token-type:credentials"
95-
assert creds.get("token_type") == "aws_credentials"
93+
creds = c.oauth.get_content_credentials("cit")
94+
assert creds.get("access_token") == "content-token"
9695

9796
@responses.activate
98-
def test_get_content_credentials(self):
97+
def test_get_content_credentials_with_requested_token_type(self):
9998
responses.post(
10099
"https://connect.example/__api__/v1/oauth/integrations/credentials",
101100
match=[
@@ -104,19 +103,22 @@ def test_get_content_credentials(self):
104103
"grant_type": "urn:ietf:params:oauth:grant-type:token-exchange",
105104
"subject_token_type": "urn:posit:connect:content-session-token",
106105
"subject_token": "cit",
106+
"requested_token_type": "urn:ietf:params:aws:token-type:credentials",
107107
},
108108
),
109109
],
110110
json={
111-
"access_token": "content-token",
112-
"issued_token_type": "urn:ietf:params:oauth:token-type:access_token",
113-
"token_type": "Bearer",
111+
"access_token": "encoded-aws-creds",
112+
"issued_token_type": "urn:ietf:params:aws:token-type:credentials",
113+
"token_type": "aws_credentials",
114114
},
115115
)
116116
c = Client(api_key="12345", url="https://connect.example/")
117117
c._ctx.version = None
118-
creds = c.oauth.get_content_credentials("cit")
119-
assert creds.get("access_token") == "content-token"
118+
creds = c.oauth.get_content_credentials("cit", OAuthTokenType.AWS_CREDENTIALS)
119+
assert creds.get("access_token") == "encoded-aws-creds"
120+
assert creds.get("issued_token_type") == "urn:ietf:params:aws:token-type:credentials"
121+
assert creds.get("token_type") == "aws_credentials"
120122

121123
@patch.dict("os.environ", {"CONNECT_CONTENT_SESSION_TOKEN": "cit"})
122124
@responses.activate

0 commit comments

Comments
 (0)