diff --git a/src/posit/connect/_utils.py b/src/posit/connect/_utils.py index c35dabd9..d9c1b083 100644 --- a/src/posit/connect/_utils.py +++ b/src/posit/connect/_utils.py @@ -1,5 +1,7 @@ from __future__ import annotations +import os + from typing_extensions import Any @@ -28,3 +30,13 @@ def update_dict_values(obj: dict[str, Any], /, **kwargs: Any) -> None: # Use the `dict` class to explicity update the object in-place dict.update(obj, **kwargs) + + +def is_local() -> bool: + """Returns true if called from a piece of content running on a Connect server. + + The connect server will always set the environment variable `RSTUDIO_PRODUCT=CONNECT`. + We can use this environment variable to determine if the content is running locally + or on a Connect server. + """ + return os.getenv("RSTUDIO_PRODUCT") != "CONNECT" diff --git a/src/posit/connect/client.py b/src/posit/connect/client.py index 5e02f52f..ce14525d 100644 --- a/src/posit/connect/client.py +++ b/src/posit/connect/client.py @@ -178,7 +178,9 @@ def with_user_session_token(self, token: str) -> Client: """Create a new Client scoped to the user specified in the user session token. Create a new Client instance from a user session token exchange for an api key scoped to the - user specified in the token. + user specified in the token (the user viewing your app). If running your application locally, + a user session token will not exist, which will cause this method to result in an error needing + to be handled in your application. Parameters ---------- @@ -195,14 +197,18 @@ def with_user_session_token(self, token: str) -> Client: >>> from posit.connect import Client >>> client = Client().with_user_session_token("my-user-session-token") """ - viewer_credentials = self.oauth.get_credentials( + if token is None or token == "": + raise ValueError("token must be set to non-empty string.") + + visitor_credentials = self.oauth.get_credentials( token, requested_token_type=API_KEY_TOKEN_TYPE ) - viewer_api_key = viewer_credentials.get("access_token") - if viewer_api_key is None: - raise ValueError("Unable to retrieve viewer api key.") - return Client(url=self.cfg.url, api_key=viewer_api_key) + visitor_api_key = visitor_credentials.get("access_token", "") + if visitor_api_key == "": + raise ValueError("Unable to retrieve token.") + + return Client(url=self.cfg.url, api_key=visitor_api_key) @property def content(self) -> Content: diff --git a/src/posit/connect/external/databricks.py b/src/posit/connect/external/databricks.py index 6578d659..1f3b0895 100644 --- a/src/posit/connect/external/databricks.py +++ b/src/posit/connect/external/databricks.py @@ -13,9 +13,9 @@ import requests from typing_extensions import Callable, Dict, Optional +from .._utils import is_local from ..client import Client from ..oauth import Credentials -from .external import is_local POSIT_OAUTH_INTEGRATION_AUTH_TYPE = "posit-oauth-integration" POSIT_LOCAL_CLIENT_CREDENTIALS_AUTH_TYPE = "posit-local-client-credentials" diff --git a/src/posit/connect/external/external.py b/src/posit/connect/external/external.py deleted file mode 100644 index b3492ce4..00000000 --- a/src/posit/connect/external/external.py +++ /dev/null @@ -1,11 +0,0 @@ -import os - - -def is_local() -> bool: - """Returns true if called from a piece of content running on a Connect server. - - The connect server will always set the environment variable `RSTUDIO_PRODUCT=CONNECT`. - We can use this environment variable to determine if the content is running locally - or on a Connect server. - """ - return not os.getenv("RSTUDIO_PRODUCT") == "CONNECT" diff --git a/src/posit/connect/external/snowflake.py b/src/posit/connect/external/snowflake.py index 54789c9b..c40c188d 100644 --- a/src/posit/connect/external/snowflake.py +++ b/src/posit/connect/external/snowflake.py @@ -9,8 +9,8 @@ from typing_extensions import Optional +from .._utils import is_local from ..client import Client -from .external import is_local class PositAuthenticator: diff --git a/tests/posit/connect/test_client.py b/tests/posit/connect/test_client.py index c256de89..be6fe9f9 100644 --- a/tests/posit/connect/test_client.py +++ b/tests/posit/connect/test_client.py @@ -85,6 +85,7 @@ def test_init( MockSession.assert_called_once() @responses.activate + @patch.dict("os.environ", {"RSTUDIO_PRODUCT": "CONNECT"}) def test_with_user_session_token(self): api_key = "12345" url = "https://connect.example.com" @@ -110,13 +111,14 @@ def test_with_user_session_token(self): }, ) - viewer_client = client.with_user_session_token("cit") + visitor_client = client.with_user_session_token("cit") - assert viewer_client.cfg.url == "https://connect.example.com/__api__" - assert viewer_client.cfg.api_key == "api-key" + assert visitor_client.cfg.url == "https://connect.example.com/__api__" + assert visitor_client.cfg.api_key == "api-key" @responses.activate - def test_with_user_session_token_bad_exchange(self): + @patch.dict("os.environ", {"RSTUDIO_PRODUCT": "CONNECT"}) + def test_with_user_session_token_bad_exchange_response_body(self): api_key = "12345" url = "https://connect.example.com" client = Client(api_key=api_key, url=url) @@ -137,8 +139,34 @@ def test_with_user_session_token_bad_exchange(self): json={}, ) - with pytest.raises(ValueError): + with pytest.raises(ValueError) as err: client.with_user_session_token("cit") + assert str(err.value) == "Unable to retrieve token." + + @patch.dict("os.environ", {"RSTUDIO_PRODUCT": "CONNECT"}) + def test_with_user_session_token_bad_token_deployed(self): + api_key = "12345" + url = "https://connect.example.com" + client = Client(api_key=api_key, url=url) + client._ctx.version = None + + with pytest.raises(ValueError) as err: + client.with_user_session_token("") + assert str(err.value) == "token must be set to non-empty string." + + def test_with_user_session_token_bad_token_local(self): + api_key = "12345" + url = "https://connect.example.com" + client = Client(api_key=api_key, url=url) + client._ctx.version = None + + with pytest.raises(ValueError) as e: + client.with_user_session_token("") + assert str(e.value) == "token must be set to non-empty string." + + with pytest.raises(ValueError) as e: + client.with_user_session_token(None) # type: ignore + assert str(e.value) == "token must be set to non-empty string." def test__del__( self,