Skip to content

Commit 8c9f74e

Browse files
authored
Introduce workspace kind and filter targets only for v1 (#712)
1 parent b91f506 commit 8c9f74e

File tree

5 files changed

+67
-16
lines changed

5 files changed

+67
-16
lines changed

azure-quantum/azure/quantum/_constants.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,11 @@ class EnvironmentKind(Enum):
5151
DOGFOOD = 3
5252

5353

54+
class WorkspaceKind(Enum):
55+
V1 = "V1"
56+
V2 = "V2"
57+
58+
5459
class ConnectionConstants:
5560
DATA_PLANE_CREDENTIAL_SCOPE = "https://quantum.microsoft.com/.default"
5661
ARM_CREDENTIAL_SCOPE = "https://management.azure.com/.default"

azure-quantum/azure/quantum/_mgmt_client.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import logging
1111
from http import HTTPStatus
12-
from typing import Any, Optional, cast
12+
from typing import Any, Dict, Optional, cast
1313
from azure.core import PipelineClient
1414
from azure.core.credentials import TokenProvider
1515
from azure.core.pipeline import policies
@@ -104,8 +104,8 @@ def load_workspace_from_arg(self, connection_params: WorkspaceConnectionParams)
104104
query += f"\n | where location =~ '{connection_params.location}'"
105105

106106
query += """
107-
| extend endpointUri = tostring(properties.endpointUri)
108-
| project name, subscriptionId, resourceGroup, location, endpointUri
107+
| extend endpointUri = tostring(properties.endpointUri), workspaceKind = tostring(properties.workspaceKind)
108+
| project name, subscriptionId, resourceGroup, location, endpointUri, workspaceKind
109109
"""
110110

111111
request_body = {
@@ -143,20 +143,22 @@ def load_workspace_from_arg(self, connection_params: WorkspaceConnectionParams)
143143
f"Please specify additional connection parameters. {self.CONNECT_DOC_MESSAGE}"
144144
)
145145

146-
workspace_data = data[0]
146+
workspace_data: Dict[str, Any] = data[0]
147147

148148
connection_params.subscription_id = workspace_data.get('subscriptionId')
149149
connection_params.resource_group = workspace_data.get('resourceGroup')
150150
connection_params.location = workspace_data.get('location')
151151
connection_params.quantum_endpoint = workspace_data.get('endpointUri')
152+
connection_params.workspace_kind = workspace_data.get('workspaceKind')
152153

153154
logger.debug(
154-
"Found workspace '%s' in subscription '%s', resource group '%s', location '%s', endpoint '%s'",
155+
"Found workspace '%s' in subscription '%s', resource group '%s', location '%s', endpoint '%s', kind '%s'.",
155156
connection_params.workspace_name,
156157
connection_params.subscription_id,
157158
connection_params.resource_group,
158159
connection_params.location,
159-
connection_params.quantum_endpoint
160+
connection_params.quantum_endpoint,
161+
connection_params.workspace_kind
160162
)
161163

162164
# If one of the required parameters is missing, probably workspace in failed provisioning state
@@ -194,7 +196,7 @@ def load_workspace_from_arm(self, connection_params: WorkspaceConnectionParams)
194196
try:
195197
response = self._client.send_request(request)
196198
response.raise_for_status()
197-
workspace_data = response.json()
199+
workspace_data: Dict[str, Any] = response.json()
198200
except HttpResponseError as e:
199201
if e.status_code == HTTPStatus.NOT_FOUND:
200202
raise ValueError(
@@ -225,7 +227,7 @@ def load_workspace_from_arm(self, connection_params: WorkspaceConnectionParams)
225227
)
226228

227229
# Extract and apply endpoint URI from properties
228-
properties = workspace_data.get("properties", {})
230+
properties: Dict[str, Any] = workspace_data.get("properties", {})
229231
endpoint_uri = properties.get("endpointUri")
230232
if endpoint_uri:
231233
connection_params.quantum_endpoint = endpoint_uri
@@ -237,3 +239,11 @@ def load_workspace_from_arm(self, connection_params: WorkspaceConnectionParams)
237239
f"Failed to retrieve endpoint uri for workspace '{connection_params.workspace_name}'. "
238240
f"Please check that workspace is in valid state."
239241
)
242+
243+
# Set workspaceKind if available
244+
workspace_kind = properties.get("workspaceKind")
245+
if workspace_kind:
246+
connection_params.workspace_kind = workspace_kind
247+
logger.debug(
248+
"Updated workspace kind from ARM: %s", connection_params.workspace_kind
249+
)

azure-quantum/azure/quantum/_workspace_connection_params.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from azure.identity import DefaultAzureCredential
1818
from azure.quantum._constants import (
1919
EnvironmentKind,
20+
WorkspaceKind,
2021
EnvironmentVariables,
2122
ConnectionConstants,
2223
GUID_REGEX_PATTERN,
@@ -48,7 +49,7 @@ class WorkspaceConnectionParams:
4849
ResourceGroupName=(?P<resource_group>[^\s;]+);
4950
WorkspaceName=(?P<workspace_name>[^\s;]+);
5051
ApiKey=(?P<api_key>[^\s;]+);
51-
QuantumEndpoint=(?P<quantum_endpoint>https://(?P<location>[a-zA-Z0-9]+)(?:-v2)?.quantum(?:-test)?.azure.com/);
52+
QuantumEndpoint=(?P<quantum_endpoint>https://(?P<location>[a-zA-Z0-9]+)(?:-(?P<workspace_kind>v2))?.quantum(?:-test)?.azure.com/);
5253
""",
5354
re.VERBOSE | re.IGNORECASE)
5455

@@ -80,13 +81,15 @@ def __init__(
8081
api_version: Optional[str] = None,
8182
connection_string: Optional[str] = None,
8283
on_new_client_request: Optional[Callable] = None,
84+
workspace_kind: Optional[str] = None,
8385
):
8486
# fields are used for these properties since
8587
# they have special getters/setters
8688
self._location = None
8789
self._environment = None
8890
self._quantum_endpoint = None
8991
self._arm_endpoint = None
92+
self._workspace_kind = None
9093
# regular connection properties
9194
self.subscription_id = None
9295
self.resource_group = None
@@ -120,6 +123,7 @@ def __init__(
120123
user_agent=user_agent,
121124
user_agent_app_id=user_agent_app_id,
122125
workspace_name=workspace_name,
126+
workspace_kind=workspace_kind,
123127
)
124128
self.apply_resource_id(resource_id=resource_id)
125129
# Validate connection parameters if they are set
@@ -272,6 +276,19 @@ def api_key(self, value: str):
272276
self.credential = AzureKeyCredential(value)
273277
self._api_key = value
274278

279+
@property
280+
def workspace_kind(self) -> WorkspaceKind:
281+
"""
282+
The workspace kind, such as V1 or V2.
283+
Defaults to WorkspaceKind.V1
284+
"""
285+
return self._workspace_kind or WorkspaceKind.V1
286+
287+
@workspace_kind.setter
288+
def workspace_kind(self, value: str):
289+
if isinstance(value, str):
290+
self._workspace_kind = WorkspaceKind[value.upper()]
291+
275292
def __repr__(self):
276293
"""
277294
Print all fields and properties.
@@ -331,6 +348,7 @@ def merge(
331348
client_id: Optional[str] = None,
332349
api_version: Optional[str] = None,
333350
api_key: Optional[str] = None,
351+
workspace_kind: Optional[str] = None,
334352
):
335353
"""
336354
Set all fields/properties with `not None` values
@@ -352,6 +370,7 @@ def merge(
352370
user_agent_app_id=user_agent_app_id,
353371
workspace_name=workspace_name,
354372
api_key=api_key,
373+
workspace_kind=workspace_kind,
355374
merge_default_mode=False,
356375
)
357376
return self
@@ -372,6 +391,7 @@ def apply_defaults(
372391
client_id: Optional[str] = None,
373392
api_version: Optional[str] = None,
374393
api_key: Optional[str] = None,
394+
workspace_kind: Optional[str] = None,
375395
) -> WorkspaceConnectionParams:
376396
"""
377397
Set all fields/properties with `not None` values
@@ -394,6 +414,7 @@ def apply_defaults(
394414
user_agent_app_id=user_agent_app_id,
395415
workspace_name=workspace_name,
396416
api_key=api_key,
417+
workspace_kind=workspace_kind,
397418
merge_default_mode=True,
398419
)
399420
return self
@@ -415,6 +436,7 @@ def _merge(
415436
client_id: Optional[str] = None,
416437
api_version: Optional[str] = None,
417438
api_key: Optional[str] = None,
439+
workspace_kind: Optional[str] = None,
418440
):
419441
"""
420442
Set all fields/properties with `not None` values
@@ -447,6 +469,7 @@ def _get_value_or_default(old_value, new_value):
447469
# the private field as the old_value
448470
self.quantum_endpoint = _get_value_or_default(self._quantum_endpoint, quantum_endpoint)
449471
self.arm_endpoint = _get_value_or_default(self._arm_endpoint, arm_endpoint)
472+
self.workspace_kind = _get_value_or_default(self._workspace_kind, workspace_kind)
450473
return self
451474

452475
def _merge_connection_params(
@@ -476,6 +499,7 @@ def _merge_connection_params(
476499
# pylint: disable=protected-access
477500
arm_endpoint=connection_params._arm_endpoint,
478501
quantum_endpoint=connection_params._quantum_endpoint,
502+
workspace_kind=connection_params._workspace_kind,
479503
)
480504
return self
481505

@@ -640,4 +664,5 @@ def get_value(group_name):
640664
quantum_endpoint=get_value('quantum_endpoint'),
641665
api_key=get_value('api_key'),
642666
arm_endpoint=get_value('arm_endpoint'),
667+
workspace_kind=get_value('workspace_kind'),
643668
)

azure-quantum/azure/quantum/target/target_factory.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import warnings
66
from typing import Any, Dict, List, TYPE_CHECKING, Union, Type
77
from azure.quantum.target import *
8+
from azure.quantum._constants import WorkspaceKind
89

910
if TYPE_CHECKING:
1011
from azure.quantum import Workspace
@@ -134,10 +135,16 @@ def get_targets(
134135
return result
135136

136137
else:
137-
# Don't return redundant targets
138-
return [
139-
self.from_target_status(_provider_id, status, **kwargs)
140-
for _provider_id, status in target_statuses
141-
if _provider_id.lower() in self._default_targets
142-
or status.id in self._all_targets
143-
]
138+
if self._workspace._connection_params.workspace_kind == WorkspaceKind.V1:
139+
# Filter only relevant targets for user's selected framework like Cirq, Qiskit, etc.
140+
return [
141+
self.from_target_status(_provider_id, status, **kwargs)
142+
for _provider_id, status in target_statuses
143+
if _provider_id.lower() in self._default_targets
144+
or status.id in self._all_targets
145+
]
146+
else:
147+
return [
148+
self.from_target_status(_provider_id, status, **kwargs)
149+
for _provider_id, status in target_statuses
150+
]

azure-quantum/azure/quantum/workspace.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ class Workspace:
117117
# Internal parameter names
118118
_FROM_CONNECTION_STRING_PARAM = '_from_connection_string'
119119
_QUANTUM_ENDPOINT_PARAM = '_quantum_endpoint'
120+
_WORKSPACE_KIND_PARAM = '_workspace_kind'
120121
_MGMT_CLIENT_PARAM = '_mgmt_client'
121122

122123
def __init__(
@@ -136,6 +137,7 @@ def __init__(
136137
from_connection_string = kwargs.pop(Workspace._FROM_CONNECTION_STRING_PARAM, False)
137138
# In case from connection string, quantum_endpoint must be passed
138139
quantum_endpoint = kwargs.pop(Workspace._QUANTUM_ENDPOINT_PARAM, None)
140+
workspace_kind = kwargs.pop(Workspace._WORKSPACE_KIND_PARAM, None)
139141
# Params to pass a mock in tests
140142
self._mgmt_client = kwargs.pop(Workspace._MGMT_CLIENT_PARAM, None)
141143

@@ -148,6 +150,7 @@ def __init__(
148150
resource_id=resource_id,
149151
quantum_endpoint=quantum_endpoint,
150152
user_agent=user_agent,
153+
workspace_kind=workspace_kind,
151154
**kwargs
152155
).default_from_env_vars()
153156

@@ -320,6 +323,7 @@ def from_connection_string(cls, connection_string: str, **kwargs) -> Workspace:
320323
connection_params = WorkspaceConnectionParams(connection_string=connection_string)
321324
kwargs[cls._FROM_CONNECTION_STRING_PARAM] = True
322325
kwargs[cls._QUANTUM_ENDPOINT_PARAM] = connection_params.quantum_endpoint
326+
kwargs[cls._WORKSPACE_KIND_PARAM] = connection_params.workspace_kind.value if connection_params.workspace_kind else None
323327
return cls(
324328
subscription_id=connection_params.subscription_id,
325329
resource_group=connection_params.resource_group,

0 commit comments

Comments
 (0)