Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions src/posit/connect/_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import os

from typing_extensions import Any


Expand Down Expand Up @@ -28,3 +30,13 @@ def update_dict_values(obj: dict[str, Any], /, **kwargs: Any) -> None:

# Use the `dict` class to explicity update the object in-place
dict.update(obj, **kwargs)


def is_local() -> bool:
"""Returns true if called from a piece of content running on a Connect server.

The connect server will always set the environment variable `RSTUDIO_PRODUCT=CONNECT`.
We can use this environment variable to determine if the content is running locally
or on a Connect server.
"""
return os.getenv("RSTUDIO_PRODUCT") != "CONNECT"
18 changes: 12 additions & 6 deletions src/posit/connect/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,9 @@ def with_user_session_token(self, token: str) -> Client:
"""Create a new Client scoped to the user specified in the user session token.

Create a new Client instance from a user session token exchange for an api key scoped to the
user specified in the token.
user specified in the token (the user viewing your app). If running your application locally,
a user session token will not exist, which will cause this method to result in an error needing
to be handled in your application.

Parameters
----------
Expand All @@ -195,14 +197,18 @@ def with_user_session_token(self, token: str) -> Client:
>>> from posit.connect import Client
>>> client = Client().with_user_session_token("my-user-session-token")
"""
viewer_credentials = self.oauth.get_credentials(
if token is None or token == "":
raise ValueError("token must be set to non-empty string.")

visitor_credentials = self.oauth.get_credentials(
token, requested_token_type=API_KEY_TOKEN_TYPE
)
viewer_api_key = viewer_credentials.get("access_token")
if viewer_api_key is None:
raise ValueError("Unable to retrieve viewer api key.")

return Client(url=self.cfg.url, api_key=viewer_api_key)
visitor_api_key = visitor_credentials.get("access_token", "")
if visitor_api_key == "":
raise ValueError("Unable to retrieve token.")

return Client(url=self.cfg.url, api_key=visitor_api_key)

@property
def content(self) -> Content:
Expand Down
2 changes: 1 addition & 1 deletion src/posit/connect/external/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
import requests
from typing_extensions import Callable, Dict, Optional

from .._utils import is_local
from ..client import Client
from ..oauth import Credentials
from .external import is_local

POSIT_OAUTH_INTEGRATION_AUTH_TYPE = "posit-oauth-integration"
POSIT_LOCAL_CLIENT_CREDENTIALS_AUTH_TYPE = "posit-local-client-credentials"
Expand Down
11 changes: 0 additions & 11 deletions src/posit/connect/external/external.py

This file was deleted.

2 changes: 1 addition & 1 deletion src/posit/connect/external/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@

from typing_extensions import Optional

from .._utils import is_local
from ..client import Client
from .external import is_local


class PositAuthenticator:
Expand Down
38 changes: 33 additions & 5 deletions tests/posit/connect/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def test_init(
MockSession.assert_called_once()

@responses.activate
@patch.dict("os.environ", {"RSTUDIO_PRODUCT": "CONNECT"})
def test_with_user_session_token(self):
api_key = "12345"
url = "https://connect.example.com"
Expand All @@ -110,13 +111,14 @@ def test_with_user_session_token(self):
},
)

viewer_client = client.with_user_session_token("cit")
visitor_client = client.with_user_session_token("cit")

assert viewer_client.cfg.url == "https://connect.example.com/__api__"
assert viewer_client.cfg.api_key == "api-key"
assert visitor_client.cfg.url == "https://connect.example.com/__api__"
assert visitor_client.cfg.api_key == "api-key"

@responses.activate
def test_with_user_session_token_bad_exchange(self):
@patch.dict("os.environ", {"RSTUDIO_PRODUCT": "CONNECT"})
def test_with_user_session_token_bad_exchange_response_body(self):
api_key = "12345"
url = "https://connect.example.com"
client = Client(api_key=api_key, url=url)
Expand All @@ -137,8 +139,34 @@ def test_with_user_session_token_bad_exchange(self):
json={},
)

with pytest.raises(ValueError):
with pytest.raises(ValueError) as err:
client.with_user_session_token("cit")
assert str(err.value) == "Unable to retrieve token."

@patch.dict("os.environ", {"RSTUDIO_PRODUCT": "CONNECT"})
def test_with_user_session_token_bad_token_deployed(self):
api_key = "12345"
url = "https://connect.example.com"
client = Client(api_key=api_key, url=url)
client._ctx.version = None

with pytest.raises(ValueError) as err:
client.with_user_session_token("")
assert str(err.value) == "token must be set to non-empty string."

def test_with_user_session_token_bad_token_local(self):
api_key = "12345"
url = "https://connect.example.com"
client = Client(api_key=api_key, url=url)
client._ctx.version = None

with pytest.raises(ValueError) as e:
client.with_user_session_token("")
assert str(e.value) == "token must be set to non-empty string."

with pytest.raises(ValueError) as e:
client.with_user_session_token(None) # type: ignore
assert str(e.value) == "token must be set to non-empty string."

def test__del__(
self,
Expand Down
Loading