Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions src/posit/connect/_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import os

from typing_extensions import Any


Expand Down Expand Up @@ -28,3 +30,13 @@ def update_dict_values(obj: dict[str, Any], /, **kwargs: Any) -> None:

# Use the `dict` class to explicity update the object in-place
dict.update(obj, **kwargs)


def is_local() -> bool:
"""Returns true if called from a piece of content running on a Connect server.

The connect server will always set the environment variable `RSTUDIO_PRODUCT=CONNECT`.
We can use this environment variable to determine if the content is running locally
or on a Connect server.
"""
return os.getenv("RSTUDIO_PRODUCT") != "CONNECT"
41 changes: 34 additions & 7 deletions src/posit/connect/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@

from __future__ import annotations

import os

from requests import Response, Session
from typing_extensions import TYPE_CHECKING, overload

from . import hooks, me
from ._utils import is_local
from .auth import Auth
from .config import Config
from .content import Content
Expand Down Expand Up @@ -174,16 +177,27 @@ def __init__(self, *args, **kwargs) -> None:
self._ctx = Context(self)

@requires("2025.01.0-dev")
def with_user_session_token(self, token: str) -> Client:
def with_user_session_token(
self, token: str, fallback_api_key_env_var: str = "CONNECT_API_KEY"
) -> Client:
"""Create a new Client scoped to the user specified in the user session token.

Create a new Client instance from a user session token exchange for an api key scoped to the
user specified in the token.
user specified in the token (the user viewing your app). If running your application locally,
you will not have a user session token. In that case, this method will look for an API key in
the environment variable specified in `fallback_api_key_env_var`. If that is not set, the API
key of the original client will be used.

Environment Variables
---------------------
CONNECT_API_KEY - The API key credential for client authentication.

Parameters
----------
token : str
The user session token.
fallback_api_key_env_var: str
Environment variable with a fallback API key for local development.

Returns
-------
Expand All @@ -195,14 +209,27 @@ def with_user_session_token(self, token: str) -> Client:
>>> from posit.connect import Client
>>> client = Client().with_user_session_token("my-user-session-token")
"""
viewer_credentials = self.oauth.get_credentials(
if is_local():
# if user session token is not available when running locally,
# default to using API set in environment variable
return Client(
url=self.cfg.url,
api_key=os.getenv(fallback_api_key_env_var, self.cfg.api_key),
)

if token is None or token == "":
# if deployed to Connect, token must be set
raise ValueError("token must be set to non-empty string.")

visitor_credentials = self.oauth.get_credentials(
token, requested_token_type=API_KEY_TOKEN_TYPE
)
viewer_api_key = viewer_credentials.get("access_token")
if viewer_api_key is None:
raise ValueError("Unable to retrieve viewer api key.")

return Client(url=self.cfg.url, api_key=viewer_api_key)
visitor_api_key = visitor_credentials.get("access_token", "")
if visitor_api_key == "":
raise ValueError("Unable to retrieve visitor API key.")

return Client(url=self.cfg.url, api_key=visitor_api_key)

@property
def content(self) -> Content:
Expand Down
2 changes: 1 addition & 1 deletion src/posit/connect/external/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
import requests
from typing_extensions import Callable, Dict, Optional

from .._utils import is_local
from ..client import Client
from ..oauth import Credentials
from .external import is_local

POSIT_OAUTH_INTEGRATION_AUTH_TYPE = "posit-oauth-integration"
POSIT_LOCAL_CLIENT_CREDENTIALS_AUTH_TYPE = "posit-local-client-credentials"
Expand Down
11 changes: 0 additions & 11 deletions src/posit/connect/external/external.py

This file was deleted.

2 changes: 1 addition & 1 deletion src/posit/connect/external/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@

from typing_extensions import Optional

from .._utils import is_local
from ..client import Client
from .external import is_local


class PositAuthenticator:
Expand Down
41 changes: 36 additions & 5 deletions tests/posit/connect/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def test_init(
MockSession.assert_called_once()

@responses.activate
@patch.dict("os.environ", {"RSTUDIO_PRODUCT": "CONNECT"})
def test_with_user_session_token(self):
api_key = "12345"
url = "https://connect.example.com"
Expand All @@ -110,13 +111,14 @@ def test_with_user_session_token(self):
},
)

viewer_client = client.with_user_session_token("cit")
visitor_client = client.with_user_session_token("cit")

assert viewer_client.cfg.url == "https://connect.example.com/__api__"
assert viewer_client.cfg.api_key == "api-key"
assert visitor_client.cfg.url == "https://connect.example.com/__api__"
assert visitor_client.cfg.api_key == "api-key"

@responses.activate
def test_with_user_session_token_bad_exchange(self):
@patch.dict("os.environ", {"RSTUDIO_PRODUCT": "CONNECT"})
def test_with_user_session_token_bad_exchange_response_body(self):
api_key = "12345"
url = "https://connect.example.com"
client = Client(api_key=api_key, url=url)
Expand All @@ -137,8 +139,37 @@ def test_with_user_session_token_bad_exchange(self):
json={},
)

with pytest.raises(ValueError):
with pytest.raises(ValueError) as err:
client.with_user_session_token("cit")
assert str(err.value) == "Unable to retrieve visitor API key."

@responses.activate
@patch.dict("os.environ", {"RSTUDIO_PRODUCT": "CONNECT"})
def test_with_user_session_token_bad_token_when_deployed(self):
api_key = "12345"
url = "https://connect.example.com"
client = Client(api_key=api_key, url=url)
client._ctx.version = None

with pytest.raises(ValueError) as err:
client.with_user_session_token("")
assert str(err.value) == "token must be set to non-empty string."

@responses.activate
@patch.dict("os.environ", {"CONNECT_API_KEY": "ABC123", "CUSTOM_KEY": "DEF456"})
def test_with_user_session_token_local_env_var_override(self):
api_key = "12345"
url = "https://connect.example.com"
client = Client(api_key=api_key, url=url)
client._ctx.version = None

visitor_client = client.with_user_session_token("")
assert visitor_client.cfg.url == client.cfg.url
assert visitor_client.cfg.api_key == "ABC123"

visitor_client = client.with_user_session_token("", fallback_api_key_env_var="CUSTOM_KEY")
assert visitor_client.cfg.url == client.cfg.url
assert visitor_client.cfg.api_key == "DEF456"

def test__del__(
self,
Expand Down
Loading