Skip to content

Commit 163a9d4

Browse files
add Credentials typeddict; simplify aws credential naming
1 parent 40d022f commit 163a9d4

File tree

2 files changed

+49
-33
lines changed

2 files changed

+49
-33
lines changed

src/posit/connect/external/aws.py

Lines changed: 34 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,22 +2,30 @@
22

33
import base64
44
import json
5+
from datetime import datetime
56

6-
from typing_extensions import TYPE_CHECKING, Dict, Optional
7+
from typing_extensions import TYPE_CHECKING, Optional, TypedDict
78

89
from ..oauth.oauth import OAuthTokenType
910

1011
if TYPE_CHECKING:
1112
from ..client import Client
1213

1314

14-
def get_aws_credentials(client: Client, user_session_token: str) -> Dict[str, str]:
15+
class Credentials(TypedDict):
16+
aws_access_key_id: str
17+
aws_secret_access_key: str
18+
aws_session_token: str
19+
expiration: datetime
20+
21+
22+
def get_credentials(client: Client, user_session_token: str) -> Credentials:
1523
"""
1624
Get AWS credentials using OAuth token exchange for an AWS Viewer integration.
1725
18-
According to RFC 8693, the access token must be a base64 encoded JSON object
19-
containing the AWS credentials. This function will decode and deserialize the
20-
access token and return the AWS credentials.
26+
According to RFC 8693, the access token must be a base64-encoded JSON object
27+
containing the AWS credentials. This function will return the decoded and
28+
deserialized AWS credentials.
2129
2230
Examples
2331
--------
@@ -32,9 +40,9 @@ def get_aws_credentials(client: Client, user_session_token: str) -> Dict[str, st
3240
credentials = get_aws_credentials(client, user_session_token)
3341
aws_session_expiration = credentials["expiration"]
3442
aws_session = boto3.Session(
35-
aws_access_key_id=credentials["accessKeyId"],
36-
aws_secret_access_key=credentials["secretAccessKey"],
37-
aws_session_token=credentials["sessionToken"],
43+
aws_access_key_id=credentials["aws_access_key_id"],
44+
aws_secret_access_key=credentials["aws_secret_access_key"],
45+
aws_session_token=credentials["aws_session_token"],
3846
)
3947
4048
s3 = aws_session.resource("s3")
@@ -67,15 +75,15 @@ def get_aws_credentials(client: Client, user_session_token: str) -> Dict[str, st
6775
return _decode_access_token(access_token)
6876

6977

70-
def get_aws_content_credentials(
78+
def get_content_credentials(
7179
client: Client, content_session_token: Optional[str] = None
72-
) -> Dict[str, str]:
80+
) -> Credentials:
7381
"""
7482
Get AWS credentials using OAuth token exchange for an AWS Service Account integration.
7583
76-
According to RFC 8693, the access token must be a base64 encoded JSON object
77-
containing the AWS credentials. This function will decode and deserialize the
78-
access token and return the AWS credentials.
84+
According to RFC 8693, the access token must be a base64-encoded JSON object
85+
containing the AWS credentials. This function will return the decoded and
86+
deserialized AWS credentials.
7987
8088
Examples
8189
--------
@@ -87,10 +95,10 @@ def get_aws_content_credentials(
8795
client = Client()
8896
credentials = get_aws_content_credentials(client)
8997
session_expiration = credentials["expiration"]
90-
session = boto3.Session(
91-
aws_access_key_id=credentials["accessKeyId"],
92-
aws_secret_access_key=credentials["secretAccessKey"],
93-
aws_session_token=credentials["sessionToken"],
98+
aws_session = boto3.Session(
99+
aws_access_key_id=credentials["aws_access_key_id"],
100+
aws_secret_access_key=credentials["aws_secret_access_key"],
101+
aws_session_token=credentials["aws_session_token"],
94102
)
95103
96104
s3 = session.resource("s3")
@@ -123,11 +131,11 @@ def get_aws_content_credentials(
123131
return _decode_access_token(access_token)
124132

125133

126-
def _decode_access_token(access_token: str) -> Dict[str, str]:
134+
def _decode_access_token(access_token: str) -> Credentials:
127135
"""
128136
Decode and deserialize an access token containing AWS credentials.
129137
130-
According to RFC 8693, the access token must be a base64 encoded JSON object
138+
According to RFC 8693, the access token must be a base64-encoded JSON object
131139
containing the AWS credentials. This function will decode and deserialize the
132140
access token and return the AWS credentials.
133141
@@ -138,12 +146,17 @@ def _decode_access_token(access_token: str) -> Dict[str, str]:
138146
139147
Returns
140148
-------
141-
Dict[str, str]
149+
Credentials
142150
Dictionary containing AWS credentials with keys:
143151
access_key_id, secret_access_key, session_token, and expiration
144152
"""
145153
decoded_bytes = base64.b64decode(access_token)
146154
decoded_str = decoded_bytes.decode("utf-8")
147155
aws_credentials = json.loads(decoded_str)
148156

149-
return aws_credentials
157+
return Credentials(
158+
aws_access_key_id=aws_credentials["accessKeyId"],
159+
aws_secret_access_key=aws_credentials["secretAccessKey"],
160+
aws_session_token=aws_credentials["sessionToken"],
161+
expiration=datetime.strptime(aws_credentials["expiration"], "%Y-%m-%dT%H:%M:%SZ"),
162+
)

tests/posit/connect/external/test_aws.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,22 @@
1+
from datetime import datetime
2+
13
import pytest
24
import responses
35

46
from posit.connect import Client
57
from posit.connect.external.aws import (
8+
Credentials,
69
_decode_access_token,
7-
get_aws_content_credentials,
8-
get_aws_credentials,
10+
get_content_credentials,
11+
get_credentials,
912
)
1013

11-
aws_creds = {
12-
"accessKeyId": "abc123",
13-
"secretAccessKey": "def456",
14-
"sessionToken": "ghi789",
15-
"expiration": "2025-01-01T00:00:00Z",
16-
}
14+
aws_creds = Credentials(
15+
aws_access_key_id="abc123",
16+
aws_secret_access_key="def456",
17+
aws_session_token="ghi789",
18+
expiration=datetime(2025, 1, 1, 0, 0, 0, 0),
19+
)
1720

1821
encoded_aws_creds = "eyJhY2Nlc3NLZXlJZCI6ICJhYmMxMjMiLCAic2VjcmV0QWNjZXNzS2V5IjogImRlZjQ1NiIsICJzZXNzaW9uVG9rZW4iOiAiZ2hpNzg5IiwgImV4cGlyYXRpb24iOiAiMjAyNS0wMS0wMVQwMDowMDowMFoifQ=="
1922

@@ -42,7 +45,7 @@ def test_get_aws_credentials(self):
4245

4346
c = Client(api_key="12345", url="https://connect.example/")
4447
c._ctx.version = None
45-
response = get_aws_credentials(c, "cit")
48+
response = get_credentials(c, "cit")
4649

4750
assert response == aws_creds
4851

@@ -67,7 +70,7 @@ def test_get_aws_credentials_no_token(self):
6770
c._ctx.version = None
6871

6972
with pytest.raises(ValueError) as e:
70-
get_aws_credentials(c, "cit")
73+
get_credentials(c, "cit")
7174

7275
assert e.match("No access token found in credentials")
7376

@@ -94,7 +97,7 @@ def test_get_aws_content_credentials(self):
9497

9598
c = Client(api_key="12345", url="https://connect.example/")
9699
c._ctx.version = None
97-
response = get_aws_content_credentials(c, "cit")
100+
response = get_content_credentials(c, "cit")
98101

99102
assert response == aws_creds
100103

@@ -119,7 +122,7 @@ def test_get_aws_content_credentials_no_token(self):
119122
c._ctx.version = None
120123

121124
with pytest.raises(ValueError) as e:
122-
get_aws_content_credentials(c, "cit")
125+
get_content_credentials(c, "cit")
123126

124127
assert e.match("No access token found in credentials")
125128

0 commit comments

Comments
 (0)