Skip to content

Commit 65e5a32

Browse files
authored
fix to_dict (#35501)
1 parent 7734b1c commit 65e5a32

File tree

10 files changed

+197
-75
lines changed

10 files changed

+197
-75
lines changed

sdk/ml/azure-ai-ml/assets.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,5 @@
22
"AssetsRepo": "Azure/azure-sdk-assets",
33
"AssetsRepoPrefixPath": "python",
44
"TagPrefix": "python/ml/azure-ai-ml",
5-
"Tag": "python/ml/azure-ai-ml_ce8aa03671"
5+
"Tag": "python/ml/azure-ai-ml_bcde27db64"
66
}

sdk/ml/azure-ai-ml/azure/ai/ml/_schema/workspace/connections/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
OpenAIConnectionSchema,
1818
SerpConnectionSchema,
1919
ServerlessConnectionSchema,
20+
OneLakeArtifactSchema,
2021
)
2122

2223
__all__ = [
@@ -32,4 +33,5 @@
3233
"OpenAIConnectionSchema",
3334
"SerpConnectionSchema",
3435
"ServerlessConnectionSchema",
36+
"OneLakeArtifactSchema",
3537
]

sdk/ml/azure-ai-ml/azure/ai/ml/_schema/workspace/connections/connection.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from marshmallow import fields, post_load
88

9-
from azure.ai.ml._restclient.v2023_06_01_preview.models import ConnectionCategory
9+
from azure.ai.ml._restclient.v2024_04_01_preview.models import ConnectionCategory
1010
from azure.ai.ml._schema.core.fields import NestedField, StringTransformedEnum, UnionField
1111
from azure.ai.ml._schema.core.resource import ResourceSchema
1212
from azure.ai.ml._schema.job import CreationContextSchema
@@ -19,6 +19,8 @@
1919
UsernamePasswordConfigurationSchema,
2020
AccessKeyConfigurationSchema,
2121
ApiKeyConfigurationSchema,
22+
AadCredentialConfigurationSchema,
23+
NoneCredentialConfigurationSchema,
2224
)
2325
from azure.ai.ml._utils.utils import camel_to_snake
2426
from azure.ai.ml.constants._common import ConnectionTypes
@@ -60,6 +62,8 @@ class ConnectionSchema(ResourceSchema):
6062
NestedField(AccessKeyConfigurationSchema),
6163
NestedField(ApiKeyConfigurationSchema),
6264
NestedField(AccountKeyConfigurationSchema),
65+
NestedField(AadCredentialConfigurationSchema),
66+
NestedField(NoneCredentialConfigurationSchema),
6367
],
6468
required=False,
6569
load_default=NoneCredentialConfiguration(),
@@ -71,7 +75,8 @@ class ConnectionSchema(ResourceSchema):
7175
def make(self, data, **kwargs):
7276
from azure.ai.ml.entities import Connection
7377

74-
# Replace ALDS gen 2 empty default with AAD over None
78+
# Most non-subclassed connections default to a none credential if none
79+
# is provided. ALDS Gen 2 connections default to AAD with this code.
7580
if (
7681
data.get("type", None) == ConnectionTypes.AZURE_DATA_LAKE_GEN_2
7782
and data.get("credentials", None) == NoneCredentialConfiguration()

sdk/ml/azure-ai-ml/azure/ai/ml/_schema/workspace/connections/connection_subtypes.py

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
# pylint: disable=unused-argument
66

77
from marshmallow import fields, post_load
8+
from marshmallow.exceptions import ValidationError
9+
from marshmallow.decorators import pre_load
810

911
from azure.ai.ml._restclient.v2024_04_01_preview.models import ConnectionCategory
1012
from azure.ai.ml._schema.core.fields import NestedField, StringTransformedEnum, UnionField
@@ -15,6 +17,7 @@
1517
SasTokenConfigurationSchema,
1618
ServicePrincipalConfigurationSchema,
1719
AccountKeyConfigurationSchema,
20+
AadCredentialConfigurationSchema,
1821
)
1922
from azure.ai.ml.entities import AadCredentialConfiguration
2023
from .connection import ConnectionSchema
@@ -30,6 +33,7 @@ class AzureBlobStoreConnectionSchema(ConnectionSchema):
3033
[
3134
NestedField(SasTokenConfigurationSchema),
3235
NestedField(AccountKeyConfigurationSchema),
36+
NestedField(AadCredentialConfigurationSchema),
3337
],
3438
required=False,
3539
load_default=AadCredentialConfiguration(),
@@ -52,13 +56,34 @@ class MicrosoftOneLakeConnectionSchema(ConnectionSchema):
5256
type = StringTransformedEnum(
5357
allowed_values=ConnectionCategory.AZURE_ONE_LAKE, casing_transform=camel_to_snake, required=True
5458
)
55-
credentials = NestedField(
56-
ServicePrincipalConfigurationSchema, required=False, load_default=AadCredentialConfiguration()
59+
credentials = UnionField(
60+
[NestedField(ServicePrincipalConfigurationSchema), NestedField(AadCredentialConfigurationSchema)],
61+
required=False,
62+
load_default=AadCredentialConfiguration(),
5763
)
58-
artifact = NestedField(OneLakeArtifactSchema, required=True)
59-
60-
endpoint = fields.Str()
61-
one_lake_workspace_name = fields.Str()
64+
artifact = NestedField(OneLakeArtifactSchema, required=False, allow_none=True)
65+
66+
endpoint = fields.Str(required=False)
67+
one_lake_workspace_name = fields.Str(required=False)
68+
69+
@pre_load
70+
def check_for_target(self, data, **kwargs):
71+
target = data.get("target", None)
72+
artifact = data.get("artifact", None)
73+
endpoint = data.get("endpoint", None)
74+
one_lake_workspace_name = data.get("one_lake_workspace_name", None)
75+
# If the user is using a target, then they don't need the artifact and one lake workspace name.
76+
# This is distinct from when the user set's the 'endpoint' value, which is also used to construct
77+
# the target. If the target is already present, then the loaded connection YAML was probably produced
78+
# by dumping an extant connection.
79+
if target is None:
80+
if artifact is None:
81+
raise ValidationError("If target is unset, then artifact must be set")
82+
if endpoint is None:
83+
raise ValidationError("If target is unset, then endpoint must be set")
84+
if one_lake_workspace_name is None:
85+
raise ValidationError("If target is unset, then one_lake_workspace_name must be set")
86+
return data
6287

6388
@post_load
6489
def make(self, data, **kwargs):

sdk/ml/azure-ai-ml/azure/ai/ml/_schema/workspace/connections/credentials.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
AccessKeyConfiguration,
2626
ApiKeyConfiguration,
2727
AccountKeyConfiguration,
28+
AadCredentialConfiguration,
29+
NoneCredentialConfiguration,
2830
)
2931

3032

@@ -148,3 +150,29 @@ class AccountKeyConfigurationSchema(metaclass=PatchedSchemaMeta):
148150
def make(self, data: Dict[str, str], **kwargs) -> AccountKeyConfiguration:
149151
data.pop("type")
150152
return AccountKeyConfiguration(**data)
153+
154+
155+
class AadCredentialConfigurationSchema(metaclass=PatchedSchemaMeta):
156+
type = StringTransformedEnum(
157+
allowed_values=ConnectionAuthType.AAD,
158+
casing_transform=camel_to_snake,
159+
required=True,
160+
)
161+
162+
@post_load
163+
def make(self, data: Dict[str, str], **kwargs) -> AadCredentialConfiguration:
164+
data.pop("type")
165+
return AadCredentialConfiguration(**data)
166+
167+
168+
class NoneCredentialConfigurationSchema(metaclass=PatchedSchemaMeta):
169+
type = StringTransformedEnum(
170+
allowed_values=ConnectionAuthType.NONE,
171+
casing_transform=camel_to_snake,
172+
required=True,
173+
)
174+
175+
@post_load
176+
def make(self, data: Dict[str, str], **kwargs) -> NoneCredentialConfiguration:
177+
data.pop("type")
178+
return NoneCredentialConfiguration(**data)

sdk/ml/azure-ai-ml/azure/ai/ml/entities/_workspace/connections/connection_subtypes.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -162,10 +162,10 @@ class MicrosoftOneLakeConnection(Connection):
162162
:param endpoint: The endpoint of the connection.
163163
:type endpoint: str
164164
:param artifact: The artifact class used to further specify the connection.
165-
:type artifact: ~azure.ai.ml.entities.OneLakeArtifact
165+
:type artifact: Optional[~azure.ai.ml.entities.OneLakeArtifact]
166166
:param one_lake_workspace_name: The name, not ID, of the workspace where the One Lake
167167
resource lives.
168-
:type one_lake_workspace_name: str
168+
:type one_lake_workspace_name: Optional[str]
169169
:param credentials: The credentials for authenticating to the blob store. This type of
170170
connection accepts 3 types of credentials: account key and SAS token credentials,
171171
or NoneCredentialConfiguration for credential-less connections.
@@ -182,8 +182,8 @@ def __init__(
182182
self,
183183
*,
184184
endpoint: str,
185-
artifact: OneLakeConnectionArtifact,
186-
one_lake_workspace_name: str,
185+
artifact: Optional[OneLakeConnectionArtifact] = None,
186+
one_lake_workspace_name: Optional[str] = None,
187187
**kwargs,
188188
):
189189
kwargs.pop("type", None) # make sure we never somehow use wrong type
@@ -192,6 +192,12 @@ def __init__(
192192
# need to worry about data-availability nonsense.
193193
target = kwargs.pop("target", None)
194194
if target is None:
195+
if artifact is None:
196+
raise ValueError("If target is unset, then artifact must be set")
197+
if endpoint is None:
198+
raise ValueError("If target is unset, then endpoint must be set")
199+
if one_lake_workspace_name is None:
200+
raise ValueError("If target is unset, then one_lake_workspace_name must be set")
195201
target = MicrosoftOneLakeConnection._construct_target(endpoint, one_lake_workspace_name, artifact)
196202
super().__init__(
197203
target=target,

sdk/ml/azure-ai-ml/azure/ai/ml/entities/_workspace/connections/one_lake_artifacts.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
# ---------------------------------------------------------
44

55
# pylint: disable=protected-access
6-
76
from typing import Any
87
from azure.ai.ml._utils._experimental import experimental
98

sdk/ml/azure-ai-ml/tests/connection/e2etests/test_connections.py

Lines changed: 41 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
OpenAIConnection,
2929
SerpConnection,
3030
ServerlessConnection,
31+
AccountKeyConfiguration,
3132
)
3233
from azure.ai.ml.constants._common import ConnectionTypes
3334
from azure.core.exceptions import ResourceNotFoundError
@@ -398,26 +399,45 @@ def test_workspace_connection_data_connection_listing(
398399
{"account_name": storage_account_name},
399400
]
400401
internal_blob_ds = load_datastore(blob_store_file, params_override=params_override)
401-
created_datastore = self.datastore_create_get_list(client, internal_blob_ds, random_name)
402-
assert isinstance(created_datastore, AzureBlobDatastore)
403-
assert created_datastore.container_name == internal_blob_ds.container_name
404-
assert created_datastore.account_name == internal_blob_ds.account_name
405-
assert created_datastore.credentials.account_key == primary_account_key
406-
407-
# Make sure that normal list call doesn't include data connection
408-
assert internal_blob_ds.name not in [conn.name for conn in client.connections.list()]
409-
410-
# Make sure that the data connection list call includes the data connection
411-
found_datastore_conn = False
412-
for conn in client.connections.list(include_data_connections=True):
413-
if created_datastore.name == conn.name:
414-
assert conn.type == camel_to_snake(ConnectionCategory.AZURE_BLOB)
415-
assert isinstance(conn, AzureBlobStoreConnection)
416-
found_datastore_conn = True
417-
# Ensure that we actually found and validated the data connection.
418-
assert found_datastore_conn
419-
# delete the data store.
420-
client.datastores.delete(random_name)
402+
403+
created_datastore = None
404+
created_connection = None
405+
try:
406+
created_datastore = self.datastore_create_get_list(client, internal_blob_ds, random_name)
407+
assert isinstance(created_datastore, AzureBlobDatastore)
408+
assert created_datastore.container_name == internal_blob_ds.container_name
409+
assert created_datastore.account_name == internal_blob_ds.account_name
410+
assert created_datastore.credentials.account_key == primary_account_key
411+
412+
# Now that a datastore exists, create a connection to it
413+
local_connection = AzureBlobStoreConnection(
414+
name=created_datastore.name,
415+
url=created_datastore.base_path,
416+
account_name=created_datastore.account_name,
417+
container_name=created_datastore.container_name,
418+
credentials=AccountKeyConfiguration(account_key=created_datastore.credentials.account_key),
419+
)
420+
421+
created_connection = client.connections.create_or_update(connection=local_connection)
422+
423+
# Make sure that normal list call doesn't include data connection
424+
assert internal_blob_ds.name not in [conn.name for conn in client.connections.list()]
425+
426+
# Make sure that the data connection list call includes the data connection
427+
found_datastore_conn = False
428+
for conn in client.connections.list(include_data_connections=True):
429+
if created_datastore.name == conn.name:
430+
assert conn.type == camel_to_snake(ConnectionCategory.AZURE_BLOB)
431+
assert isinstance(conn, AzureBlobStoreConnection)
432+
found_datastore_conn = True
433+
# Ensure that we actually found and validated the data connection.
434+
assert found_datastore_conn
435+
finally:
436+
# Delete resources
437+
if created_connection is not None:
438+
client.connections.delete(name=created_datastore.name)
439+
if created_datastore is not None:
440+
client.datastores.delete(random_name)
421441
with pytest.raises(Exception):
422442
client.datastores.get(random_name)
423443

@@ -549,7 +569,7 @@ def test_azure_open_ai_crud(
549569
assert created_connection.type == camel_to_snake(ConnectionCategory.AZURE_OPEN_AI)
550570
assert created_connection.tags is not None
551571
assert created_connection.tags["hello"] == "world"
552-
assert created_connection.api_version == "1.0"
572+
assert created_connection.api_version is None
553573
assert created_connection.open_ai_resource_id == None
554574

555575
with pytest.raises(Exception):

0 commit comments

Comments
 (0)