Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 35 additions & 48 deletions qnexus/client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,78 +26,67 @@
VERSION = version("qnexus")


class AuthHandler(httpx.Auth):
"""Custom nexus auth handler"""
def get_cookies_from_disk() -> httpx.Cookies:
cookies = httpx.Cookies()
try:
refresh_token = read_token("refresh_token")
cookies.set("myqos_oat", refresh_token, domain=CONFIG.domain)
except FileNotFoundError:
pass
try:
access_token = read_token("access_token")
cookies.set("myqos_id", access_token, domain=CONFIG.domain)
except FileNotFoundError:
pass

cookies: httpx.Cookies
return cookies

def __init__(self) -> None:
self.cookies = httpx.Cookies()
self.reload_tokens()

super().__init__()
def set_cookie_header(cookies: httpx.Cookies, request: httpx.Request) -> None:
"""by default cookies.set_cookie_header(...) doesn't overwrite cookies if they already exist in the request header"""
if request.headers.get("cookie"):
request.headers.pop("cookie")
cookies.set_cookie_header(request)

def reload_tokens(self) -> None:
"""Clear tokens and attempt to reload from the file system."""
try:
self.cookies.clear()
token = read_token("refresh_token")
self.cookies.set("myqos_oat", token, domain=CONFIG.domain)
id_token = read_token("access_token")
self.cookies.set("myqos_id", id_token, domain=CONFIG.domain)
except FileNotFoundError:
pass

class AuthHandler(httpx.Auth):
"""Custom nexus auth handler"""

def auth_flow(
self, request: httpx.Request
) -> typing.Generator[httpx.Request, httpx.Response, None]:
self.cookies.set_cookie_header(request)

cookies = get_cookies_from_disk()
set_cookie_header(cookies, request)
response = yield request

_check_sunset_header(request, response)

if response.status_code == 401:
if self.cookies.get("myqos_oat") is None:
try:
token = read_token(
"refresh_token",
)
self.cookies.set("myqos_oat", token, domain=CONFIG.domain)
except FileNotFoundError as exc:
raise AuthenticationError(
"Not authenticated. Please run `qnx login` in your terminal."
) from exc

auth_response = yield self.build_refresh_request()
auth_response = yield httpx.Request(
method="POST",
url=f"{CONFIG.url}/auth/tokens/refresh",
cookies=cookies,
headers={VERSION_HEADER: VERSION},
)

if auth_response.status_code == 401:
raise AuthenticationError(
"Not authenticated. Please run `qnx login` in your terminal."
)

auth_response.raise_for_status()
self.cookies.extract_cookies(auth_response)
auth_response_cookies = httpx.Cookies()
auth_response_cookies.extract_cookies(auth_response)

write_token(
"access_token",
self.cookies.get("myqos_id", domain=CONFIG.domain) or "",
auth_response_cookies.get("myqos_id") or "",
)
if request.headers.get("cookie"):
request.headers.pop("cookie")
self.cookies.set_cookie_header(request)

_check_version_headers(auth_response)
yield request
set_cookie_header(auth_response_cookies, request)

def build_refresh_request(self) -> httpx.Request:
"""Build the request for refreshing the id token."""
self.cookies.delete("myqos_id") # We need to delete any existing id token first
return httpx.Request(
method="POST",
url=f"{CONFIG.url}/auth/tokens/refresh",
cookies=self.cookies,
headers={VERSION_HEADER: VERSION},
)
yield request


_nexus_client: httpx.Client | None = None
Expand All @@ -113,8 +102,6 @@ def get_nexus_client(reload: bool = False) -> httpx.Client:
global _nexus_client
if _nexus_client is None or reload:
_auth_handler = AuthHandler()
_auth_handler.reload_tokens()

_nexus_client = httpx.Client(
base_url=CONFIG.url,
auth=_auth_handler,
Expand Down
3 changes: 1 addition & 2 deletions scripts/run_unit_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,11 @@ set -e
# Order doesn't matter but auth tests manipulate environment variables
# and should be run separately
uv run pytest --cov-reset tests/test_auth.py::test_token_refresh
uv run pytest tests/test_auth.py::test_nexus_client_reloads_tokens
uv run pytest tests/test_auth.py::test_nexus_client_reloads_domain
uv run pytest tests/test_auth.py::test_token_refresh_expired


echo "Running non-auth tests"
uv run pytest tests/ -v --ignore=tests/test_auth.py

echo -e "\n🎉 All tests passed successfully!"
echo -e "\n🎉 All tests passed successfully!"
22 changes: 0 additions & 22 deletions tests/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ def test_token_refresh() -> None:

# Confirm that the access token was updated
assert read_token("access_token") == refreshed_access_token
assert get_nexus_client().auth.cookies.get("myqos_id") == refreshed_access_token # type: ignore

# confirm that the request headers were updated
first_cookie_header = list_project_route.calls[0].request.headers["cookie"]
Expand Down Expand Up @@ -119,27 +118,6 @@ def test_token_refresh_expired() -> None:
assert refresh_token_route.called


def test_nexus_client_reloads_tokens() -> None:
"""Test the reload functionality of the nexus client.

Test that if we write new tokens and reload the client,
that the new tokens are used."""

oat_one = "dummy_oat_one"
oat_two = "dummy_oat_two"

write_token("refresh_token", oat_one)
client_one = get_nexus_client(reload=True)
assert client_one.auth.cookies.get("myqos_oat") == oat_one # type: ignore

write_token("refresh_token", oat_two)
client_two = get_nexus_client()
assert client_two.auth.cookies.get("myqos_oat") == oat_one # type: ignore

client_two = get_nexus_client(reload=True)
assert client_two.auth.cookies.get("myqos_oat") == oat_two # type: ignore


def test_nexus_client_reloads_domain() -> None:
"""Test the reload functionality of the nexus client.
We should be able to change the domain in the config
Expand Down