66
77from typing import Any , Dict , Optional
88
9- from azure .ai .ml ._restclient .v2022_10_01_preview .models import AmlCompute as AmlComputeRest
10- from azure .ai .ml ._restclient .v2022_10_01_preview .models import (
9+ from azure .ai .ml ._restclient .v2022_12_01_preview .models import (
10+ AmlCompute as AmlComputeRest ,
11+ )
12+ from azure .ai .ml ._restclient .v2022_12_01_preview .models import (
1113 AmlComputeProperties ,
1214 ComputeResource ,
1315 ResourceId ,
1618)
1719from azure .ai .ml ._schema ._utils .utils import get_subnet_str
1820from azure .ai .ml ._schema .compute .aml_compute import AmlComputeSchema
19- from azure .ai .ml ._utils .utils import camel_to_snake , snake_to_pascal , to_iso_duration_format
21+ from azure .ai .ml ._utils .utils import (
22+ camel_to_snake ,
23+ snake_to_pascal ,
24+ to_iso_duration_format ,
25+ )
2026from azure .ai .ml .constants ._common import BASE_PATH_CONTEXT_KEY , TYPE
2127from azure .ai .ml .constants ._compute import ComputeDefaults , ComputeType
2228from azure .ai .ml .entities ._credentials import IdentityConfiguration
@@ -180,7 +186,7 @@ def _load_from_rest(cls, rest_obj: ComputeResource) -> "AmlCompute":
180186 name = rest_obj .name ,
181187 id = rest_obj .id ,
182188 description = prop .description ,
183- location = prop .compute_location if prop .compute_location else rest_obj .location ,
189+ location = ( prop .compute_location if prop .compute_location else rest_obj .location ) ,
184190 tags = rest_obj .tags if rest_obj .tags else None ,
185191 provisioning_state = prop .provisioning_state ,
186192 provisioning_errors = (
@@ -190,8 +196,8 @@ def _load_from_rest(cls, rest_obj: ComputeResource) -> "AmlCompute":
190196 ),
191197 size = prop .properties .vm_size ,
192198 tier = camel_to_snake (prop .properties .vm_priority ),
193- min_instances = prop .properties .scale_settings .min_node_count if prop .properties .scale_settings else None ,
194- max_instances = prop .properties .scale_settings .max_node_count if prop .properties .scale_settings else None ,
199+ min_instances = ( prop .properties .scale_settings .min_node_count if prop .properties .scale_settings else None ) ,
200+ max_instances = ( prop .properties .scale_settings .max_node_count if prop .properties .scale_settings else None ) ,
195201 network_settings = network_settings or None ,
196202 ssh_settings = ssh_settings ,
197203 ssh_public_access_enabled = (prop .properties .remote_login_port_public_access == "Enabled" ),
@@ -200,7 +206,9 @@ def _load_from_rest(cls, rest_obj: ComputeResource) -> "AmlCompute":
200206 if prop .properties .scale_settings and prop .properties .scale_settings .node_idle_time_before_scale_down
201207 else None
202208 ),
203- identity = IdentityConfiguration ._from_compute_rest_object (rest_obj .identity ) if rest_obj .identity else None ,
209+ identity = (
210+ IdentityConfiguration ._from_compute_rest_object (rest_obj .identity ) if rest_obj .identity else None
211+ ),
204212 created_on = prop .additional_properties .get ("createdOn" , None ),
205213 enable_node_public_ip = (
206214 prop .properties .enable_node_public_ip if prop .properties .enable_node_public_ip is not None else True
@@ -244,21 +252,28 @@ def _to_rest_object(self) -> ComputeResource:
244252 ),
245253 )
246254 remote_login_public_access = "Enabled"
255+ disableLocalAuth = not (self .ssh_public_access_enabled and self .ssh_settings is not None )
247256 if self .ssh_public_access_enabled is not None :
248257 remote_login_public_access = "Enabled" if self .ssh_public_access_enabled else "Disabled"
258+
249259 else :
250260 remote_login_public_access = "NotSpecified"
251261 aml_prop = AmlComputeProperties (
252262 vm_size = self .size if self .size else ComputeDefaults .VMSIZE ,
253263 vm_priority = snake_to_pascal (self .tier ),
254- user_account_credentials = self .ssh_settings ._to_user_account_credentials () if self .ssh_settings else None ,
264+ user_account_credentials = ( self .ssh_settings ._to_user_account_credentials () if self .ssh_settings else None ) ,
255265 scale_settings = scale_settings ,
256266 subnet = subnet_resource ,
257267 remote_login_port_public_access = remote_login_public_access ,
258268 enable_node_public_ip = self .enable_node_public_ip ,
259269 )
260270
261- aml_comp = AmlComputeRest (description = self .description , compute_type = self .type , properties = aml_prop )
271+ aml_comp = AmlComputeRest (
272+ description = self .description ,
273+ compute_type = self .type ,
274+ properties = aml_prop ,
275+ disable_local_auth = disableLocalAuth ,
276+ )
262277 return ComputeResource (
263278 location = self .location ,
264279 properties = aml_comp ,
0 commit comments