Skip to content

Commit 2c2f42f

Browse files
committed
[QI2-1211] Move Qiskit Plugin's API subpackage to compute-api-client
1 parent 9fe162c commit 2c2f42f

File tree

13 files changed

+879
-4
lines changed

13 files changed

+879
-4
lines changed

poetry.lock

Lines changed: 305 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "qi-compute-api-client"
3-
version = "0.42.0"
3+
version = "0.43.0"
44
description = "An API client for the Compute Job Manager of Quantum Inspire."
55
license = "Apache-2.0"
66
authors = ["Quantum Inspire <[email protected]>"]
@@ -18,15 +18,18 @@ classifiers = [
1818
]
1919
packages = [
2020
{ include = "compute_api_client" },
21+
{ include = "qi2_shared" },
2122
]
22-
exclude = ["compute_api_client/test"]
23+
exclude = ["compute_api_client/test", "tests"]
2324

2425
[tool.poetry.dependencies]
2526
# Should not exceed package python version, so be conservative
2627
aiohttp = "^3.10.5"
28+
pydantic = "^2.10.4"
2729
python = "^3.8"
2830
python-dateutil = "^2.8.2"
2931
urllib3 = "^2.0.0"
32+
requests = "^2.32.3"
3033

3134
[tool.poetry.group.dev.dependencies]
3235
pytest = {extras = ["toml"], version = "^8.0.0"}

qi2_shared/__init__.py

Whitespace-only changes.

qi2_shared/authentication.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
import time
2+
from typing import Any, Tuple, cast
3+
4+
import requests
5+
6+
from qi2_shared.settings import ApiSettings, TokenInfo, Url
7+
8+
9+
class AuthorisationError(Exception):
10+
"""Indicates that the authorisation permanently went wrong."""
11+
pass
12+
13+
14+
class IdentityProvider:
15+
"""Class for interfacing with the IdentityProvider."""
16+
17+
def __init__(self, well_known_endpoint: str):
18+
self._well_known_endpoint = well_known_endpoint
19+
self._token_endpoint, self._device_endpoint = self._get_endpoints()
20+
self._headers = {"Content-Type": "application/x-www-form-urlencoded"}
21+
22+
def _get_endpoints(self) -> Tuple[str, str]:
23+
response = requests.get(self._well_known_endpoint)
24+
response.raise_for_status()
25+
config = response.json()
26+
return config["token_endpoint"], config["device_authorization_endpoint"]
27+
28+
def refresh_access_token(self, client_id: str, refresh_token: str) -> dict[str, Any]:
29+
data = {
30+
"grant_type": "refresh_token",
31+
"client_id": client_id,
32+
"refresh_token": refresh_token,
33+
}
34+
response = requests.post(self._token_endpoint, headers=self._headers, data=data)
35+
response.raise_for_status()
36+
return cast(dict[str, Any], response.json())
37+
38+
39+
class OauthDeviceSession:
40+
"""Class for storing OAuth session information and refreshing tokens when needed."""
41+
42+
def __init__(self, host: Url, settings: ApiSettings, identity_provider: IdentityProvider):
43+
self._api_settings = settings
44+
_auth_settings = settings.auths[host]
45+
self._host = host
46+
self._client_id = _auth_settings.client_id
47+
self._token_info = _auth_settings.tokens
48+
self._refresh_time_reduction = 5 # the number of seconds to refresh the expiration time
49+
self._identity_provider = identity_provider
50+
51+
def refresh(self) -> TokenInfo:
52+
if self._token_info is None:
53+
raise AuthorisationError("You should authenticate first before you can refresh")
54+
55+
if self._token_info.access_expires_at > time.time() + self._refresh_time_reduction:
56+
return self._token_info
57+
58+
try:
59+
self._token_info = TokenInfo(
60+
**self._identity_provider.refresh_access_token(self._client_id, self._token_info.refresh_token)
61+
)
62+
self._api_settings.store_tokens(self._host, self._token_info)
63+
return self._token_info
64+
except requests.HTTPError as e:
65+
raise AuthorisationError(f"An error occurred during token refresh: {e}")

qi2_shared/client.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
from typing import Any, Optional
2+
3+
import compute_api_client
4+
5+
from qi2_shared.authentication import IdentityProvider, OauthDeviceSession
6+
from qi2_shared.settings import ApiSettings
7+
8+
9+
class Configuration(compute_api_client.Configuration): # type: ignore[misc]
10+
"""Original Configuration class in compute_api_client does not handle refreshing bearer tokens, so we need to add
11+
some functionality."""
12+
13+
def __init__(self, host: str, oauth_session: OauthDeviceSession, **kwargs: Any):
14+
self._oauth_session = oauth_session
15+
super().__init__(host=host, **kwargs)
16+
17+
def auth_settings(self) -> Any:
18+
token_info = self._oauth_session.refresh()
19+
self.access_token = token_info.access_token
20+
return super().auth_settings()
21+
22+
23+
_config: Optional[Configuration] = None
24+
25+
26+
def connect() -> None:
27+
"""Set connection configuration for the Quantum Inspire API.
28+
29+
Call after logging in with the CLI. Will remove old configuration.
30+
"""
31+
global _config
32+
settings = ApiSettings.from_config_file()
33+
34+
tokens = settings.auths[settings.default_host].tokens
35+
36+
if tokens is None:
37+
raise ValueError("No access token found for the default host. Please connect to Quantum Inspire using the CLI.")
38+
39+
host = settings.default_host
40+
_config = Configuration(
41+
host=host,
42+
oauth_session=OauthDeviceSession(host, settings, IdentityProvider(settings.auths[host].well_known_endpoint)),
43+
)
44+
45+
46+
def config() -> Configuration:
47+
global _config
48+
if _config is None:
49+
connect()
50+
51+
assert _config is not None
52+
return _config

qi2_shared/pagination.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
from typing import Any, Awaitable, Callable, Generic, List, Optional, TypeVar, Union, cast
2+
3+
from pydantic import BaseModel, Field
4+
from typing_extensions import Annotated
5+
6+
PageType = TypeVar("PageType")
7+
ItemType = TypeVar("ItemType")
8+
9+
10+
class PageInterface(BaseModel, Generic[ItemType]):
11+
"""The page models in the generated API client don't inherit from a common base class, so we have to trick the
12+
typing system a bit with this fake base class."""
13+
14+
items: List[ItemType]
15+
total: Optional[Annotated[int, Field(strict=True, ge=0)]]
16+
page: Optional[Annotated[int, Field(strict=True, ge=1)]]
17+
size: Optional[Annotated[int, Field(strict=True, ge=1)]]
18+
pages: Optional[Annotated[int, Field(strict=True, ge=0)]] = None
19+
20+
21+
class PageReader(Generic[PageType, ItemType]):
22+
"""Helper class for reading fastapi-pagination style pages returned by the compute_api_client."""
23+
24+
async def get_all(self, api_call: Callable[..., Awaitable[PageType]], **kwargs: Any) -> List[ItemType]:
25+
"""Get all items from an API call that supports paging."""
26+
items: List[ItemType] = []
27+
page = 1
28+
29+
while True:
30+
response = cast(PageInterface[ItemType], await api_call(page=page, **kwargs))
31+
32+
items.extend(response.items)
33+
page += 1
34+
if response.pages is None or page > response.pages:
35+
break
36+
return items
37+
38+
async def get_single(self, api_call: Callable[..., Awaitable[PageType]], **kwargs: Any) -> Union[ItemType, None]:
39+
"""Get a single item from an API call that supports paging."""
40+
response = cast(PageInterface[ItemType], await api_call(**kwargs))
41+
if len(response.items) > 1:
42+
raise RuntimeError(f"Response contains more than one item -> {kwargs}.")
43+
44+
return response.items[0] if response.items else None

qi2_shared/settings.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
"""Module containing the handler for the Quantum Inspire persistent configuration."""
2+
3+
from __future__ import annotations
4+
5+
import time
6+
from pathlib import Path
7+
from typing import Dict, Optional
8+
9+
from pydantic import BaseModel, BeforeValidator, Field, HttpUrl
10+
from typing_extensions import Annotated
11+
12+
Url = Annotated[str, BeforeValidator(lambda value: str(HttpUrl(value)).rstrip("/"))]
13+
API_SETTINGS_FILE = Path.joinpath(Path.home(), ".quantuminspire", "config.json")
14+
15+
16+
class TokenInfo(BaseModel):
17+
"""A pydantic model for storing all information regarding oauth access and refresh tokens."""
18+
19+
access_token: str
20+
expires_in: int # [s]
21+
refresh_token: str
22+
refresh_expires_in: Optional[int] = None # [s]
23+
generated_at: float = Field(default_factory=time.time)
24+
25+
@property
26+
def access_expires_at(self) -> float:
27+
"""Unix timestamp containing the time when the access token will expire."""
28+
return self.generated_at + self.expires_in
29+
30+
31+
class AuthSettings(BaseModel):
32+
"""Pydantic model for storing all auth related settings for a given host."""
33+
34+
client_id: str
35+
code_challenge_method: str
36+
code_verifyer_length: int
37+
well_known_endpoint: Url
38+
tokens: Optional[TokenInfo]
39+
team_member_id: Optional[int]
40+
41+
42+
class ApiSettings(BaseModel):
43+
"""The settings class for the Quantum Inspire persistent configuration."""
44+
45+
auths: Dict[Url, AuthSettings]
46+
default_host: Url
47+
48+
def store_tokens(self, host: Url, tokens: TokenInfo, path: Path = API_SETTINGS_FILE) -> None:
49+
"""Stores the team_member_id, access and refresh tokens in the config.json file.
50+
51+
Args:
52+
host: The hostname of the API for which the tokens are intended.
53+
tokens: OAuth access and refresh tokens.
54+
path: The path to the config.json file. Defaults to API_SETTINGS_FILE.
55+
Returns:
56+
None
57+
"""
58+
self.auths[host].tokens = tokens
59+
path.write_text(self.model_dump_json(indent=2))
60+
61+
@classmethod
62+
def from_config_file(cls, path: Path = API_SETTINGS_FILE) -> ApiSettings:
63+
"""Load the configuration from a file."""
64+
if not path.is_file():
65+
raise FileNotFoundError("No configuration file found. Please connect to Quantum Inspire using the CLI.")
66+
67+
api_settings = path.read_text()
68+
return ApiSettings.model_validate_json(api_settings)

tests/__init__.py

Whitespace-only changes.

tests/shared/conftest.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import pytest
2+
3+
from qi2_shared.settings import AuthSettings, TokenInfo
4+
5+
6+
@pytest.fixture
7+
def token_info() -> TokenInfo:
8+
return TokenInfo(
9+
access_token="access_token",
10+
expires_in=100,
11+
refresh_token="refresh_token",
12+
refresh_expires_in=1000,
13+
generated_at=1,
14+
)
15+
16+
17+
@pytest.fixture
18+
def auth_settings(token_info: TokenInfo) -> AuthSettings:
19+
return AuthSettings(
20+
client_id="client_id",
21+
code_challenge_method="code_challenge_method",
22+
code_verifyer_length=1,
23+
well_known_endpoint="https://host.com/well-known-endpoint",
24+
tokens=token_info,
25+
team_member_id=1,
26+
)
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
import time
2+
from typing import Any
3+
from unittest.mock import MagicMock
4+
5+
import pytest
6+
import responses
7+
from responses import matchers
8+
9+
from qi2_shared.authentication import AuthorisationError, IdentityProvider, OauthDeviceSession
10+
from qi2_shared.settings import ApiSettings, AuthSettings, TokenInfo
11+
12+
13+
@pytest.fixture
14+
def identity_provider_mock() -> MagicMock:
15+
return MagicMock(spec=IdentityProvider)
16+
17+
18+
@pytest.fixture
19+
def api_settings_mock(auth_settings: AuthSettings) -> MagicMock:
20+
api_settings = MagicMock(spec=ApiSettings)
21+
api_settings.default_host = "https://host.com"
22+
api_settings.auths = {api_settings.default_host: auth_settings}
23+
return api_settings
24+
25+
26+
def test_oauth_device_session_refresh_no_token(api_settings_mock: MagicMock, identity_provider_mock: MagicMock) -> None:
27+
# Arrange
28+
api_settings_mock.auths[api_settings_mock.default_host].tokens = None
29+
session = OauthDeviceSession("https://host.com", api_settings_mock, identity_provider_mock)
30+
31+
# Act & Assert
32+
with pytest.raises(AuthorisationError):
33+
session.refresh()
34+
35+
36+
def test_oauth_device_session_refresh_token_not_expired(
37+
api_settings_mock: MagicMock, identity_provider_mock: MagicMock
38+
) -> None:
39+
# Arrange
40+
auth_settings = api_settings_mock.auths[api_settings_mock.default_host]
41+
auth_settings.tokens.generated_at = time.time()
42+
session = OauthDeviceSession("https://host.com", api_settings_mock, identity_provider_mock)
43+
44+
# Act
45+
token_info = session.refresh()
46+
47+
# Assert
48+
assert token_info == auth_settings.tokens
49+
50+
identity_provider_mock.refresh_access_token.assert_not_called()
51+
52+
53+
def test_oauth_device_session_refresh_token_expired(
54+
api_settings_mock: MagicMock, identity_provider_mock: MagicMock
55+
) -> None:
56+
# Arrange
57+
session = OauthDeviceSession("https://host.com", api_settings_mock, identity_provider_mock)
58+
new_token_info: dict[str, Any] = {
59+
"access_token": "new_access_token",
60+
"expires_in": 100,
61+
"refresh_token": "new_refresh_token",
62+
"refresh_expires_in": 1000,
63+
"generated_at": time.time(),
64+
}
65+
66+
identity_provider_mock.refresh_access_token.return_value = new_token_info
67+
68+
# Act
69+
token_info = session.refresh()
70+
71+
# Assert
72+
assert token_info == TokenInfo(**new_token_info)
73+
74+
identity_provider_mock.refresh_access_token.assert_called_once_with("client_id", "refresh_token")
75+
api_settings_mock.store_tokens.assert_called_once_with("https://host.com", token_info)
76+
77+
78+
@responses.activate
79+
def test_identity_provider_refresh_access_token() -> None:
80+
# Arrange
81+
token_info = {"token": "something", "some": "other_data"}
82+
client_id = "some_client"
83+
old_refresh_token = "old_token"
84+
85+
responses.get(
86+
"https://host.com/well-known-endpoint",
87+
json={
88+
"token_endpoint": "https://host.com/token-endpoint",
89+
"device_authorization_endpoint": "https://host.com/device-endpoint",
90+
},
91+
)
92+
responses.post(
93+
"https://host.com/token-endpoint",
94+
json=token_info,
95+
match=[
96+
matchers.urlencoded_params_matcher(
97+
{"grant_type": "refresh_token", "client_id": client_id, "refresh_token": old_refresh_token}
98+
)
99+
],
100+
)
101+
102+
# Act
103+
provider = IdentityProvider("https://host.com/well-known-endpoint")
104+
105+
token = provider.refresh_access_token(client_id, old_refresh_token)
106+
107+
# Assert
108+
assert token == token_info

0 commit comments

Comments
 (0)