Skip to content

Commit 20a3908

Browse files
added new function to support aws service accoutn integrations
1 parent d4b350e commit 20a3908

File tree

2 files changed

+107
-3
lines changed

2 files changed

+107
-3
lines changed

src/posit/connect/external/aws.py

Lines changed: 100 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import base64
44
import json
55

6-
from typing_extensions import TYPE_CHECKING, Dict
6+
from typing_extensions import TYPE_CHECKING, Dict, Optional
77

88
from ..oauth.oauth import OAuthTokenType
99

@@ -13,12 +13,33 @@
1313

1414
def get_aws_credentials(client: Client, user_session_token: str) -> Dict[str, str]:
1515
"""
16-
Get AWS credentials using OAuth token exchange.
16+
Get AWS credentials using OAuth token exchange for an AWS Viewer integration.
1717
1818
According to RFC 8693, the access token must be a base64 encoded JSON object
1919
containing the AWS credentials. This function will decode and deserialize the
2020
access token and return the AWS credentials.
2121
22+
Examples
23+
--------
24+
```python
25+
from posit.connect import Client
26+
from posit.connect.external.aws import get_aws_credentials
27+
import boto3
28+
from shiny.express import session
29+
30+
client = Client()
31+
session_token = session.http_conn.headers.get("Posit-Connect-User-Session-Token")
32+
credentials = get_aws_credentials(client, user_session_token)
33+
aws_session = boto3.Session(
34+
aws_access_key_id=credentials["accessKeyId"],
35+
aws_secret_access_key=credentials["secretAccessKey"],
36+
aws_session_token=credentials["sessionToken"],
37+
)
38+
39+
s3 = aws_session.resource("s3")
40+
bucket = s3.Bucket("your-bucket-name")
41+
```
42+
2243
Parameters
2344
----------
2445
client : Client
@@ -42,6 +63,83 @@ def get_aws_credentials(client: Client, user_session_token: str) -> Dict[str, st
4263
access_token = credentials.get("access_token")
4364
if not access_token:
4465
raise ValueError("No access token found in credentials")
66+
return _decode_access_token(access_token)
67+
68+
69+
def get_aws_content_credentials(
70+
client: Client, content_session_token: Optional[str] = None
71+
) -> Dict[str, str]:
72+
"""
73+
Get AWS credentials using OAuth token exchange for an AWS Service Account integration.
74+
75+
According to RFC 8693, the access token must be a base64 encoded JSON object
76+
containing the AWS credentials. This function will decode and deserialize the
77+
access token and return the AWS credentials.
78+
79+
Examples
80+
--------
81+
```python
82+
from posit.connect import Client
83+
from posit.connect.external.aws import get_aws_content_credentials
84+
import boto3
85+
86+
client = Client()
87+
credentials = get_aws_content_credentials(client)
88+
session = boto3.Session(
89+
aws_access_key_id=credentials["accessKeyId"],
90+
aws_secret_access_key=credentials["secretAccessKey"],
91+
aws_session_token=credentials["sessionToken"],
92+
)
93+
94+
s3 = session.resource("s3")
95+
bucket = s3.Bucket("your-bucket-name")
96+
```
97+
98+
Parameters
99+
----------
100+
client : Client
101+
The client to use for making requests
102+
content_session_token : str
103+
The content session token to exchange
104+
105+
Returns
106+
-------
107+
Dict[str, str]
108+
Dictionary containing AWS credentials with keys:
109+
access_key_id, secret_access_key, session_token, and expiration
110+
"""
111+
# Get credentials using OAuth
112+
credentials = client.oauth.get_content_credentials(
113+
content_session_token=content_session_token,
114+
requested_token_type=OAuthTokenType.AWS_CREDENTIALS,
115+
)
116+
117+
# Decode base64 access token
118+
access_token = credentials.get("access_token")
119+
if not access_token:
120+
raise ValueError("No access token found in credentials")
121+
return _decode_access_token(access_token)
122+
123+
124+
def _decode_access_token(access_token: str) -> Dict[str, str]:
125+
"""
126+
Decode and deserialize an access token containing AWS credentials.
127+
128+
According to RFC 8693, the access token must be a base64 encoded JSON object
129+
containing the AWS credentials. This function will decode and deserialize the
130+
access token and return the AWS credentials.
131+
132+
Parameters
133+
----------
134+
access_token : str
135+
The access token to decode
136+
137+
Returns
138+
-------
139+
Dict[str, str]
140+
Dictionary containing AWS credentials with keys:
141+
access_key_id, secret_access_key, session_token, and expiration
142+
"""
45143
decoded_bytes = base64.b64decode(access_token)
46144
decoded_str = decoded_bytes.decode("utf-8")
47145
aws_credentials = json.loads(decoded_str)

src/posit/connect/oauth/oauth.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,13 +76,19 @@ def get_credentials(
7676
response = self._ctx.client.post(self._path, data=data)
7777
return Credentials(**response.json())
7878

79-
def get_content_credentials(self, content_session_token: Optional[str] = None) -> Credentials:
79+
def get_content_credentials(
80+
self,
81+
content_session_token: Optional[str] = None,
82+
requested_token_type: Optional[str | OAuthTokenType] = None,
83+
) -> Credentials:
8084
"""Perform an oauth credential exchange with a content-session-token."""
8185
# craft a credential exchange request
8286
data = {}
8387
data["grant_type"] = GRANT_TYPE
8488
data["subject_token_type"] = OAuthTokenType.CONTENT_SESSION_TOKEN
8589
data["subject_token"] = content_session_token or _get_content_session_token()
90+
if requested_token_type:
91+
data["requested_token_type"] = requested_token_type
8692

8793
response = self._ctx.client.post(self._path, data=data)
8894
return Credentials(**response.json())

0 commit comments

Comments
 (0)