Skip to content

Commit 5e175a2

Browse files
committed
Updated mock to use WorkspaceClient instead of ServiceClient
1 parent bb118ca commit 5e175a2

File tree

2 files changed

+189
-25
lines changed

2 files changed

+189
-25
lines changed

azure-quantum/tests/unit/local/mock_client.py

Lines changed: 166 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,9 @@
1414
from azure.core.paging import ItemPaged
1515
from azure.quantum.workspace import Workspace
1616
from types import SimpleNamespace
17-
from azure.quantum._client import ServicesClient
17+
from azure.quantum._client import WorkspaceClient
1818
from azure.quantum._client.models import JobDetails, SessionDetails, ItemDetails
19+
from azure.quantum._client.models import SasUriResponse
1920
from azure.quantum._workspace_connection_params import WorkspaceConnectionParams
2021
from common import (
2122
SUBSCRIPTION_ID,
@@ -156,6 +157,23 @@ def create_or_replace(
156157
self._store.append(job_details)
157158
return job_details
158159

160+
# New WorkspaceClient API: create
161+
def create(
162+
self,
163+
subscription_id: str,
164+
resource_group_name: str,
165+
workspace_name: str,
166+
job_id: str,
167+
resource: JobDetails,
168+
) -> JobDetails:
169+
return self.create_or_replace(
170+
subscription_id,
171+
resource_group_name,
172+
workspace_name,
173+
job_id,
174+
resource,
175+
)
176+
159177
def get(
160178
self,
161179
subscription_id: str,
@@ -168,6 +186,20 @@ def get(
168186
return jd
169187
raise KeyError(job_id)
170188

189+
# Cancel/delete for older API; mark job as cancelled
190+
def delete(
191+
self,
192+
subscription_id: str,
193+
resource_group_name: str,
194+
workspace_name: str,
195+
job_id: str,
196+
) -> None:
197+
for jd in self._store:
198+
if jd.id == job_id:
199+
jd.status = "Cancelled"
200+
return None
201+
raise KeyError(job_id)
202+
171203
def list(
172204
self,
173205
subscription_id: str,
@@ -222,6 +254,23 @@ def create_or_replace(
222254
self._store.append(session_details)
223255
return session_details
224256

257+
# New WorkspaceClient API: open
258+
def open(
259+
self,
260+
subscription_id: str,
261+
resource_group_name: str,
262+
workspace_name: str,
263+
session_id: str,
264+
resource: SessionDetails,
265+
) -> SessionDetails:
266+
return self.create_or_replace(
267+
subscription_id,
268+
resource_group_name,
269+
workspace_name,
270+
session_id,
271+
resource,
272+
)
273+
225274
def close(
226275
self,
227276
subscription_id: str,
@@ -269,6 +318,27 @@ def list(
269318
pass
270319
return _paged(items[skip : skip + top], page_size=top)
271320

321+
# New WorkspaceClient API: listv2 (same behavior as list)
322+
def listv2(
323+
self,
324+
subscription_id: str,
325+
resource_group_name: str,
326+
workspace_name: str,
327+
filter: Optional[str] = None,
328+
orderby: Optional[str] = None,
329+
skip: int = 0,
330+
top: int = 100,
331+
) -> ItemPaged[SessionDetails]:
332+
return self.list(
333+
subscription_id,
334+
resource_group_name,
335+
workspace_name,
336+
filter,
337+
orderby,
338+
skip,
339+
top,
340+
)
341+
272342
def jobs_list(
273343
self,
274344
subscription_id: str,
@@ -356,68 +426,141 @@ def list(
356426
pass
357427
return _paged(items[skip : skip + top], page_size=top)
358428

429+
# New WorkspaceClient API: listv2
430+
def listv2(
431+
self,
432+
subscription_id: str,
433+
resource_group_name: str,
434+
workspace_name: str,
435+
filter: Optional[str] = None,
436+
orderby: Optional[str] = None,
437+
top: int = 100,
438+
skip: int = 0,
439+
) -> ItemPaged[ItemDetails]:
440+
return self.list(
441+
subscription_id,
442+
resource_group_name,
443+
workspace_name,
444+
filter,
445+
orderby,
446+
top,
447+
skip,
448+
)
449+
450+
451+
class ProvidersOperations:
452+
def list(
453+
self,
454+
subscription_id: str,
455+
resource_group_name: str,
456+
workspace_name: str,
457+
) -> ItemPaged:
458+
# Minimal stub: return empty provider list
459+
return _paged([], page_size=100)
460+
461+
462+
class QuotasOperations:
463+
def list(
464+
self,
465+
subscription_id: str,
466+
resource_group_name: str,
467+
workspace_name: str,
468+
) -> ItemPaged:
469+
# Minimal stub: return empty quotas list
470+
return _paged([], page_size=100)
471+
472+
473+
class StorageOperations:
474+
def get_sas_uri(
475+
self,
476+
subscription_id: str,
477+
resource_group_name: str,
478+
workspace_name: str,
479+
*,
480+
blob_details: object,
481+
) -> SasUriResponse:
482+
# Return a dummy SAS URI suitable for tests that might exercise storage
483+
return SasUriResponse({"sasUri": "https://example.com/container?sas-token"})
484+
359485

360486
class MockWorkspaceMgmtClient:
361487
"""Mock management client that avoids network calls to ARM/ARG."""
362-
363-
def __init__(self, credential: Optional[object] = None, base_url: Optional[str] = None, user_agent: Optional[str] = None) -> None:
488+
489+
def __init__(
490+
self,
491+
credential: Optional[object] = None,
492+
base_url: Optional[str] = None,
493+
user_agent: Optional[str] = None,
494+
) -> None:
364495
self._credential = credential
365496
self._base_url = base_url
366497
self._user_agent = user_agent
367-
498+
368499
def close(self) -> None:
369500
"""No-op close for mock."""
370501
pass
371502

372-
def __enter__(self) -> 'MockWorkspaceMgmtClient':
503+
def __enter__(self) -> "MockWorkspaceMgmtClient":
373504
return self
374505

375506
def __exit__(self, *exc_details) -> None:
376507
pass
377508

378-
def load_workspace_from_arg(self, connection_params: WorkspaceConnectionParams) -> None:
509+
def load_workspace_from_arg(
510+
self, connection_params: WorkspaceConnectionParams
511+
) -> None:
379512
connection_params.subscription_id = SUBSCRIPTION_ID
380513
connection_params.resource_group = RESOURCE_GROUP
381514
connection_params.location = LOCATION
382515
connection_params.quantum_endpoint = ENDPOINT_URI
383516

384-
def load_workspace_from_arm(self, connection_params: WorkspaceConnectionParams) -> None:
517+
def load_workspace_from_arm(
518+
self, connection_params: WorkspaceConnectionParams
519+
) -> None:
385520
connection_params.location = LOCATION
386521
connection_params.quantum_endpoint = ENDPOINT_URI
387522

388523

389-
class MockServicesClient(ServicesClient):
524+
class MockWorkspaceClient:
390525
def __init__(self, authentication_policy: Optional[object] = None) -> None:
391526
# in-memory stores
392527
self._jobs_store: List[JobDetails] = []
393528
self._sessions_store: List[SessionDetails] = []
394-
# operations
395-
self.jobs = JobsOperations(self._jobs_store)
396-
self.sessions = SessionsOperations(self._sessions_store, self._jobs_store)
397-
self.top_level_items = TopLevelItemsOperations(
398-
self._jobs_store, self._sessions_store
529+
# operations grouped under .services to mirror WorkspaceClient
530+
self.services = SimpleNamespace(
531+
jobs=JobsOperations(self._jobs_store),
532+
sessions=SessionsOperations(self._sessions_store, self._jobs_store),
533+
top_level_items=TopLevelItemsOperations(
534+
self._jobs_store, self._sessions_store
535+
),
536+
providers=ProvidersOperations(),
537+
quotas=QuotasOperations(),
538+
storage=StorageOperations(),
399539
)
400-
# Mimic ServicesClient config shape for tests that inspect policy
540+
# Mimic WorkspaceClient config shape for tests that inspect policy
401541
self._config = SimpleNamespace(authentication_policy=authentication_policy)
402542

403-
def __enter__(self) -> 'MockServicesClient':
543+
def __enter__(self) -> "MockWorkspaceClient":
404544
return self
405545

406546
def __exit__(self, *exc_details) -> None:
407547
pass
408548

549+
def close(self) -> None:
550+
pass
551+
409552

410553
class WorkspaceMock(Workspace):
411554
def __init__(self, **kwargs) -> None:
412555
# Create and pass mock management client to prevent network calls
413-
if '_mgmt_client' not in kwargs:
414-
kwargs['_mgmt_client'] = MockWorkspaceMgmtClient()
556+
if "_mgmt_client" not in kwargs:
557+
kwargs["_mgmt_client"] = MockWorkspaceMgmtClient()
415558
super().__init__(**kwargs)
416-
417-
def _create_client(self) -> ServicesClient: # type: ignore[override]
559+
560+
def _create_client(self) -> WorkspaceClient: # type: ignore[override]
418561
# Pass through the Workspace's auth policy to the mock client
419562
auth_policy = self._connection_params.get_auth_policy()
420-
return MockServicesClient(authentication_policy=auth_policy)
563+
return MockWorkspaceClient(authentication_policy=auth_policy)
421564

422565

423566
def seed_jobs(ws: WorkspaceMock) -> None:
@@ -478,7 +621,7 @@ def seed_jobs(ws: WorkspaceMock) -> None:
478621
),
479622
]
480623
for d in samples:
481-
ws._client.jobs.create_or_replace(
624+
ws._client.services.jobs.create_or_replace(
482625
ws.subscription_id, ws.resource_group, ws.name, job_id=d.id, job_details=d
483626
)
484627

@@ -504,7 +647,7 @@ def seed_sessions(ws: WorkspaceMock) -> None:
504647
),
505648
]
506649
for s in samples:
507-
ws._client.sessions.create_or_replace(
650+
ws._client.services.sessions.create_or_replace(
508651
ws.subscription_id,
509652
ws.resource_group,
510653
ws.name,
@@ -515,9 +658,7 @@ def seed_sessions(ws: WorkspaceMock) -> None:
515658

516659
def create_default_workspace() -> WorkspaceMock:
517660
ws = WorkspaceMock(
518-
subscription_id=SUBSCRIPTION_ID,
519-
resource_group=RESOURCE_GROUP,
520-
name=WORKSPACE
661+
subscription_id=SUBSCRIPTION_ID, resource_group=RESOURCE_GROUP, name=WORKSPACE
521662
)
522663
seed_jobs(ws)
523664
seed_sessions(ws)

azure-quantum/tests/unit/local/test_workspace.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -465,3 +465,26 @@ def test_workspace_context_manager_calls_enter_exit():
465465
# Verify __exit__ was called on both clients after exiting context
466466
ws._client.__exit__.assert_called_once()
467467
ws._mgmt_client.__exit__.assert_called_once()
468+
469+
470+
def test_get_container_uri_uses_linked_storage_sas_when_storage_none():
471+
"""When storage is None, get_container_uri should use linked storage via service SAS."""
472+
ws = WorkspaceMock(
473+
subscription_id=SUBSCRIPTION_ID,
474+
resource_group=RESOURCE_GROUP,
475+
name=WORKSPACE,
476+
)
477+
assert ws.storage is None
478+
479+
with mock.patch(
480+
"azure.quantum.storage.ContainerClient.from_container_url",
481+
return_value=mock.MagicMock(),
482+
):
483+
with mock.patch(
484+
"azure.quantum.storage.create_container_using_client",
485+
return_value=None,
486+
):
487+
uri = ws.get_container_uri(job_id="job-123")
488+
assert isinstance(uri, str)
489+
assert "https://example.com/" in uri
490+
assert "sas-token" in uri

0 commit comments

Comments
 (0)