Skip to content

Commit ad054b7

Browse files
Apply code formatting and finalize marshmallow 4.x.x upgrade
Co-authored-by: kshitij-microsoft <[email protected]>
1 parent 37c7eb3 commit ad054b7

File tree

1 file changed

+110
-33
lines changed
  • sdk/ml/azure-ai-ml/azure/ai/ml/_schema/core

1 file changed

+110
-33
lines changed

sdk/ml/azure-ai-ml/azure/ai/ml/_schema/core/fields.py

Lines changed: 110 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -25,20 +25,33 @@
2525
# marshmallow 3.x
2626
from marshmallow.utils import resolve_field_instance
2727

28+
2829
# Custom implementation for from_iso_datetime compatibility
2930
def from_iso_datetime(value):
3031
"""Parse an ISO8601 datetime string, handling the 'Z' suffix."""
3132
from datetime import datetime
33+
3234
if isinstance(value, str):
3335
# Replace 'Z' with '+00:00' for compatibility with datetime.fromisoformat
34-
if value.endswith('Z'):
35-
value = value[:-1] + '+00:00'
36+
if value.endswith("Z"):
37+
value = value[:-1] + "+00:00"
3638
return datetime.fromisoformat(value)
3739
return value
3840

39-
from ..._utils._arm_id_utils import AMLVersionedArmId, is_ARM_id_for_resource, parse_name_label, parse_name_version
41+
42+
from ..._utils._arm_id_utils import (
43+
AMLVersionedArmId,
44+
is_ARM_id_for_resource,
45+
parse_name_label,
46+
parse_name_version,
47+
)
4048
from ..._utils._experimental import _is_warning_cached
41-
from ..._utils.utils import is_data_binding_expression, is_valid_node_name, load_file, load_yaml
49+
from ..._utils.utils import (
50+
is_data_binding_expression,
51+
is_valid_node_name,
52+
load_file,
53+
load_yaml,
54+
)
4255
from ...constants._common import (
4356
ARM_ID_PREFIX,
4457
AZUREML_RESOURCE_PROVIDER,
@@ -86,14 +99,24 @@ def _jsonschema_type_mapping(self):
8699
def _serialize(self, value, attr, obj, **kwargs):
87100
if not value:
88101
return None
89-
if isinstance(value, str) and self.casing_transform(value) in self.allowed_values:
102+
if (
103+
isinstance(value, str)
104+
and self.casing_transform(value) in self.allowed_values
105+
):
90106
return value if self.pass_original else self.casing_transform(value)
91-
raise ValidationError(f"Value {value!r} passed is not in set {self.allowed_values}")
107+
raise ValidationError(
108+
f"Value {value!r} passed is not in set {self.allowed_values}"
109+
)
92110

93111
def _deserialize(self, value, attr, data, **kwargs):
94-
if isinstance(value, str) and self.casing_transform(value) in self.allowed_values:
112+
if (
113+
isinstance(value, str)
114+
and self.casing_transform(value) in self.allowed_values
115+
):
95116
return value if self.pass_original else self.casing_transform(value)
96-
raise ValidationError(f"Value {value!r} passed is not in set {self.allowed_values}")
117+
raise ValidationError(
118+
f"Value {value!r} passed is not in set {self.allowed_values}"
119+
)
97120

98121

99122
class DumpableEnumField(StringTransformedEnum):
@@ -148,11 +171,15 @@ def _resolve_path(self, value: Union[str, os.PathLike]) -> Path:
148171
# for non-path string like "azureml:/xxx", OSError can be raised in either
149172
# resolve() or is_dir() or is_file()
150173
result = result.resolve()
151-
if (self._allow_dir and result.is_dir()) or (self._allow_file and result.is_file()):
174+
if (self._allow_dir and result.is_dir()) or (
175+
self._allow_file and result.is_file()
176+
):
152177
return result
153178
except OSError as e:
154179
raise self.make_error("invalid_path") from e
155-
raise self.make_error("path_not_exist", path=result.as_posix(), allow_type=self.allowed_path_type)
180+
raise self.make_error(
181+
"path_not_exist", path=result.as_posix(), allow_type=self.allowed_path_type
182+
)
156183

157184
@property
158185
def allowed_path_type(self) -> str:
@@ -175,7 +202,9 @@ def _serialize(self, value, attr, obj, **kwargs) -> typing.Optional[str]:
175202
if value is None:
176203
return None
177204
# always dump path as absolute path in string as base_path will be dropped after serialization
178-
return super(LocalPathField, self)._serialize(self._resolve_path(value).as_posix(), attr, obj, **kwargs)
205+
return super(LocalPathField, self)._serialize(
206+
self._resolve_path(value).as_posix(), attr, obj, **kwargs
207+
)
179208

180209

181210
class SerializeValidatedUrl(fields.Url):
@@ -268,7 +297,9 @@ def _validate(self, value):
268297
try:
269298
from_iso_datetime(value)
270299
except Exception as e:
271-
raise ValidationError(f"Not a valid ISO8601-formatted datetime string: {value}") from e
300+
raise ValidationError(
301+
f"Not a valid ISO8601-formatted datetime string: {value}"
302+
) from e
272303

273304

274305
class ArmStr(Field):
@@ -293,7 +324,9 @@ def _jsonschema_type_mapping(self):
293324

294325
def _serialize(self, value, attr, obj, **kwargs):
295326
if isinstance(value, str):
296-
serialized_value = value if value.startswith(ARM_ID_PREFIX) else f"{ARM_ID_PREFIX}{value}"
327+
serialized_value = (
328+
value if value.startswith(ARM_ID_PREFIX) else f"{ARM_ID_PREFIX}{value}"
329+
)
297330
return serialized_value
298331
if value is None and not self.required:
299332
return None
@@ -312,7 +345,9 @@ def _deserialize(self, value, attr, data, **kwargs):
312345
if self.azureml_type is not None:
313346
azureml_type_suffix = self.azureml_type
314347
else:
315-
azureml_type_suffix = "<asset_type>" + "/<resource_name>/<version-if applicable>)"
348+
azureml_type_suffix = (
349+
"<asset_type>" + "/<resource_name>/<version-if applicable>)"
350+
)
316351
raise ValidationError(
317352
f"In order to specify an existing {self.azureml_type if self.azureml_type is not None else 'asset'}, "
318353
"please provide either of the following prefixed with 'azureml:':\n"
@@ -358,7 +393,9 @@ def _deserialize(self, value, attr, data, **kwargs):
358393
if not (label or version):
359394
if self.allow_default_version:
360395
return name
361-
raise ValidationError(f"Either version or label is not provided for {attr} or the id is not valid.")
396+
raise ValidationError(
397+
f"Either version or label is not provided for {attr} or the id is not valid."
398+
)
362399

363400
if version:
364401
return f"{name}:{version}"
@@ -453,7 +490,10 @@ def __init__(self, union_fields: List[fields.Field], is_strict=False, **kwargs):
453490
try:
454491
# add the validation and make sure union_fields must be subclasses or instances of
455492
# marshmallow fields
456-
self._union_fields = [resolve_field_instance(cls_or_instance) for cls_or_instance in union_fields]
493+
self._union_fields = [
494+
resolve_field_instance(cls_or_instance)
495+
for cls_or_instance in union_fields
496+
]
457497
# TODO: make serialization/de-serialization work in the same way as json schema when is_strict is True
458498
self.is_strict = is_strict # S\When True, combine fields with oneOf instead of anyOf at schema generation
459499
except FieldInstanceResolutionError as error:
@@ -528,9 +568,13 @@ def _deserialize(self, value, attr, data, **kwargs):
528568
and isinstance(schema.schema, PathAwareSchema)
529569
):
530570
# use old base path to recover original base path
531-
schema.schema.context[BASE_PATH_CONTEXT_KEY] = schema.schema.old_base_path
571+
schema.schema.context[BASE_PATH_CONTEXT_KEY] = (
572+
schema.schema.old_base_path
573+
)
532574
# recover base path of parent schema
533-
schema.context[BASE_PATH_CONTEXT_KEY] = schema.schema.context[BASE_PATH_CONTEXT_KEY]
575+
schema.context[BASE_PATH_CONTEXT_KEY] = schema.schema.context[
576+
BASE_PATH_CONTEXT_KEY
577+
]
534578
raise ValidationError(errors, field_name=attr)
535579

536580

@@ -566,7 +610,8 @@ def __init__(
566610
for type_name, type_sensitive_fields in type_sensitive_fields_dict.items():
567611
union_fields.extend(type_sensitive_fields)
568612
self._type_sensitive_fields_dict[type_name] = [
569-
resolve_field_instance(cls_or_instance) for cls_or_instance in type_sensitive_fields
613+
resolve_field_instance(cls_or_instance)
614+
for cls_or_instance in type_sensitive_fields
570615
]
571616

572617
super(TypeSensitiveUnionField, self).__init__(union_fields, **kwargs)
@@ -578,7 +623,9 @@ def _bind_to_schema(self, field_name, schema):
578623
type_name,
579624
type_sensitive_fields,
580625
) in self._type_sensitive_fields_dict.items():
581-
self._type_sensitive_fields_dict[type_name] = self._create_bind_fields(type_sensitive_fields, field_name)
626+
self._type_sensitive_fields_dict[type_name] = self._create_bind_fields(
627+
type_sensitive_fields, field_name
628+
)
582629

583630
@property
584631
def type_field_name(self) -> str:
@@ -613,7 +660,9 @@ def _simplified_error_base_on_type(self, e, value, attr) -> Exception:
613660
if value_type not in self.allowed_types:
614661
# if value has type field but its value doesn't match any allowed value, raise ValidationError directly
615662
return ValidationError(
616-
message={self.type_field_name: f"Value {value_type!r} passed is not in set {self.allowed_types}"},
663+
message={
664+
self.type_field_name: f"Value {value_type!r} passed is not in set {self.allowed_types}"
665+
},
617666
field_name=attr,
618667
)
619668
filtered_messages = []
@@ -644,7 +693,9 @@ def _serialize(self, value, attr, obj, **kwargs):
644693
self._union_fields = target_fields
645694

646695
try:
647-
return super(TypeSensitiveUnionField, self)._serialize(value, attr, obj, **kwargs)
696+
return super(TypeSensitiveUnionField, self)._serialize(
697+
value, attr, obj, **kwargs
698+
)
648699
except ValidationError as e:
649700
raise self._simplified_error_base_on_type(e, value, attr)
650701
finally:
@@ -672,7 +723,9 @@ def _try_load_from_yaml(self, value):
672723

673724
def _deserialize(self, value, attr, data, **kwargs):
674725
try:
675-
return super(TypeSensitiveUnionField, self)._deserialize(value, attr, data, **kwargs)
726+
return super(TypeSensitiveUnionField, self)._deserialize(
727+
value, attr, data, **kwargs
728+
)
676729
except ValidationError as e:
677730
if isinstance(value, str) and self._allow_load_from_yaml:
678731
value = self._try_load_from_yaml(value)
@@ -711,7 +764,9 @@ def CodeField(**kwargs) -> Field:
711764
# put arm versioned string at last order as it can deserialize any string into "azureml:<origin>"
712765
ArmVersionedStr(azureml_type=AzureMLResourceType.CODE),
713766
],
714-
metadata={"description": "A local path or http:, https:, azureml: url pointing to a remote location."},
767+
metadata={
768+
"description": "A local path or http:, https:, azureml: url pointing to a remote location."
769+
},
715770
**kwargs,
716771
)
717772

@@ -732,7 +787,9 @@ def EnvironmentField(*, extra_fields: List[Field] = None, **kwargs):
732787
[
733788
NestedField(AnonymousEnvironmentSchema),
734789
RegistryStr(azureml_type=AzureMLResourceType.ENVIRONMENT),
735-
ArmVersionedStr(azureml_type=AzureMLResourceType.ENVIRONMENT, allow_default_version=True),
790+
ArmVersionedStr(
791+
azureml_type=AzureMLResourceType.ENVIRONMENT, allow_default_version=True
792+
),
736793
]
737794
+ extra_fields,
738795
**kwargs,
@@ -825,9 +882,19 @@ class NumberVersionField(VersionField):
825882
"invalid": "Number version must be integers concatenated by '.', like 1.0.1.",
826883
}
827884

828-
def __init__(self, *args, upper_bound: Optional[str] = None, lower_bound: Optional[str] = None, **kwargs) -> None:
829-
self._upper = None if upper_bound is None else self._version_to_tuple(upper_bound)
830-
self._lower = None if lower_bound is None else self._version_to_tuple(lower_bound)
885+
def __init__(
886+
self,
887+
*args,
888+
upper_bound: Optional[str] = None,
889+
lower_bound: Optional[str] = None,
890+
**kwargs,
891+
) -> None:
892+
self._upper = (
893+
None if upper_bound is None else self._version_to_tuple(upper_bound)
894+
)
895+
self._lower = (
896+
None if lower_bound is None else self._version_to_tuple(lower_bound)
897+
)
831898
super().__init__(*args, **kwargs)
832899

833900
def _version_to_tuple(self, value: str):
@@ -848,7 +915,9 @@ def _validate(self, value):
848915
class DumpableIntegerField(fields.Integer):
849916
"""A int field that cannot serialize other type of values to int if self.strict."""
850917

851-
def _serialize(self, value, attr, obj, **kwargs) -> typing.Optional[typing.Union[str, T]]:
918+
def _serialize(
919+
self, value, attr, obj, **kwargs
920+
) -> typing.Optional[typing.Union[str, T]]:
852921
if self.strict and not isinstance(value, int):
853922
# this implementation can serialize bool to bool
854923
raise self.make_error("invalid", input=value)
@@ -874,14 +943,18 @@ def _validated(self, value):
874943
raise self.make_error("invalid", input=value)
875944
return super()._validated(value)
876945

877-
def _serialize(self, value, attr, obj, **kwargs) -> typing.Optional[typing.Union[str, T]]:
946+
def _serialize(
947+
self, value, attr, obj, **kwargs
948+
) -> typing.Optional[typing.Union[str, T]]:
878949
return super()._serialize(self._validated(value), attr, obj, **kwargs)
879950

880951

881952
class DumpableStringField(fields.String):
882953
"""A string field that cannot serialize other type of values to string if self.strict."""
883954

884-
def _serialize(self, value, attr, obj, **kwargs) -> typing.Optional[typing.Union[str, T]]:
955+
def _serialize(
956+
self, value, attr, obj, **kwargs
957+
) -> typing.Optional[typing.Union[str, T]]:
885958
if not isinstance(value, str):
886959
raise ValidationError("Given value is not a string")
887960
return super()._serialize(value, attr, obj, **kwargs)
@@ -914,7 +987,9 @@ def _serialize(self, value, attr, obj, **kwargs):
914987

915988
def _deserialize(self, value, attr, data, **kwargs):
916989
if value is not None:
917-
message = "Field '{0}': {1} {2}".format(attr, EXPERIMENTAL_FIELD_MESSAGE, EXPERIMENTAL_LINK_MESSAGE)
990+
message = "Field '{0}': {1} {2}".format(
991+
attr, EXPERIMENTAL_FIELD_MESSAGE, EXPERIMENTAL_LINK_MESSAGE
992+
)
918993
if not _is_warning_cached(message):
919994
module_logger.warning(message)
920995

@@ -1043,4 +1118,6 @@ def _serialize(self, value, attr, obj, **kwargs):
10431118
def _deserialize(self, value, attr, data, **kwargs):
10441119
if isinstance(value, str) and value.startswith("git+"):
10451120
return value
1046-
raise ValidationError("In order to specify a git path, please provide the correct path prefixed with 'git+\n")
1121+
raise ValidationError(
1122+
"In order to specify a git path, please provide the correct path prefixed with 'git+\n"
1123+
)

0 commit comments

Comments
 (0)