diff --git a/qnexus/client/__init__.py b/qnexus/client/__init__.py index 7d85a5b..180960d 100644 --- a/qnexus/client/__init__.py +++ b/qnexus/client/__init__.py @@ -26,78 +26,67 @@ VERSION = version("qnexus") -class AuthHandler(httpx.Auth): - """Custom nexus auth handler""" +def get_cookies_from_disk() -> httpx.Cookies: + cookies = httpx.Cookies() + try: + refresh_token = read_token("refresh_token") + cookies.set("myqos_oat", refresh_token, domain=CONFIG.domain) + except FileNotFoundError: + pass + try: + access_token = read_token("access_token") + cookies.set("myqos_id", access_token, domain=CONFIG.domain) + except FileNotFoundError: + pass - cookies: httpx.Cookies + return cookies - def __init__(self) -> None: - self.cookies = httpx.Cookies() - self.reload_tokens() - super().__init__() +def set_cookie_header(cookies: httpx.Cookies, request: httpx.Request) -> None: + """by default cookies.set_cookie_header(...) doesn't overwrite cookies if they already exist in the request header""" + if request.headers.get("cookie"): + request.headers.pop("cookie") + cookies.set_cookie_header(request) - def reload_tokens(self) -> None: - """Clear tokens and attempt to reload from the file system.""" - try: - self.cookies.clear() - token = read_token("refresh_token") - self.cookies.set("myqos_oat", token, domain=CONFIG.domain) - id_token = read_token("access_token") - self.cookies.set("myqos_id", id_token, domain=CONFIG.domain) - except FileNotFoundError: - pass + +class AuthHandler(httpx.Auth): + """Custom nexus auth handler""" def auth_flow( self, request: httpx.Request ) -> typing.Generator[httpx.Request, httpx.Response, None]: - self.cookies.set_cookie_header(request) - + cookies = get_cookies_from_disk() + set_cookie_header(cookies, request) response = yield request _check_sunset_header(request, response) if response.status_code == 401: - if self.cookies.get("myqos_oat") is None: - try: - token = read_token( - "refresh_token", - ) - self.cookies.set("myqos_oat", token, domain=CONFIG.domain) - except FileNotFoundError as exc: - raise AuthenticationError( - "Not authenticated. Please run `qnx login` in your terminal." - ) from exc - - auth_response = yield self.build_refresh_request() + auth_response = yield httpx.Request( + method="POST", + url=f"{CONFIG.url}/auth/tokens/refresh", + cookies=cookies, + headers={VERSION_HEADER: VERSION}, + ) + if auth_response.status_code == 401: raise AuthenticationError( "Not authenticated. Please run `qnx login` in your terminal." ) auth_response.raise_for_status() - self.cookies.extract_cookies(auth_response) + auth_response_cookies = httpx.Cookies() + auth_response_cookies.extract_cookies(auth_response) write_token( "access_token", - self.cookies.get("myqos_id", domain=CONFIG.domain) or "", + auth_response_cookies.get("myqos_id") or "", ) - if request.headers.get("cookie"): - request.headers.pop("cookie") - self.cookies.set_cookie_header(request) _check_version_headers(auth_response) - yield request + set_cookie_header(auth_response_cookies, request) - def build_refresh_request(self) -> httpx.Request: - """Build the request for refreshing the id token.""" - self.cookies.delete("myqos_id") # We need to delete any existing id token first - return httpx.Request( - method="POST", - url=f"{CONFIG.url}/auth/tokens/refresh", - cookies=self.cookies, - headers={VERSION_HEADER: VERSION}, - ) + yield request _nexus_client: httpx.Client | None = None @@ -113,8 +102,6 @@ def get_nexus_client(reload: bool = False) -> httpx.Client: global _nexus_client if _nexus_client is None or reload: _auth_handler = AuthHandler() - _auth_handler.reload_tokens() - _nexus_client = httpx.Client( base_url=CONFIG.url, auth=_auth_handler, diff --git a/scripts/run_unit_tests.sh b/scripts/run_unit_tests.sh index d3675ee..1825adb 100755 --- a/scripts/run_unit_tests.sh +++ b/scripts/run_unit_tests.sh @@ -5,7 +5,6 @@ set -e # Order doesn't matter but auth tests manipulate environment variables # and should be run separately uv run pytest --cov-reset tests/test_auth.py::test_token_refresh -uv run pytest tests/test_auth.py::test_nexus_client_reloads_tokens uv run pytest tests/test_auth.py::test_nexus_client_reloads_domain uv run pytest tests/test_auth.py::test_token_refresh_expired @@ -13,4 +12,4 @@ uv run pytest tests/test_auth.py::test_token_refresh_expired echo "Running non-auth tests" uv run pytest tests/ -v --ignore=tests/test_auth.py -echo -e "\nšŸŽ‰ All tests passed successfully!" \ No newline at end of file +echo -e "\nšŸŽ‰ All tests passed successfully!" diff --git a/tests/test_auth.py b/tests/test_auth.py index 7ea1e74..e8e7b10 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -85,7 +85,6 @@ def test_token_refresh() -> None: # Confirm that the access token was updated assert read_token("access_token") == refreshed_access_token - assert get_nexus_client().auth.cookies.get("myqos_id") == refreshed_access_token # type: ignore # confirm that the request headers were updated first_cookie_header = list_project_route.calls[0].request.headers["cookie"] @@ -119,27 +118,6 @@ def test_token_refresh_expired() -> None: assert refresh_token_route.called -def test_nexus_client_reloads_tokens() -> None: - """Test the reload functionality of the nexus client. - - Test that if we write new tokens and reload the client, - that the new tokens are used.""" - - oat_one = "dummy_oat_one" - oat_two = "dummy_oat_two" - - write_token("refresh_token", oat_one) - client_one = get_nexus_client(reload=True) - assert client_one.auth.cookies.get("myqos_oat") == oat_one # type: ignore - - write_token("refresh_token", oat_two) - client_two = get_nexus_client() - assert client_two.auth.cookies.get("myqos_oat") == oat_one # type: ignore - - client_two = get_nexus_client(reload=True) - assert client_two.auth.cookies.get("myqos_oat") == oat_two # type: ignore - - def test_nexus_client_reloads_domain() -> None: """Test the reload functionality of the nexus client. We should be able to change the domain in the config