22
33import base64
44import json
5+ from datetime import datetime
56
6- from typing_extensions import TYPE_CHECKING , Dict
7+ from typing_extensions import TYPE_CHECKING , Optional , TypedDict
78
89from ..oauth .oauth import OAuthTokenType
910
1011if TYPE_CHECKING :
1112 from ..client import Client
1213
1314
14- def get_aws_credentials (client : Client , user_session_token : str ) -> Dict [str , str ]:
15+ class Credentials (TypedDict ):
16+ aws_access_key_id : str
17+ aws_secret_access_key : str
18+ aws_session_token : str
19+ expiration : datetime
20+
21+
22+ def get_credentials (client : Client , user_session_token : str ) -> Credentials :
1523 """
16- Get AWS credentials using OAuth token exchange.
24+ Get AWS credentials using OAuth token exchange for an AWS Viewer integration .
1725
18- According to RFC 8693, the access token must be a base64 encoded JSON object
19- containing the AWS credentials. This function will decode and deserialize the
20- access token and return the AWS credentials.
26+ According to RFC 8693, the access token must be a base64-encoded JSON object
27+ containing the AWS credentials. This function will return the decoded and
28+ deserialized AWS credentials.
29+
30+ Examples
31+ --------
32+ ```python
33+ from posit.connect import Client
34+ from posit.connect.external.aws import get_aws_credentials
35+ import boto3
36+ from shiny.express import session
37+
38+ client = Client()
39+ session_token = session.http_conn.headers.get("Posit-Connect-User-Session-Token")
40+ credentials = get_aws_credentials(client, user_session_token)
41+ aws_session_expiration = credentials["expiration"]
42+ aws_session = boto3.Session(
43+ aws_access_key_id=credentials["aws_access_key_id"],
44+ aws_secret_access_key=credentials["aws_secret_access_key"],
45+ aws_session_token=credentials["aws_session_token"],
46+ )
47+
48+ s3 = aws_session.resource("s3")
49+ bucket = s3.Bucket("your-bucket-name")
50+ ```
2151
2252 Parameters
2353 ----------
@@ -42,8 +72,91 @@ def get_aws_credentials(client: Client, user_session_token: str) -> Dict[str, st
4272 access_token = credentials .get ("access_token" )
4373 if not access_token :
4474 raise ValueError ("No access token found in credentials" )
75+ return _decode_access_token (access_token )
76+
77+
78+ def get_content_credentials (
79+ client : Client , content_session_token : Optional [str ] = None
80+ ) -> Credentials :
81+ """
82+ Get AWS credentials using OAuth token exchange for an AWS Service Account integration.
83+
84+ According to RFC 8693, the access token must be a base64-encoded JSON object
85+ containing the AWS credentials. This function will return the decoded and
86+ deserialized AWS credentials.
87+
88+ Examples
89+ --------
90+ ```python
91+ from posit.connect import Client
92+ from posit.connect.external.aws import get_aws_content_credentials
93+ import boto3
94+
95+ client = Client()
96+ credentials = get_aws_content_credentials(client)
97+ session_expiration = credentials["expiration"]
98+ aws_session = boto3.Session(
99+ aws_access_key_id=credentials["aws_access_key_id"],
100+ aws_secret_access_key=credentials["aws_secret_access_key"],
101+ aws_session_token=credentials["aws_session_token"],
102+ )
103+
104+ s3 = session.resource("s3")
105+ bucket = s3.Bucket("your-bucket-name")
106+ ```
107+
108+ Parameters
109+ ----------
110+ client : Client
111+ The client to use for making requests
112+ content_session_token : str
113+ The content session token to exchange
114+
115+ Returns
116+ -------
117+ Dict[str, str]
118+ Dictionary containing AWS credentials with keys:
119+ access_key_id, secret_access_key, session_token, and expiration
120+ """
121+ # Get credentials using OAuth
122+ credentials = client .oauth .get_content_credentials (
123+ content_session_token = content_session_token ,
124+ requested_token_type = OAuthTokenType .AWS_CREDENTIALS ,
125+ )
126+
127+ # Decode base64 access token
128+ access_token = credentials .get ("access_token" )
129+ if not access_token :
130+ raise ValueError ("No access token found in credentials" )
131+ return _decode_access_token (access_token )
132+
133+
134+ def _decode_access_token (access_token : str ) -> Credentials :
135+ """
136+ Decode and deserialize an access token containing AWS credentials.
137+
138+ According to RFC 8693, the access token must be a base64-encoded JSON object
139+ containing the AWS credentials. This function will decode and deserialize the
140+ access token and return the AWS credentials.
141+
142+ Parameters
143+ ----------
144+ access_token : str
145+ The access token to decode
146+
147+ Returns
148+ -------
149+ Credentials
150+ Dictionary containing AWS credentials with keys:
151+ access_key_id, secret_access_key, session_token, and expiration
152+ """
45153 decoded_bytes = base64 .b64decode (access_token )
46154 decoded_str = decoded_bytes .decode ("utf-8" )
47155 aws_credentials = json .loads (decoded_str )
48156
49- return aws_credentials
157+ return Credentials (
158+ aws_access_key_id = aws_credentials ["accessKeyId" ],
159+ aws_secret_access_key = aws_credentials ["secretAccessKey" ],
160+ aws_session_token = aws_credentials ["sessionToken" ],
161+ expiration = datetime .strptime (aws_credentials ["expiration" ], "%Y-%m-%dT%H:%M:%SZ" ),
162+ )
0 commit comments