1+ from datetime import timedelta
12from typing import List , Union
23
34import pytest
45import yaml
56from msrest import Serializer
6-
7- from azure .ai .ml ._restclient .v2023_04_01_preview .models import DataFactory
87from test_utilities .utils import verify_entity_load_and_dump
98
109from azure .ai .ml import load_compute
10+ from azure .ai .ml ._restclient .v2023_04_01_preview .models import DataFactory
1111from azure .ai .ml ._restclient .v2023_08_01_preview .models import ComputeResource , ImageMetadata
1212from azure .ai .ml .constants ._compute import CustomApplicationDefaults
1313from azure .ai .ml .entities import (
1717 KubernetesCompute ,
1818 ManagedIdentityConfiguration ,
1919 SynapseSparkCompute ,
20- VirtualMachineCompute ,
2120 UnsupportedCompute ,
21+ VirtualMachineCompute ,
2222)
2323
2424
@@ -66,6 +66,9 @@ def test_compute_from_yaml(self):
6666 )[0 ]
6767 assert compute .ssh_settings .admin_username == "azureuser"
6868 assert compute .identity .type == "user_assigned"
69+ assert compute .idle_time_before_scale_down == 100
70+ assert compute .min_instances == 0
71+ assert compute .max_instances == 2
6972
7073 rest_intermediate = compute ._to_rest_object ()
7174 assert rest_intermediate .properties .compute_type == "AmlCompute"
@@ -76,7 +79,9 @@ def test_compute_from_yaml(self):
7679 assert rest_intermediate .tags is not None
7780 assert rest_intermediate .tags ["test" ] == "true"
7881 assert rest_intermediate .properties .disable_local_auth is False
79- assert rest_intermediate .properties .properties .remote_login_port_public_access == "Enabled"
82+ assert rest_intermediate .properties .properties .scale_settings .max_node_count == 2
83+ assert rest_intermediate .properties .properties .scale_settings .min_node_count == 0
84+ assert rest_intermediate .properties .properties .scale_settings .node_idle_time_before_scale_down == "PT1M40S"
8085
8186 serializer = Serializer ({"ComputeResource" : ComputeResource })
8287 body = serializer .body (rest_intermediate , "ComputeResource" )
@@ -101,6 +106,9 @@ def test_aml_compute_from_yaml_with_disable_public_access(self):
101106 assert rest_intermediate .properties .disable_local_auth is True
102107 assert rest_intermediate .location == compute .location
103108 assert rest_intermediate .properties .properties .remote_login_port_public_access == "NotSpecified"
109+ assert rest_intermediate .properties .properties .scale_settings .max_node_count == 4
110+ assert rest_intermediate .properties .properties .scale_settings .min_node_count == 0
111+ assert rest_intermediate .properties .properties .scale_settings .node_idle_time_before_scale_down == "PT2M"
104112
105113 def test_aml_compute_from_yaml_with_creds_and_disable_public_access (self ):
106114 compute : AmlCompute = load_compute ("tests/test_configs/compute/compute-aml-no-identity.yaml" )
@@ -345,6 +353,11 @@ def validate_no_public_ip(compute: Compute):
345353 assert compute .enable_node_public_ip == False
346354 compute_resource = compute ._to_rest_object ()
347355 assert compute_resource .properties .properties .enable_node_public_ip == False
356+ # AmlCompute _from_rest_object expects a timedelta object for node_idle_time_before_scale_down
357+ if compute_resource .properties .compute_type == "AmlCompute" :
358+ compute_resource .properties .properties .scale_settings .node_idle_time_before_scale_down = timedelta (
359+ seconds = 120
360+ )
348361 compute_from_rest = Compute ._from_rest_object (compute_resource )
349362 assert compute_from_rest .enable_node_public_ip == False
350363
0 commit comments