Skip to content

Commit dfc4781

Browse files
committed
Merge remote-tracking branch 'origin/main' into staging/0.9.0
2 parents 37b58ab + 23d5778 commit dfc4781

File tree

9 files changed

+232
-42
lines changed

9 files changed

+232
-42
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,3 +170,6 @@ cython_debug/
170170

171171
/.luarc.json
172172
_dev/
173+
174+
# license files should not be commited to this repository
175+
*.lic

integration/.gitignore

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,2 @@
1-
*.lic
21
logs
32
reports

src/posit/connect/client.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,9 @@ def oauth(self) -> OAuth:
349349
@property
350350
@requires(version="2024.11.0")
351351
def packages(self) -> Packages:
352-
return _PaginatedResourceSequence(self._ctx, "v1/packages", uid="name")
352+
return _PaginatedResourceSequence(
353+
self._ctx, "v1/packages", uid="name", page_size=1_000_000
354+
)
353355

354356
@property
355357
def system(self) -> System:

src/posit/connect/external/aws.py

Lines changed: 120 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,22 +2,52 @@
22

33
import base64
44
import json
5+
from datetime import datetime
56

6-
from typing_extensions import TYPE_CHECKING, Dict
7+
from typing_extensions import TYPE_CHECKING, Optional, TypedDict
78

89
from ..oauth.oauth import OAuthTokenType
910

1011
if TYPE_CHECKING:
1112
from ..client import Client
1213

1314

14-
def get_aws_credentials(client: Client, user_session_token: str) -> Dict[str, str]:
15+
class Credentials(TypedDict):
16+
aws_access_key_id: str
17+
aws_secret_access_key: str
18+
aws_session_token: str
19+
expiration: datetime
20+
21+
22+
def get_credentials(client: Client, user_session_token: str) -> Credentials:
1523
"""
16-
Get AWS credentials using OAuth token exchange.
24+
Get AWS credentials using OAuth token exchange for an AWS Viewer integration.
1725
18-
According to RFC 8693, the access token must be a base64 encoded JSON object
19-
containing the AWS credentials. This function will decode and deserialize the
20-
access token and return the AWS credentials.
26+
According to RFC 8693, the access token must be a base64-encoded JSON object
27+
containing the AWS credentials. This function will return the decoded and
28+
deserialized AWS credentials.
29+
30+
Examples
31+
--------
32+
```python
33+
from posit.connect import Client
34+
from posit.connect.external.aws import get_aws_credentials
35+
import boto3
36+
from shiny.express import session
37+
38+
client = Client()
39+
session_token = session.http_conn.headers.get("Posit-Connect-User-Session-Token")
40+
credentials = get_aws_credentials(client, user_session_token)
41+
aws_session_expiration = credentials["expiration"]
42+
aws_session = boto3.Session(
43+
aws_access_key_id=credentials["aws_access_key_id"],
44+
aws_secret_access_key=credentials["aws_secret_access_key"],
45+
aws_session_token=credentials["aws_session_token"],
46+
)
47+
48+
s3 = aws_session.resource("s3")
49+
bucket = s3.Bucket("your-bucket-name")
50+
```
2151
2252
Parameters
2353
----------
@@ -42,8 +72,91 @@ def get_aws_credentials(client: Client, user_session_token: str) -> Dict[str, st
4272
access_token = credentials.get("access_token")
4373
if not access_token:
4474
raise ValueError("No access token found in credentials")
75+
return _decode_access_token(access_token)
76+
77+
78+
def get_content_credentials(
79+
client: Client, content_session_token: Optional[str] = None
80+
) -> Credentials:
81+
"""
82+
Get AWS credentials using OAuth token exchange for an AWS Service Account integration.
83+
84+
According to RFC 8693, the access token must be a base64-encoded JSON object
85+
containing the AWS credentials. This function will return the decoded and
86+
deserialized AWS credentials.
87+
88+
Examples
89+
--------
90+
```python
91+
from posit.connect import Client
92+
from posit.connect.external.aws import get_aws_content_credentials
93+
import boto3
94+
95+
client = Client()
96+
credentials = get_aws_content_credentials(client)
97+
session_expiration = credentials["expiration"]
98+
aws_session = boto3.Session(
99+
aws_access_key_id=credentials["aws_access_key_id"],
100+
aws_secret_access_key=credentials["aws_secret_access_key"],
101+
aws_session_token=credentials["aws_session_token"],
102+
)
103+
104+
s3 = session.resource("s3")
105+
bucket = s3.Bucket("your-bucket-name")
106+
```
107+
108+
Parameters
109+
----------
110+
client : Client
111+
The client to use for making requests
112+
content_session_token : str
113+
The content session token to exchange
114+
115+
Returns
116+
-------
117+
Dict[str, str]
118+
Dictionary containing AWS credentials with keys:
119+
access_key_id, secret_access_key, session_token, and expiration
120+
"""
121+
# Get credentials using OAuth
122+
credentials = client.oauth.get_content_credentials(
123+
content_session_token=content_session_token,
124+
requested_token_type=OAuthTokenType.AWS_CREDENTIALS,
125+
)
126+
127+
# Decode base64 access token
128+
access_token = credentials.get("access_token")
129+
if not access_token:
130+
raise ValueError("No access token found in credentials")
131+
return _decode_access_token(access_token)
132+
133+
134+
def _decode_access_token(access_token: str) -> Credentials:
135+
"""
136+
Decode and deserialize an access token containing AWS credentials.
137+
138+
According to RFC 8693, the access token must be a base64-encoded JSON object
139+
containing the AWS credentials. This function will decode and deserialize the
140+
access token and return the AWS credentials.
141+
142+
Parameters
143+
----------
144+
access_token : str
145+
The access token to decode
146+
147+
Returns
148+
-------
149+
Credentials
150+
Dictionary containing AWS credentials with keys:
151+
access_key_id, secret_access_key, session_token, and expiration
152+
"""
45153
decoded_bytes = base64.b64decode(access_token)
46154
decoded_str = decoded_bytes.decode("utf-8")
47155
aws_credentials = json.loads(decoded_str)
48156

49-
return aws_credentials
157+
return Credentials(
158+
aws_access_key_id=aws_credentials["accessKeyId"],
159+
aws_secret_access_key=aws_credentials["secretAccessKey"],
160+
aws_session_token=aws_credentials["sessionToken"],
161+
expiration=datetime.strptime(aws_credentials["expiration"], "%Y-%m-%dT%H:%M:%SZ"),
162+
)

src/posit/connect/oauth/oauth.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,13 +76,19 @@ def get_credentials(
7676
response = self._ctx.client.post(self._path, data=data)
7777
return Credentials(**response.json())
7878

79-
def get_content_credentials(self, content_session_token: Optional[str] = None) -> Credentials:
79+
def get_content_credentials(
80+
self,
81+
content_session_token: Optional[str] = None,
82+
requested_token_type: Optional[str | OAuthTokenType] = None,
83+
) -> Credentials:
8084
"""Perform an oauth credential exchange with a content-session-token."""
8185
# craft a credential exchange request
8286
data = {}
8387
data["grant_type"] = GRANT_TYPE
8488
data["subject_token_type"] = OAuthTokenType.CONTENT_SESSION_TOKEN
8589
data["subject_token"] = content_session_token or _get_content_session_token()
90+
if requested_token_type:
91+
data["requested_token_type"] = requested_token_type
8692

8793
response = self._ctx.client.post(self._path, data=data)
8894
return Credentials(**response.json())

src/posit/connect/paginator.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,16 +43,14 @@ class Paginator:
4343
"""
4444

4545
def __init__(
46-
self,
47-
ctx: Context,
48-
path: str,
49-
params: dict | None = None,
46+
self, ctx: Context, path: str, params: dict | None = None, page_size: int | None = None
5047
) -> None:
5148
if params is None:
5249
params = {}
5350
self._ctx = ctx
5451
self._path = path
5552
self._params = params
53+
self._page_size = page_size or _MAX_PAGE_SIZE
5654

5755
def fetch_results(self) -> List[dict]:
5856
"""
@@ -109,7 +107,7 @@ def fetch_page(self, page_number: int) -> Page:
109107
params = {
110108
**self._params,
111109
"page_number": page_number,
112-
"page_size": _MAX_PAGE_SIZE,
110+
"page_size": self._page_size,
113111
}
114112
response = self._ctx.client.get(self._path, params=params)
115113
return Page(**response.json())

src/posit/connect/resources.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,8 +177,12 @@ def find_by(self, **conditions) -> Any | None:
177177

178178

179179
class _PaginatedResourceSequence(_ResourceSequence):
180+
def __init__(self, ctx, path: str, *, uid: str = "guid", page_size: int | None = None):
181+
super().__init__(ctx, path, uid=uid)
182+
self._page_size = page_size
183+
180184
def fetch(self, **conditions):
181-
paginator = Paginator(self._ctx, self._path, dict(**conditions))
185+
paginator = Paginator(self._ctx, self._path, dict(**conditions), page_size=self._page_size)
182186
for page in paginator.fetch_pages():
183187
resources = []
184188
results = page.results

tests/posit/connect/external/test_aws.py

Lines changed: 72 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,22 @@
1+
from datetime import datetime
2+
13
import pytest
24
import responses
35

46
from posit.connect import Client
5-
from posit.connect.external.aws import get_aws_credentials
7+
from posit.connect.external.aws import (
8+
Credentials,
9+
_decode_access_token,
10+
get_content_credentials,
11+
get_credentials,
12+
)
613

7-
aws_creds = {
8-
"accessKeyId": "abc123",
9-
"secretAccessKey": "def456",
10-
"sessionToken": "ghi789",
11-
"expiration": "2025-01-01T00:00:00Z",
12-
}
14+
aws_creds = Credentials(
15+
aws_access_key_id="abc123",
16+
aws_secret_access_key="def456",
17+
aws_session_token="ghi789",
18+
expiration=datetime(2025, 1, 1, 0, 0, 0, 0),
19+
)
1320

1421
encoded_aws_creds = "eyJhY2Nlc3NLZXlJZCI6ICJhYmMxMjMiLCAic2VjcmV0QWNjZXNzS2V5IjogImRlZjQ1NiIsICJzZXNzaW9uVG9rZW4iOiAiZ2hpNzg5IiwgImV4cGlyYXRpb24iOiAiMjAyNS0wMS0wMVQwMDowMDowMFoifQ=="
1522

@@ -38,7 +45,7 @@ def test_get_aws_credentials(self):
3845

3946
c = Client(api_key="12345", url="https://connect.example/")
4047
c._ctx.version = None
41-
response = get_aws_credentials(c, "cit")
48+
response = get_credentials(c, "cit")
4249

4350
assert response == aws_creds
4451

@@ -63,6 +70,62 @@ def test_get_aws_credentials_no_token(self):
6370
c._ctx.version = None
6471

6572
with pytest.raises(ValueError) as e:
66-
get_aws_credentials(c, "cit")
73+
get_credentials(c, "cit")
74+
75+
assert e.match("No access token found in credentials")
76+
77+
@responses.activate
78+
def test_get_aws_content_credentials(self):
79+
responses.post(
80+
"https://connect.example/__api__/v1/oauth/integrations/credentials",
81+
match=[
82+
responses.matchers.urlencoded_params_matcher(
83+
{
84+
"grant_type": "urn:ietf:params:oauth:grant-type:token-exchange",
85+
"subject_token_type": "urn:posit:connect:content-session-token",
86+
"subject_token": "cit",
87+
"requested_token_type": "urn:ietf:params:aws:token-type:credentials",
88+
}
89+
)
90+
],
91+
json={
92+
"access_token": encoded_aws_creds,
93+
"issued_token_type": "urn:ietf:params:aws:token-type:credentials",
94+
"token_type": "aws_credentials",
95+
},
96+
)
97+
98+
c = Client(api_key="12345", url="https://connect.example/")
99+
c._ctx.version = None
100+
response = get_content_credentials(c, "cit")
101+
102+
assert response == aws_creds
103+
104+
@responses.activate
105+
def test_get_aws_content_credentials_no_token(self):
106+
responses.post(
107+
"https://connect.example/__api__/v1/oauth/integrations/credentials",
108+
match=[
109+
responses.matchers.urlencoded_params_matcher(
110+
{
111+
"grant_type": "urn:ietf:params:oauth:grant-type:token-exchange",
112+
"subject_token_type": "urn:posit:connect:content-session-token",
113+
"subject_token": "cit",
114+
"requested_token_type": "urn:ietf:params:aws:token-type:credentials",
115+
}
116+
)
117+
],
118+
json={},
119+
)
120+
121+
c = Client(api_key="12345", url="https://connect.example/")
122+
c._ctx.version = None
123+
124+
with pytest.raises(ValueError) as e:
125+
get_content_credentials(c, "cit")
67126

68127
assert e.match("No access token found in credentials")
128+
129+
def test_decode_access_token(self):
130+
decoded_creds = _decode_access_token(encoded_aws_creds)
131+
assert decoded_creds == aws_creds

0 commit comments

Comments
 (0)