22
33import base64
44import json
5+ from datetime import datetime
56
6- from typing_extensions import TYPE_CHECKING , Dict , Optional
7+ from typing_extensions import TYPE_CHECKING , Optional , TypedDict
78
89from ..oauth .oauth import OAuthTokenType
910
1011if TYPE_CHECKING :
1112 from ..client import Client
1213
1314
14- def get_aws_credentials (client : Client , user_session_token : str ) -> Dict [str , str ]:
15+ class Credentials (TypedDict ):
16+ aws_access_key_id : str
17+ aws_secret_access_key : str
18+ aws_session_token : str
19+ expiration : datetime
20+
21+
22+ def get_credentials (client : Client , user_session_token : str ) -> Credentials :
1523 """
1624 Get AWS credentials using OAuth token exchange for an AWS Viewer integration.
1725
18- According to RFC 8693, the access token must be a base64 encoded JSON object
19- containing the AWS credentials. This function will decode and deserialize the
20- access token and return the AWS credentials.
26+ According to RFC 8693, the access token must be a base64- encoded JSON object
27+ containing the AWS credentials. This function will return the decoded and
28+ deserialized AWS credentials.
2129
2230 Examples
2331 --------
@@ -32,9 +40,9 @@ def get_aws_credentials(client: Client, user_session_token: str) -> Dict[str, st
3240 credentials = get_aws_credentials(client, user_session_token)
3341 aws_session_expiration = credentials["expiration"]
3442 aws_session = boto3.Session(
35- aws_access_key_id=credentials["accessKeyId "],
36- aws_secret_access_key=credentials["secretAccessKey "],
37- aws_session_token=credentials["sessionToken "],
43+ aws_access_key_id=credentials["aws_access_key_id "],
44+ aws_secret_access_key=credentials["aws_secret_access_key "],
45+ aws_session_token=credentials["aws_session_token "],
3846 )
3947
4048 s3 = aws_session.resource("s3")
@@ -67,15 +75,15 @@ def get_aws_credentials(client: Client, user_session_token: str) -> Dict[str, st
6775 return _decode_access_token (access_token )
6876
6977
70- def get_aws_content_credentials (
78+ def get_content_credentials (
7179 client : Client , content_session_token : Optional [str ] = None
72- ) -> Dict [ str , str ] :
80+ ) -> Credentials :
7381 """
7482 Get AWS credentials using OAuth token exchange for an AWS Service Account integration.
7583
76- According to RFC 8693, the access token must be a base64 encoded JSON object
77- containing the AWS credentials. This function will decode and deserialize the
78- access token and return the AWS credentials.
84+ According to RFC 8693, the access token must be a base64- encoded JSON object
85+ containing the AWS credentials. This function will return the decoded and
86+ deserialized AWS credentials.
7987
8088 Examples
8189 --------
@@ -87,10 +95,10 @@ def get_aws_content_credentials(
8795 client = Client()
8896 credentials = get_aws_content_credentials(client)
8997 session_expiration = credentials["expiration"]
90- session = boto3.Session(
91- aws_access_key_id=credentials["accessKeyId "],
92- aws_secret_access_key=credentials["secretAccessKey "],
93- aws_session_token=credentials["sessionToken "],
98+ aws_session = boto3.Session(
99+ aws_access_key_id=credentials["aws_access_key_id "],
100+ aws_secret_access_key=credentials["aws_secret_access_key "],
101+ aws_session_token=credentials["aws_session_token "],
94102 )
95103
96104 s3 = session.resource("s3")
@@ -123,11 +131,11 @@ def get_aws_content_credentials(
123131 return _decode_access_token (access_token )
124132
125133
126- def _decode_access_token (access_token : str ) -> Dict [ str , str ] :
134+ def _decode_access_token (access_token : str ) -> Credentials :
127135 """
128136 Decode and deserialize an access token containing AWS credentials.
129137
130- According to RFC 8693, the access token must be a base64 encoded JSON object
138+ According to RFC 8693, the access token must be a base64- encoded JSON object
131139 containing the AWS credentials. This function will decode and deserialize the
132140 access token and return the AWS credentials.
133141
@@ -138,12 +146,17 @@ def _decode_access_token(access_token: str) -> Dict[str, str]:
138146
139147 Returns
140148 -------
141- Dict[str, str]
149+ Credentials
142150 Dictionary containing AWS credentials with keys:
143151 access_key_id, secret_access_key, session_token, and expiration
144152 """
145153 decoded_bytes = base64 .b64decode (access_token )
146154 decoded_str = decoded_bytes .decode ("utf-8" )
147155 aws_credentials = json .loads (decoded_str )
148156
149- return aws_credentials
157+ return Credentials (
158+ aws_access_key_id = aws_credentials ["accessKeyId" ],
159+ aws_secret_access_key = aws_credentials ["secretAccessKey" ],
160+ aws_session_token = aws_credentials ["sessionToken" ],
161+ expiration = datetime .strptime (aws_credentials ["expiration" ], "%Y-%m-%dT%H:%M:%SZ" ),
162+ )
0 commit comments