Skip to content

Commit 6049ee6

Browse files
added standard token type enum; added aws_creds helper
1 parent e37c031 commit 6049ee6

File tree

5 files changed

+162
-10
lines changed

5 files changed

+162
-10
lines changed

src/posit/connect/client.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from .context import Context, ContextManager, requires
1212
from .groups import Groups
1313
from .metrics.metrics import Metrics
14-
from .oauth.oauth import API_KEY_TOKEN_TYPE, OAuth
14+
from .oauth.oauth import OAuth, OAuthTokenType
1515
from .resources import _PaginatedResourceSequence, _ResourceSequence
1616
from .sessions import Session
1717
from .system import System
@@ -256,7 +256,7 @@ def user_profile():
256256
raise ValueError("token must be set to non-empty string.")
257257

258258
visitor_credentials = self.oauth.get_credentials(
259-
token, requested_token_type=API_KEY_TOKEN_TYPE
259+
token, requested_token_type=OAuthTokenType.API_KEY
260260
)
261261

262262
visitor_api_key = visitor_credentials.get("access_token", "")

src/posit/connect/external/aws.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
from __future__ import annotations
2+
3+
import base64
4+
import json
5+
6+
from typing_extensions import TYPE_CHECKING, Dict
7+
8+
from ..oauth.oauth import OAuthTokenType
9+
10+
if TYPE_CHECKING:
11+
from ..client import Client
12+
13+
14+
def get_aws_credentials(client: Client, user_session_token: str) -> Dict[str, str]:
15+
"""
16+
Get AWS credentials using OAuth token exchange.
17+
18+
According to RFC 8693, the access token must be a base64 encoded JSON object
19+
containing the AWS credentials. This function will decode and deserialize the
20+
access token and return the AWS credentials.
21+
22+
Parameters
23+
----------
24+
client : Client
25+
The client to use for making requests
26+
user_session_token : str
27+
The user session token to exchange
28+
29+
Returns
30+
-------
31+
Dict[str, str]
32+
Dictionary containing AWS credentials with keys:
33+
access_key_id, secret_access_key, session_token, and expiration
34+
"""
35+
# Get credentials using OAuth
36+
credentials = client.oauth.get_credentials(
37+
user_session_token=user_session_token,
38+
requested_token_type=OAuthTokenType.AWS_CREDENTIALS,
39+
)
40+
41+
# Decode base64 access token
42+
access_token = credentials.get("access_token")
43+
if not access_token:
44+
raise ValueError("No access token found in credentials")
45+
decoded_bytes = base64.b64decode(access_token)
46+
decoded_str = decoded_bytes.decode("utf-8")
47+
aws_credentials = json.loads(decoded_str)
48+
49+
return aws_credentials

src/posit/connect/oauth/oauth.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import os
4+
from enum import Enum
45

56
from typing_extensions import TYPE_CHECKING, Optional, TypedDict
67

@@ -12,9 +13,14 @@
1213
from ..context import Context
1314

1415
GRANT_TYPE = "urn:ietf:params:oauth:grant-type:token-exchange"
15-
USER_SESSION_TOKEN_TYPE = "urn:posit:connect:user-session-token"
16-
CONTENT_SESSION_TOKEN_TYPE = "urn:posit:connect:content-session-token"
17-
API_KEY_TOKEN_TYPE = "urn:posit:connect:api-key"
16+
17+
18+
class OAuthTokenType(str, Enum):
19+
ACCESS_TOKEN = "urn:ietf:params:oauth:token-type:access_token"
20+
AWS_CREDENTIALS = "urn:ietf:params:aws:token-type:credentials"
21+
API_KEY = "urn:posit:connect:api-key"
22+
CONTENT_SESSION_TOKEN = "urn:posit:connect:content-session-token"
23+
USER_SESSION_TOKEN = "urn:posit:connect:user-session-token"
1824

1925

2026
def _get_content_session_token() -> str:
@@ -53,13 +59,15 @@ def sessions(self):
5359
return Sessions(self._ctx)
5460

5561
def get_credentials(
56-
self, user_session_token: Optional[str] = None, requested_token_type: Optional[str] = None
62+
self,
63+
user_session_token: Optional[str] = None,
64+
requested_token_type: Optional[str | OAuthTokenType] = None,
5765
) -> Credentials:
5866
"""Perform an oauth credential exchange with a user-session-token."""
5967
# craft a credential exchange request
6068
data = {}
6169
data["grant_type"] = GRANT_TYPE
62-
data["subject_token_type"] = USER_SESSION_TOKEN_TYPE
70+
data["subject_token_type"] = OAuthTokenType.USER_SESSION_TOKEN
6371
if user_session_token:
6472
data["subject_token"] = user_session_token
6573
if requested_token_type:
@@ -73,7 +81,7 @@ def get_content_credentials(self, content_session_token: Optional[str] = None) -
7381
# craft a credential exchange request
7482
data = {}
7583
data["grant_type"] = GRANT_TYPE
76-
data["subject_token_type"] = CONTENT_SESSION_TOKEN_TYPE
84+
data["subject_token_type"] = OAuthTokenType.CONTENT_SESSION_TOKEN
7785
data["subject_token"] = content_session_token or _get_content_session_token()
7886

7987
response = self._ctx.client.post(self._path, data=data)
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
import pytest
2+
import responses
3+
4+
from posit.connect import Client
5+
from posit.connect.external.aws import get_aws_credentials
6+
7+
aws_creds = {
8+
"accessKeyId": "abc123",
9+
"secretAccessKey": "def456",
10+
"sessionToken": "ghi789",
11+
"expiration": "2025-01-01T00:00:00Z",
12+
}
13+
14+
encoded_aws_creds = "eyJhY2Nlc3NLZXlJZCI6ICJhYmMxMjMiLCAic2VjcmV0QWNjZXNzS2V5IjogImRlZjQ1NiIsICJzZXNzaW9uVG9rZW4iOiAiZ2hpNzg5IiwgImV4cGlyYXRpb24iOiAiMjAyNS0wMS0wMVQwMDowMDowMFoifQ=="
15+
16+
17+
class TestAWS:
18+
@responses.activate
19+
def test_get_aws_credentials(self):
20+
responses.post(
21+
"https://connect.example/__api__/v1/oauth/integrations/credentials",
22+
match=[
23+
responses.matchers.urlencoded_params_matcher(
24+
{
25+
"grant_type": "urn:ietf:params:oauth:grant-type:token-exchange",
26+
"subject_token_type": "urn:posit:connect:user-session-token",
27+
"subject_token": "cit",
28+
"requested_token_type": "urn:ietf:params:aws:token-type:credentials",
29+
}
30+
)
31+
],
32+
json={
33+
"access_token": encoded_aws_creds,
34+
"issued_token_type": "urn:ietf:params:aws:token-type:credentials",
35+
"token_type": "aws_credentials",
36+
},
37+
)
38+
39+
c = Client(api_key="12345", url="https://connect.example/")
40+
c._ctx.version = None
41+
response = get_aws_credentials(c, "cit")
42+
43+
assert response == aws_creds
44+
45+
@responses.activate
46+
def test_get_aws_credentials_no_token(self):
47+
responses.post(
48+
"https://connect.example/__api__/v1/oauth/integrations/credentials",
49+
match=[
50+
responses.matchers.urlencoded_params_matcher(
51+
{
52+
"grant_type": "urn:ietf:params:oauth:grant-type:token-exchange",
53+
"subject_token_type": "urn:posit:connect:user-session-token",
54+
"subject_token": "cit",
55+
"requested_token_type": "urn:ietf:params:aws:token-type:credentials",
56+
}
57+
)
58+
],
59+
json={},
60+
)
61+
62+
c = Client(api_key="12345", url="https://connect.example/")
63+
c._ctx.version = None
64+
65+
with pytest.raises(ValueError) as e:
66+
get_aws_credentials(c, "cit")
67+
68+
assert e.match("No access token found in credentials")

tests/posit/connect/oauth/test_oauth.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import responses
55

66
from posit.connect import Client
7-
from posit.connect.oauth.oauth import API_KEY_TOKEN_TYPE, _get_content_session_token
7+
from posit.connect.oauth.oauth import OAuthTokenType, _get_content_session_token
88

99

1010
class TestOAuthIntegrations:
@@ -62,11 +62,38 @@ def test_get_credentials_api_key(self):
6262
)
6363
c = Client(api_key="12345", url="https://connect.example/")
6464
c._ctx.version = None
65-
creds = c.oauth.get_credentials("cit", API_KEY_TOKEN_TYPE)
65+
creds = c.oauth.get_credentials("cit", OAuthTokenType.API_KEY)
6666
assert creds.get("access_token") == "viewer-api-key"
6767
assert creds.get("issued_token_type") == "urn:posit:connect:api-key"
6868
assert creds.get("token_type") == "Key"
6969

70+
@responses.activate
71+
def test_get_credentials_aws(self):
72+
responses.post(
73+
"https://connect.example/__api__/v1/oauth/integrations/credentials",
74+
match=[
75+
responses.matchers.urlencoded_params_matcher(
76+
{
77+
"grant_type": "urn:ietf:params:oauth:grant-type:token-exchange",
78+
"subject_token_type": "urn:posit:connect:user-session-token",
79+
"subject_token": "cit",
80+
"requested_token_type": "urn:ietf:params:aws:token-type:credentials",
81+
},
82+
),
83+
],
84+
json={
85+
"access_token": "encoded-aws-creds",
86+
"issued_token_type": "urn:ietf:params:aws:token-type:credentials",
87+
"token_type": "aws_credentials",
88+
},
89+
)
90+
c = Client(api_key="12345", url="https://connect.example/")
91+
c._ctx.version = None
92+
creds = c.oauth.get_credentials("cit", OAuthTokenType.AWS_CREDENTIALS)
93+
assert creds.get("access_token") == "encoded-aws-creds"
94+
assert creds.get("issued_token_type") == "urn:ietf:params:aws:token-type:credentials"
95+
assert creds.get("token_type") == "aws_credentials"
96+
7097
@responses.activate
7198
def test_get_content_credentials(self):
7299
responses.post(

0 commit comments

Comments
 (0)