|
1 | 1 | from typing import Optional |
2 | 2 |
|
3 | 3 | import pytest |
| 4 | +import json |
4 | 5 | from marshmallow.exceptions import ValidationError |
5 | 6 |
|
6 | 7 | from azure.ai.ml import load_workspace |
7 | | -from azure.ai.ml._restclient.v2024_10_01_preview.models import Workspace |
| 8 | +from azure.ai.ml._restclient.v2024_10_01_preview.models import ( |
| 9 | + Workspace as RestWorkspace, |
| 10 | +) |
8 | 11 | from azure.ai.ml.constants._workspace import FirewallSku, IsolationMode |
9 | 12 | from azure.ai.ml.entities import ServerlessComputeSettings, Workspace |
10 | 13 |
|
@@ -38,6 +41,95 @@ def test_serverless_compute_settings_loaded_from_rest_object( |
38 | 41 | else: |
39 | 42 | assert ServerlessComputeSettings._from_rest_object(rest_object.serverless_compute_settings) == settings |
40 | 43 |
|
| 44 | + def test_from_rest_object(self) -> None: |
| 45 | + with open("./tests/test_configs/workspace/workspace_full_rest_response.json", "r") as f: |
| 46 | + rest_object = RestWorkspace.deserialize(json.load(f)) |
| 47 | + |
| 48 | + workspace = Workspace._from_rest_object(rest_object) |
| 49 | + |
| 50 | + assert ( |
| 51 | + workspace.id |
| 52 | + == "/subscriptions/sub-id/test_workspace/providers/Microsoft.Storage/storageAccounts/storage-account-name" |
| 53 | + ) |
| 54 | + assert workspace.name == "test_workspace" |
| 55 | + assert workspace.location == "test_location" |
| 56 | + assert workspace.description == "test_description" |
| 57 | + assert workspace.tags == {"test_tag": "test_value"} |
| 58 | + assert workspace.display_name == "test_friendly_name" |
| 59 | + assert workspace.discovery_url == "test_discovery_url" |
| 60 | + assert workspace.resource_group == "providers" |
| 61 | + assert workspace.storage_account == "test_storage_account" |
| 62 | + assert workspace.key_vault == "test_key_vault" |
| 63 | + assert workspace.application_insights == "test_application_insights" |
| 64 | + assert workspace.container_registry == "test_container_registry" |
| 65 | + assert workspace.customer_managed_key.key_uri == "key_identifier" |
| 66 | + assert workspace.customer_managed_key.key_vault == "key_vault_arm_id" |
| 67 | + assert workspace.hbi_workspace is True |
| 68 | + assert workspace.public_network_access == "Enabled" |
| 69 | + assert workspace.image_build_compute == "test_image_build_compute" |
| 70 | + assert workspace.discovery_url == "test_discovery_url" |
| 71 | + assert workspace.mlflow_tracking_uri == "ml_flow_tracking_uri" |
| 72 | + assert workspace.primary_user_assigned_identity == "test_primary_user_assigned_identity" |
| 73 | + assert workspace.system_datastores_auth_mode == "AccessKey" |
| 74 | + assert workspace.enable_data_isolation == True |
| 75 | + assert workspace.allow_roleassignment_on_rg == True |
| 76 | + assert workspace._hub_id == "hub_resource_id" |
| 77 | + assert workspace._kind == "project" |
| 78 | + assert workspace._workspace_id == "workspace_id" |
| 79 | + assert workspace.identity is not None |
| 80 | + assert workspace.managed_network is not None |
| 81 | + assert workspace._feature_store_settings is not None |
| 82 | + assert workspace.network_acls is not None |
| 83 | + assert workspace.provision_network_now == True |
| 84 | + assert workspace.serverless_compute is not None |
| 85 | + assert workspace.network_acls is not None |
| 86 | + |
| 87 | + def test_from_rest_object_for_attributes_none(self) -> None: |
| 88 | + with open("./tests/test_configs/workspace/workspace_full_rest_response.json", "r") as f: |
| 89 | + rest_json = json.load(f) |
| 90 | + del rest_json["properties"]["managedNetwork"] |
| 91 | + del rest_json["properties"]["encryption"] |
| 92 | + rest_json["id"] = "/subscriptions/sub-id" |
| 93 | + del rest_json["identity"] |
| 94 | + del rest_json["properties"]["featureStoreSettings"] |
| 95 | + del rest_json["properties"]["serverlessComputeSettings"] |
| 96 | + del rest_json["properties"]["networkAcls"] |
| 97 | + rest_object = RestWorkspace.deserialize(rest_json) |
| 98 | + |
| 99 | + workspace = Workspace._from_rest_object(rest_object) |
| 100 | + |
| 101 | + assert workspace.id == "/subscriptions/sub-id" |
| 102 | + assert workspace.name == "test_workspace" |
| 103 | + assert workspace.location == "test_location" |
| 104 | + assert workspace.description == "test_description" |
| 105 | + assert workspace.tags == {"test_tag": "test_value"} |
| 106 | + assert workspace.display_name == "test_friendly_name" |
| 107 | + assert workspace.discovery_url == "test_discovery_url" |
| 108 | + assert workspace.resource_group is None |
| 109 | + assert workspace.storage_account == "test_storage_account" |
| 110 | + assert workspace.key_vault == "test_key_vault" |
| 111 | + assert workspace.application_insights == "test_application_insights" |
| 112 | + assert workspace.container_registry == "test_container_registry" |
| 113 | + assert workspace.customer_managed_key is None |
| 114 | + assert workspace.hbi_workspace is True |
| 115 | + assert workspace.public_network_access == "Enabled" |
| 116 | + assert workspace.image_build_compute == "test_image_build_compute" |
| 117 | + assert workspace.discovery_url == "test_discovery_url" |
| 118 | + assert workspace.mlflow_tracking_uri == "ml_flow_tracking_uri" |
| 119 | + assert workspace.primary_user_assigned_identity == "test_primary_user_assigned_identity" |
| 120 | + assert workspace.system_datastores_auth_mode == "AccessKey" |
| 121 | + assert workspace.enable_data_isolation == True |
| 122 | + assert workspace.allow_roleassignment_on_rg == True |
| 123 | + assert workspace._hub_id == "hub_resource_id" |
| 124 | + assert workspace._kind == "project" |
| 125 | + assert workspace._workspace_id == "workspace_id" |
| 126 | + assert workspace.identity is None |
| 127 | + assert workspace.managed_network is None |
| 128 | + assert workspace._feature_store_settings is None |
| 129 | + assert workspace.network_acls is None |
| 130 | + assert workspace.provision_network_now == True |
| 131 | + assert workspace.serverless_compute is None |
| 132 | + |
41 | 133 | def test_serverless_compute_settings_subnet_name_must_be_an_arm_id(self) -> None: |
42 | 134 | with pytest.raises(ValidationError): |
43 | 135 | ServerlessComputeSettings(custom_subnet="justaname", no_public_ip=True) |
@@ -65,11 +157,17 @@ def test_serverless_compute_settings_subnet_name_must_be_an_arm_id(self) -> None |
65 | 157 | ) |
66 | 158 | def test_workspace_load_override_serverless(self, settings: ServerlessComputeSettings) -> None: |
67 | 159 | params_override = [ |
68 | | - {"serverless_compute": {"custom_subnet": settings.custom_subnet, "no_public_ip": settings.no_public_ip}} |
| 160 | + { |
| 161 | + "serverless_compute": { |
| 162 | + "custom_subnet": settings.custom_subnet, |
| 163 | + "no_public_ip": settings.no_public_ip, |
| 164 | + } |
| 165 | + } |
69 | 166 | ] |
70 | 167 |
|
71 | 168 | workspace_override = load_workspace( |
72 | | - "./tests/test_configs/workspace/workspace_serverless.yaml", params_override=params_override |
| 169 | + "./tests/test_configs/workspace/workspace_serverless.yaml", |
| 170 | + params_override=params_override, |
73 | 171 | ) |
74 | 172 | assert workspace_override.serverless_compute == settings |
75 | 173 |
|
|
0 commit comments