Skip to content

Commit 7ca183a

Browse files
joshharrinJosh Harrington
andauthored
Address comments from Azure SDK team (Azure#29379)
* address comments * fix pylint and tests * fix pylint * fix pylint * fix OutboundRule classes * fix schema dump for managed network * black, cspell, pylint fixes --------- Co-authored-by: Josh Harrington <[email protected]>
1 parent 844e0c5 commit 7ca183a

11 files changed

+5542
-891
lines changed

sdk/ml/azure-ai-ml/azure/ai/ml/_schema/workspace/networking.py

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from azure.ai.ml._schema.core.fields import NestedField
1212
from azure.ai.ml.entities._workspace.networking import (
1313
ManagedNetwork,
14-
OutboundRule,
1514
FqdnDestination,
1615
ServiceTagDestination,
1716
PrivateEndpointDestination,
@@ -66,20 +65,23 @@ def createdestobject(self, data, **kwargs):
6665
category = data.get("category", OutboundRuleCategory.USER_DEFINED)
6766
if dest:
6867
if isinstance(dest, str):
69-
return FqdnDestination(dest, _snake_to_camel(category))
68+
return FqdnDestination(rule_name=None, destination=dest, category=_snake_to_camel(category))
7069
else:
7170
if dest.get("subresource_target", False):
7271
return PrivateEndpointDestination(
73-
dest["service_resource_id"],
74-
dest["subresource_target"],
75-
dest["spark_enabled"],
76-
_snake_to_camel(category),
72+
rule_name=None,
73+
service_resource_id=dest["service_resource_id"],
74+
subresource_target=dest["subresource_target"],
75+
spark_enabled=dest["spark_enabled"],
76+
category=_snake_to_camel(category),
7777
)
78-
if dest.get("service_tag", False):
79-
return ServiceTagDestination(
80-
dest["service_tag"], dest["protocol"], dest["port_ranges"], _snake_to_camel(category)
81-
)
82-
return OutboundRule(data)
78+
return ServiceTagDestination(
79+
rule_name=None,
80+
service_tag=dest["service_tag"],
81+
protocol=dest["protocol"],
82+
port_ranges=dest["port_ranges"],
83+
category=_snake_to_camel(category),
84+
)
8385

8486
def fqdn_dest2dict(self, fqdndest):
8587
res = fqdndest
@@ -118,7 +120,18 @@ class ManagedNetworkSchema(metaclass=PatchedSchemaMeta):
118120

119121
@post_load
120122
def make(self, data, **kwargs):
121-
if data.get("outbound_rules", False):
122-
return ManagedNetwork(_snake_to_camel(data["isolation_mode"]), data["outbound_rules"])
123+
rules_dict = data.get("outbound_rules", False)
124+
if rules_dict:
125+
rules_as_list = []
126+
for rule_name in rules_dict:
127+
rule = rules_dict[rule_name]
128+
rule.rule_name = rule_name
129+
rules_as_list.append(rule)
130+
return ManagedNetwork(_snake_to_camel(data["isolation_mode"]), rules_as_list)
123131
else:
124132
return ManagedNetwork(_snake_to_camel(data["isolation_mode"]))
133+
134+
@pre_dump
135+
def predump(self, data, **kwargs):
136+
data.outbound_rules = {rule.rule_name: rule for rule in data.outbound_rules}
137+
return data

sdk/ml/azure-ai-ml/azure/ai/ml/entities/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@
133133
FqdnDestination,
134134
ServiceTagDestination,
135135
PrivateEndpointDestination,
136+
IsolationMode,
136137
)
137138
from ._workspace.private_endpoint import EndpointConnection, PrivateEndpoint
138139
from ._workspace.workspace import Workspace
@@ -222,6 +223,7 @@
222223
"FqdnDestination",
223224
"ServiceTagDestination",
224225
"PrivateEndpointDestination",
226+
"IsolationMode",
225227
"EndpointConnection",
226228
"CustomerManagedKey",
227229
"DataImport",

sdk/ml/azure-ai-ml/azure/ai/ml/entities/_workspace/networking.py

Lines changed: 65 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# Copyright (c) Microsoft Corporation. All rights reserved.
33
# ---------------------------------------------------------
44

5-
from typing import Any, Dict, Optional
5+
from typing import Any, Dict, Optional, List
66

77
from azure.ai.ml._restclient.v2022_12_01_preview.models import (
88
ManagedNetworkSettings as RestManagedNetwork,
@@ -19,23 +19,36 @@
1919

2020
@experimental
2121
class OutboundRule:
22+
"""Base class for Outbound Rules, should not be instantiated directly.
23+
24+
:param rule_name: Name of the outbound rule.
25+
:type rule_name: str
26+
:param type: Type of the outbound rule. Supported types are "FQDN", "PrivateEndpoint", "ServiceTag"
27+
:type type: str
28+
"""
29+
2230
def __init__(
23-
self, type: str = None, category: str = OutboundRuleCategory.USER_DEFINED # pylint: disable=redefined-builtin
31+
self,
32+
*,
33+
rule_name: str = None,
34+
**kwargs,
2435
) -> None:
25-
self.type = type
26-
self.category = category
36+
self.rule_name = rule_name
37+
self.type = kwargs.pop("type", None)
38+
self.category = kwargs.pop("category", OutboundRuleCategory.USER_DEFINED)
2739

2840
@classmethod
29-
def _from_rest_object(cls, rest_obj: Any) -> "OutboundRule":
41+
def _from_rest_object(cls, rest_obj: Any, rule_name: str) -> "OutboundRule":
3042
if isinstance(rest_obj, RestFqdnOutboundRule):
31-
rule = FqdnDestination(destination=rest_obj.destination)
43+
rule = FqdnDestination(destination=rest_obj.destination, rule_name=rule_name)
3244
rule.category = rest_obj.category
3345
return rule
3446
if isinstance(rest_obj, RestPrivateEndpointOutboundRule):
3547
rule = PrivateEndpointDestination(
3648
service_resource_id=rest_obj.destination.service_resource_id,
3749
subresource_target=rest_obj.destination.subresource_target,
3850
spark_enabled=rest_obj.destination.spark_enabled,
51+
rule_name=rule_name,
3952
)
4053
rule.category = rest_obj.category
4154
return rule
@@ -44,37 +57,44 @@ def _from_rest_object(cls, rest_obj: Any) -> "OutboundRule":
4457
service_tag=rest_obj.destination.service_tag,
4558
protocol=rest_obj.destination.protocol,
4659
port_ranges=rest_obj.destination.port_ranges,
60+
rule_name=rule_name,
4761
)
4862
rule.category = rest_obj.category
4963
return rule
5064

5165

5266
@experimental
5367
class FqdnDestination(OutboundRule):
54-
def __init__(self, destination: str, category: str = OutboundRuleCategory.USER_DEFINED) -> None:
68+
def __init__(self, *, rule_name: str, destination: str, **kwargs) -> None:
5569
self.destination = destination
56-
OutboundRule.__init__(self, type=OutboundRuleType.FQDN, category=category)
70+
category = kwargs.pop("category", OutboundRuleCategory.USER_DEFINED)
71+
OutboundRule.__init__(self, type=OutboundRuleType.FQDN, category=category, rule_name=rule_name)
5772

5873
def _to_rest_object(self) -> RestFqdnOutboundRule:
5974
return RestFqdnOutboundRule(type=self.type, category=self.category, destination=self.destination)
6075

6176
def _to_dict(self) -> Dict:
62-
return {"type": OutboundRuleType.FQDN, "category": self.category, "destination": self.destination}
77+
return {
78+
self.rule_name: {"type": OutboundRuleType.FQDN, "category": self.category, "destination": self.destination}
79+
}
6380

6481

6582
@experimental
6683
class PrivateEndpointDestination(OutboundRule):
6784
def __init__(
6885
self,
86+
*,
87+
rule_name: str,
6988
service_resource_id: str,
7089
subresource_target: str,
7190
spark_enabled: bool = False,
72-
category: str = OutboundRuleCategory.USER_DEFINED,
91+
**kwargs,
7392
) -> None:
7493
self.service_resource_id = service_resource_id
7594
self.subresource_target = subresource_target
7695
self.spark_enabled = spark_enabled
77-
OutboundRule.__init__(self, OutboundRuleType.PRIVATE_ENDPOINT, category=category)
96+
category = kwargs.pop("category", OutboundRuleCategory.USER_DEFINED)
97+
OutboundRule.__init__(self, type=OutboundRuleType.PRIVATE_ENDPOINT, category=category, rule_name=rule_name)
7898

7999
def _to_rest_object(self) -> RestPrivateEndpointOutboundRule:
80100
return RestPrivateEndpointOutboundRule(
@@ -89,25 +109,34 @@ def _to_rest_object(self) -> RestPrivateEndpointOutboundRule:
89109

90110
def _to_dict(self) -> Dict:
91111
return {
92-
"type": OutboundRuleType.PRIVATE_ENDPOINT,
93-
"category": self.category,
94-
"destination": {
95-
"service_resource_id": self.service_resource_id,
96-
"subresource_target": self.subresource_target,
97-
"spark_enabled": self.spark_enabled,
98-
},
112+
self.rule_name: {
113+
"type": OutboundRuleType.PRIVATE_ENDPOINT,
114+
"category": self.category,
115+
"destination": {
116+
"service_resource_id": self.service_resource_id,
117+
"subresource_target": self.subresource_target,
118+
"spark_enabled": self.spark_enabled,
119+
},
120+
}
99121
}
100122

101123

102124
@experimental
103125
class ServiceTagDestination(OutboundRule):
104126
def __init__(
105-
self, service_tag: str, protocol: str, port_ranges: str, category: str = OutboundRuleCategory.USER_DEFINED
127+
self,
128+
*,
129+
rule_name: str,
130+
service_tag: str,
131+
protocol: str,
132+
port_ranges: str,
133+
**kwargs,
106134
) -> None:
107135
self.service_tag = service_tag
108136
self.protocol = protocol
109137
self.port_ranges = port_ranges
110-
OutboundRule.__init__(self, OutboundRuleType.SERVICE_TAG, category=category)
138+
category = kwargs.pop("category", OutboundRuleCategory.USER_DEFINED)
139+
OutboundRule.__init__(self, type=OutboundRuleType.SERVICE_TAG, category=category, rule_name=rule_name)
111140

112141
def _to_rest_object(self) -> RestServiceTagOutboundRule:
113142
return RestServiceTagOutboundRule(
@@ -120,13 +149,15 @@ def _to_rest_object(self) -> RestServiceTagOutboundRule:
120149

121150
def _to_dict(self) -> Dict:
122151
return {
123-
"type": OutboundRuleType.SERVICE_TAG,
124-
"category": self.category,
125-
"destination": {
126-
"service_tag": self.service_tag,
127-
"protocol": self.protocol,
128-
"port_ranges": self.port_ranges,
129-
},
152+
self.rule_name: {
153+
"type": OutboundRuleType.SERVICE_TAG,
154+
"category": self.category,
155+
"destination": {
156+
"service_tag": self.service_tag,
157+
"protocol": self.protocol,
158+
"port_ranges": self.port_ranges,
159+
},
160+
}
130161
}
131162

132163

@@ -135,7 +166,7 @@ class ManagedNetwork:
135166
def __init__(
136167
self,
137168
isolation_mode: str = IsolationMode.DISABLED,
138-
outbound_rules: Optional[Dict[str, OutboundRule]] = None,
169+
outbound_rules: Optional[List[OutboundRule]] = None,
139170
network_id: Optional[str] = None,
140171
) -> None:
141172
self.isolation_mode = isolation_mode
@@ -145,8 +176,8 @@ def __init__(
145176
def _to_rest_object(self) -> RestManagedNetwork:
146177
rest_outbound_rules = (
147178
{
148-
rule_name: self.outbound_rules[rule_name]._to_rest_object() # pylint: disable=protected-access
149-
for rule_name in self.outbound_rules
179+
outbound_rule.rule_name: outbound_rule._to_rest_object() # pylint: disable=protected-access
180+
for outbound_rule in self.outbound_rules
150181
}
151182
if self.outbound_rules
152183
else None
@@ -156,12 +187,12 @@ def _to_rest_object(self) -> RestManagedNetwork:
156187
@classmethod
157188
def _from_rest_object(cls, obj: RestManagedNetwork) -> "ManagedNetwork":
158189
from_rest_outbound_rules = (
159-
{
160-
rule_name: OutboundRule._from_rest_object( # pylint: disable=protected-access
161-
obj.outbound_rules[rule_name]
190+
[
191+
OutboundRule._from_rest_object( # pylint: disable=protected-access
192+
obj.outbound_rules[rule_name], rule_name=rule_name
162193
)
163194
for rule_name in obj.outbound_rules
164-
}
195+
]
165196
if obj.outbound_rules
166197
else {}
167198
)

sdk/ml/azure-ai-ml/azure/ai/ml/operations/_workspace_operations_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,7 @@ def callback(_, deserialized, args):
260260
poller = self._operation.begin_update(resource_group, workspace_name, update_param, polling=True, cls=callback)
261261
return poller
262262

263-
def begin_delete(self, name: str, *, delete_dependent_resources: bool, **kwargs: Dict) -> LROPoller:
263+
def begin_delete(self, name: str, *, delete_dependent_resources: bool, **kwargs: Dict) -> LROPoller[None]:
264264
workspace = self.get(name, **kwargs)
265265
resource_group = kwargs.get("resource_group") or self._resource_group_name
266266
if delete_dependent_resources:

sdk/ml/azure-ai-ml/azure/ai/ml/operations/_workspace_outbound_rule_operations.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# Copyright (c) Microsoft Corporation. All rights reserved.
33
# ---------------------------------------------------------
44

5-
from typing import Dict
5+
from typing import Dict, Iterable
66
from azure.ai.ml._restclient.v2022_12_01_preview import AzureMachineLearningWorkspaces as ServiceClient122022Preview
77
from azure.ai.ml._scope_dependent_operations import OperationsContainer, OperationScope
88

@@ -36,29 +36,31 @@ def __init__(
3636
self._credentials = credentials
3737
self._init_kwargs = kwargs
3838

39-
@monitor_with_activity(logger, "WorkspaceOutboundRule.Show", ActivityType.PUBLICAPI)
40-
def show(self, resource_group: str, ws_name: str, outbound_rule_name: str, **kwargs) -> OutboundRule:
39+
@monitor_with_activity(logger, "WorkspaceOutboundRule.Get", ActivityType.PUBLICAPI)
40+
def get(self, resource_group: str, ws_name: str, outbound_rule_name: str, **kwargs) -> OutboundRule:
4141
workspace_name = self._check_workspace_name(ws_name)
4242
resource_group = kwargs.get("resource_group") or self._resource_group_name
4343

4444
obj = self._rule_operation.get(resource_group, workspace_name, outbound_rule_name)
45-
return OutboundRule._from_rest_object(obj) # pylint: disable=protected-access
45+
return OutboundRule._from_rest_object(obj, rule_name=outbound_rule_name) # pylint: disable=protected-access
4646

4747
@monitor_with_activity(logger, "WorkspaceOutboundRule.List", ActivityType.PUBLICAPI)
48-
def list(self, resource_group: str, ws_name: str, **kwargs) -> Dict[str, OutboundRule]:
48+
def list(self, resource_group: str, ws_name: str, **kwargs) -> Iterable[OutboundRule]:
4949
workspace_name = self._check_workspace_name(ws_name)
5050
resource_group = kwargs.get("resource_group") or self._resource_group_name
5151

5252
rest_rules = self._rule_operation.list(resource_group, workspace_name)
5353

54-
result = {
55-
rule_name: OutboundRule._from_rest_object(rest_rules[rule_name]) # pylint: disable=protected-access
54+
result = [
55+
OutboundRule._from_rest_object( # pylint: disable=protected-access
56+
rest_obj=rest_rules[rule_name], rule_name=rule_name
57+
)
5658
for rule_name in rest_rules.keys()
57-
}
59+
]
5860
return result
5961

6062
@monitor_with_activity(logger, "WorkspaceOutboundRule.Remove", ActivityType.PUBLICAPI)
61-
def remove(self, resource_group: str, ws_name: str, outbound_rule_name: str, **kwargs) -> LROPoller:
63+
def begin_remove(self, resource_group: str, ws_name: str, outbound_rule_name: str, **kwargs) -> LROPoller[None]:
6264
workspace_name = self._check_workspace_name(ws_name)
6365
resource_group = kwargs.get("resource_group") or self._resource_group_name
6466

0 commit comments

Comments
 (0)