Skip to content

Commit 6c0728a

Browse files
authored
add disableLocalAuth while creating aml compute (Azure#38913)
* add disableLocalAuth while creating aml compute * add unit tests * push recordings * optimize expression * format * fix black error format * update logic for disable local auth * formatting * add info to changelog
1 parent ce16093 commit 6c0728a

File tree

6 files changed

+82
-14
lines changed

6 files changed

+82
-14
lines changed

sdk/ml/azure-ai-ml/CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
### Features Added
55

66
### Bugs Fixed
7-
7+
- Fixed disableLocalAuthentication handling while creating amlCompute
88

99
## 1.23.0 (2024-12-05)
1010

sdk/ml/azure-ai-ml/assets.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,5 @@
22
"AssetsRepo": "Azure/azure-sdk-assets",
33
"AssetsRepoPrefixPath": "python",
44
"TagPrefix": "python/ml/azure-ai-ml",
5-
"Tag": "python/ml/azure-ai-ml_d220df7fea"
5+
"Tag": "python/ml/azure-ai-ml_003b900b39"
66
}

sdk/ml/azure-ai-ml/azure/ai/ml/entities/_compute/aml_compute.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,10 @@
66

77
from typing import Any, Dict, Optional
88

9-
from azure.ai.ml._restclient.v2022_10_01_preview.models import AmlCompute as AmlComputeRest
10-
from azure.ai.ml._restclient.v2022_10_01_preview.models import (
9+
from azure.ai.ml._restclient.v2022_12_01_preview.models import (
10+
AmlCompute as AmlComputeRest,
11+
)
12+
from azure.ai.ml._restclient.v2022_12_01_preview.models import (
1113
AmlComputeProperties,
1214
ComputeResource,
1315
ResourceId,
@@ -16,7 +18,11 @@
1618
)
1719
from azure.ai.ml._schema._utils.utils import get_subnet_str
1820
from azure.ai.ml._schema.compute.aml_compute import AmlComputeSchema
19-
from azure.ai.ml._utils.utils import camel_to_snake, snake_to_pascal, to_iso_duration_format
21+
from azure.ai.ml._utils.utils import (
22+
camel_to_snake,
23+
snake_to_pascal,
24+
to_iso_duration_format,
25+
)
2026
from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, TYPE
2127
from azure.ai.ml.constants._compute import ComputeDefaults, ComputeType
2228
from azure.ai.ml.entities._credentials import IdentityConfiguration
@@ -180,7 +186,7 @@ def _load_from_rest(cls, rest_obj: ComputeResource) -> "AmlCompute":
180186
name=rest_obj.name,
181187
id=rest_obj.id,
182188
description=prop.description,
183-
location=prop.compute_location if prop.compute_location else rest_obj.location,
189+
location=(prop.compute_location if prop.compute_location else rest_obj.location),
184190
tags=rest_obj.tags if rest_obj.tags else None,
185191
provisioning_state=prop.provisioning_state,
186192
provisioning_errors=(
@@ -190,8 +196,8 @@ def _load_from_rest(cls, rest_obj: ComputeResource) -> "AmlCompute":
190196
),
191197
size=prop.properties.vm_size,
192198
tier=camel_to_snake(prop.properties.vm_priority),
193-
min_instances=prop.properties.scale_settings.min_node_count if prop.properties.scale_settings else None,
194-
max_instances=prop.properties.scale_settings.max_node_count if prop.properties.scale_settings else None,
199+
min_instances=(prop.properties.scale_settings.min_node_count if prop.properties.scale_settings else None),
200+
max_instances=(prop.properties.scale_settings.max_node_count if prop.properties.scale_settings else None),
195201
network_settings=network_settings or None,
196202
ssh_settings=ssh_settings,
197203
ssh_public_access_enabled=(prop.properties.remote_login_port_public_access == "Enabled"),
@@ -200,7 +206,9 @@ def _load_from_rest(cls, rest_obj: ComputeResource) -> "AmlCompute":
200206
if prop.properties.scale_settings and prop.properties.scale_settings.node_idle_time_before_scale_down
201207
else None
202208
),
203-
identity=IdentityConfiguration._from_compute_rest_object(rest_obj.identity) if rest_obj.identity else None,
209+
identity=(
210+
IdentityConfiguration._from_compute_rest_object(rest_obj.identity) if rest_obj.identity else None
211+
),
204212
created_on=prop.additional_properties.get("createdOn", None),
205213
enable_node_public_ip=(
206214
prop.properties.enable_node_public_ip if prop.properties.enable_node_public_ip is not None else True
@@ -244,21 +252,28 @@ def _to_rest_object(self) -> ComputeResource:
244252
),
245253
)
246254
remote_login_public_access = "Enabled"
255+
disableLocalAuth = not (self.ssh_public_access_enabled and self.ssh_settings is not None)
247256
if self.ssh_public_access_enabled is not None:
248257
remote_login_public_access = "Enabled" if self.ssh_public_access_enabled else "Disabled"
258+
249259
else:
250260
remote_login_public_access = "NotSpecified"
251261
aml_prop = AmlComputeProperties(
252262
vm_size=self.size if self.size else ComputeDefaults.VMSIZE,
253263
vm_priority=snake_to_pascal(self.tier),
254-
user_account_credentials=self.ssh_settings._to_user_account_credentials() if self.ssh_settings else None,
264+
user_account_credentials=(self.ssh_settings._to_user_account_credentials() if self.ssh_settings else None),
255265
scale_settings=scale_settings,
256266
subnet=subnet_resource,
257267
remote_login_port_public_access=remote_login_public_access,
258268
enable_node_public_ip=self.enable_node_public_ip,
259269
)
260270

261-
aml_comp = AmlComputeRest(description=self.description, compute_type=self.type, properties=aml_prop)
271+
aml_comp = AmlComputeRest(
272+
description=self.description,
273+
compute_type=self.type,
274+
properties=aml_prop,
275+
disable_local_auth=disableLocalAuth,
276+
)
262277
return ComputeResource(
263278
location=self.location,
264279
properties=aml_comp,

sdk/ml/azure-ai-ml/tests/compute/unittests/test_compute_entity.py

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,21 +54,24 @@ def test_compute_from_rest(self):
5454

5555
def _test_loaded_compute(self, compute: AmlCompute):
5656
assert compute.name == "banchaml"
57-
assert compute.ssh_settings.admin_username == "azureuser"
58-
assert compute.identity.type == "user_assigned"
57+
assert compute.type == "amlcompute"
58+
assert compute.location == "eastus"
59+
assert compute.description == "some_desc_aml"
5960

6061
def test_compute_from_yaml(self):
6162
compute: AmlCompute = verify_entity_load_and_dump(
6263
load_compute,
6364
self._test_loaded_compute,
6465
"tests/test_configs/compute/compute-aml.yaml",
6566
)[0]
66-
assert compute.location == "eastus"
67+
assert compute.ssh_settings.admin_username == "azureuser"
68+
assert compute.identity.type == "user_assigned"
6769

6870
rest_intermediate = compute._to_rest_object()
6971
assert rest_intermediate.properties.compute_type == "AmlCompute"
7072
assert rest_intermediate.properties.properties.user_account_credentials.admin_user_name == "azureuser"
7173
assert rest_intermediate.properties.properties.enable_node_public_ip
74+
assert rest_intermediate.properties.disable_local_auth is False
7275
assert rest_intermediate.location == compute.location
7376
assert rest_intermediate.tags is not None
7477
assert rest_intermediate.tags["test"] == "true"
@@ -81,6 +84,36 @@ def test_compute_from_yaml(self):
8184
)
8285
assert body["location"] == compute.location
8386

87+
def test_aml_compute_from_yaml_with_disable_public_access(self):
88+
89+
compute: AmlCompute = verify_entity_load_and_dump(
90+
load_compute,
91+
self._test_loaded_compute,
92+
"tests/test_configs/compute/compute-aml-disable-public-access.yaml",
93+
)[0]
94+
95+
rest_intermediate = compute._to_rest_object()
96+
97+
assert rest_intermediate.properties.compute_type == "AmlCompute"
98+
assert rest_intermediate.properties.properties.enable_node_public_ip
99+
assert rest_intermediate.properties.disable_local_auth is True
100+
assert rest_intermediate.location == compute.location
101+
102+
def test_aml_compute_from_yaml_with_disable_public_access_when_no_sshSettings(self):
103+
104+
compute: AmlCompute = verify_entity_load_and_dump(
105+
load_compute,
106+
self._test_loaded_compute,
107+
"tests/test_configs/compute/compute-aml-public-access-no-ssh.yaml",
108+
)[0]
109+
110+
rest_intermediate = compute._to_rest_object()
111+
112+
assert rest_intermediate.properties.compute_type == "AmlCompute"
113+
assert rest_intermediate.properties.properties.enable_node_public_ip
114+
assert rest_intermediate.properties.disable_local_auth is True
115+
assert rest_intermediate.location == compute.location
116+
84117
def test_compute_vm_from_yaml(self):
85118
resource_id = "/subscriptions/13e50845-67bc-4ac5-94db-48d493a6d9e8/resourceGroups/myrg/providers/Microsoft.Compute/virtualMachines/myvm"
86119
fake_key = "myfakekey"
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
name: banchaml
2+
type: amlcompute
3+
tier: dedicated
4+
description: some_desc_aml
5+
size: Standard_DS2_v2
6+
min_instances: 0
7+
max_instances: 2
8+
location: eastus
9+
idle_time_before_scale_down: 120
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
name: banchaml
2+
type: amlcompute
3+
description: some_desc_aml
4+
size: Standard_DS2_v2
5+
location: eastus
6+
tags:
7+
test: "true"
8+
ssh_public_access_enabled: true
9+
max_instances: 2
10+
idle_time_before_scale_down: 100
11+
enable_node_public_ip: true

0 commit comments

Comments
 (0)