Skip to content

Commit 9fc59a8

Browse files
black fixes
1 parent 3433418 commit 9fc59a8

File tree

1 file changed

+32
-91
lines changed
  • sdk/ml/azure-ai-ml/azure/ai/ml/_schema/core

1 file changed

+32
-91
lines changed

sdk/ml/azure-ai-ml/azure/ai/ml/_schema/core/fields.py

Lines changed: 32 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -74,16 +74,17 @@ def from_iso_datetime(value):
7474
if isinstance(value, str):
7575
# Validate that this is a proper datetime string, not just a date
7676
# The original marshmallow from_iso_datetime expects datetime format
77-
if 'T' not in value and 'Z' not in value and ' ' not in value:
77+
if "T" not in value and "Z" not in value and " " not in value:
7878
# This is likely just a date string, not a datetime
7979
raise ValueError(f"Expected datetime string but got date string: {value}")
80-
80+
8181
# Replace 'Z' with '+00:00' for compatibility with datetime.fromisoformat
8282
if value.endswith("Z"):
8383
value = value[:-1] + "+00:00"
8484
return datetime.fromisoformat(value)
8585
return value
8686

87+
8788
module_logger = logging.getLogger(__name__)
8889
T = typing.TypeVar("T")
8990

@@ -110,24 +111,14 @@ def _jsonschema_type_mapping(self):
110111
def _serialize(self, value, attr, obj, **kwargs):
111112
if not value:
112113
return None
113-
if (
114-
isinstance(value, str)
115-
and self.casing_transform(value) in self.allowed_values
116-
):
114+
if isinstance(value, str) and self.casing_transform(value) in self.allowed_values:
117115
return value if self.pass_original else self.casing_transform(value)
118-
raise ValidationError(
119-
f"Value {value!r} passed is not in set {self.allowed_values}"
120-
)
116+
raise ValidationError(f"Value {value!r} passed is not in set {self.allowed_values}")
121117

122118
def _deserialize(self, value, attr, data, **kwargs):
123-
if (
124-
isinstance(value, str)
125-
and self.casing_transform(value) in self.allowed_values
126-
):
119+
if isinstance(value, str) and self.casing_transform(value) in self.allowed_values:
127120
return value if self.pass_original else self.casing_transform(value)
128-
raise ValidationError(
129-
f"Value {value!r} passed is not in set {self.allowed_values}"
130-
)
121+
raise ValidationError(f"Value {value!r} passed is not in set {self.allowed_values}")
131122

132123

133124
class DumpableEnumField(StringTransformedEnum):
@@ -182,15 +173,11 @@ def _resolve_path(self, value: Union[str, os.PathLike]) -> Path:
182173
# for non-path string like "azureml:/xxx", OSError can be raised in either
183174
# resolve() or is_dir() or is_file()
184175
result = result.resolve()
185-
if (self._allow_dir and result.is_dir()) or (
186-
self._allow_file and result.is_file()
187-
):
176+
if (self._allow_dir and result.is_dir()) or (self._allow_file and result.is_file()):
188177
return result
189178
except OSError as e:
190179
raise self.make_error("invalid_path") from e
191-
raise self.make_error(
192-
"path_not_exist", path=result.as_posix(), allow_type=self.allowed_path_type
193-
)
180+
raise self.make_error("path_not_exist", path=result.as_posix(), allow_type=self.allowed_path_type)
194181

195182
@property
196183
def allowed_path_type(self) -> str:
@@ -213,9 +200,7 @@ def _serialize(self, value, attr, obj, **kwargs) -> typing.Optional[str]:
213200
if value is None:
214201
return None
215202
# always dump path as absolute path in string as base_path will be dropped after serialization
216-
return super(LocalPathField, self)._serialize(
217-
self._resolve_path(value).as_posix(), attr, obj, **kwargs
218-
)
203+
return super(LocalPathField, self)._serialize(self._resolve_path(value).as_posix(), attr, obj, **kwargs)
219204

220205

221206
class SerializeValidatedUrl(fields.Url):
@@ -308,9 +293,7 @@ def _validate(self, value):
308293
try:
309294
from_iso_datetime(value)
310295
except Exception as e:
311-
raise ValidationError(
312-
f"Not a valid ISO8601-formatted datetime string: {value}"
313-
) from e
296+
raise ValidationError(f"Not a valid ISO8601-formatted datetime string: {value}") from e
314297

315298

316299
class ArmStr(Field):
@@ -335,9 +318,7 @@ def _jsonschema_type_mapping(self):
335318

336319
def _serialize(self, value, attr, obj, **kwargs):
337320
if isinstance(value, str):
338-
serialized_value = (
339-
value if value.startswith(ARM_ID_PREFIX) else f"{ARM_ID_PREFIX}{value}"
340-
)
321+
serialized_value = value if value.startswith(ARM_ID_PREFIX) else f"{ARM_ID_PREFIX}{value}"
341322
return serialized_value
342323
if value is None and not self.required:
343324
return None
@@ -356,9 +337,7 @@ def _deserialize(self, value, attr, data, **kwargs):
356337
if self.azureml_type is not None:
357338
azureml_type_suffix = self.azureml_type
358339
else:
359-
azureml_type_suffix = (
360-
"<asset_type>" + "/<resource_name>/<version-if applicable>)"
361-
)
340+
azureml_type_suffix = "<asset_type>" + "/<resource_name>/<version-if applicable>)"
362341
raise ValidationError(
363342
f"In order to specify an existing {self.azureml_type if self.azureml_type is not None else 'asset'}, "
364343
"please provide either of the following prefixed with 'azureml:':\n"
@@ -404,9 +383,7 @@ def _deserialize(self, value, attr, data, **kwargs):
404383
if not (label or version):
405384
if self.allow_default_version:
406385
return name
407-
raise ValidationError(
408-
f"Either version or label is not provided for {attr} or the id is not valid."
409-
)
386+
raise ValidationError(f"Either version or label is not provided for {attr} or the id is not valid.")
410387

411388
if version:
412389
return f"{name}:{version}"
@@ -501,10 +478,7 @@ def __init__(self, union_fields: List[fields.Field], is_strict=False, **kwargs):
501478
try:
502479
# add the validation and make sure union_fields must be subclasses or instances of
503480
# marshmallow fields
504-
self._union_fields = [
505-
resolve_field_instance(cls_or_instance)
506-
for cls_or_instance in union_fields
507-
]
481+
self._union_fields = [resolve_field_instance(cls_or_instance) for cls_or_instance in union_fields]
508482
# TODO: make serialization/de-serialization work in the same way as json schema when is_strict is True
509483
self.is_strict = is_strict # S\When True, combine fields with oneOf instead of anyOf at schema generation
510484
except FieldInstanceResolutionError as error:
@@ -579,13 +553,9 @@ def _deserialize(self, value, attr, data, **kwargs):
579553
and isinstance(schema.schema, PathAwareSchema)
580554
):
581555
# use old base path to recover original base path
582-
schema.schema.context[BASE_PATH_CONTEXT_KEY] = (
583-
schema.schema.old_base_path
584-
)
556+
schema.schema.context[BASE_PATH_CONTEXT_KEY] = schema.schema.old_base_path
585557
# recover base path of parent schema
586-
schema.context[BASE_PATH_CONTEXT_KEY] = schema.schema.context[
587-
BASE_PATH_CONTEXT_KEY
588-
]
558+
schema.context[BASE_PATH_CONTEXT_KEY] = schema.schema.context[BASE_PATH_CONTEXT_KEY]
589559
raise ValidationError(errors, field_name=attr)
590560

591561

@@ -621,8 +591,7 @@ def __init__(
621591
for type_name, type_sensitive_fields in type_sensitive_fields_dict.items():
622592
union_fields.extend(type_sensitive_fields)
623593
self._type_sensitive_fields_dict[type_name] = [
624-
resolve_field_instance(cls_or_instance)
625-
for cls_or_instance in type_sensitive_fields
594+
resolve_field_instance(cls_or_instance) for cls_or_instance in type_sensitive_fields
626595
]
627596

628597
super(TypeSensitiveUnionField, self).__init__(union_fields, **kwargs)
@@ -634,9 +603,7 @@ def _bind_to_schema(self, field_name, schema):
634603
type_name,
635604
type_sensitive_fields,
636605
) in self._type_sensitive_fields_dict.items():
637-
self._type_sensitive_fields_dict[type_name] = self._create_bind_fields(
638-
type_sensitive_fields, field_name
639-
)
606+
self._type_sensitive_fields_dict[type_name] = self._create_bind_fields(type_sensitive_fields, field_name)
640607

641608
@property
642609
def type_field_name(self) -> str:
@@ -671,9 +638,7 @@ def _simplified_error_base_on_type(self, e, value, attr) -> Exception:
671638
if value_type not in self.allowed_types:
672639
# if value has type field but its value doesn't match any allowed value, raise ValidationError directly
673640
return ValidationError(
674-
message={
675-
self.type_field_name: f"Value {value_type!r} passed is not in set {self.allowed_types}"
676-
},
641+
message={self.type_field_name: f"Value {value_type!r} passed is not in set {self.allowed_types}"},
677642
field_name=attr,
678643
)
679644
filtered_messages = []
@@ -704,9 +669,7 @@ def _serialize(self, value, attr, obj, **kwargs):
704669
self._union_fields = target_fields
705670

706671
try:
707-
return super(TypeSensitiveUnionField, self)._serialize(
708-
value, attr, obj, **kwargs
709-
)
672+
return super(TypeSensitiveUnionField, self)._serialize(value, attr, obj, **kwargs)
710673
except ValidationError as e:
711674
raise self._simplified_error_base_on_type(e, value, attr)
712675
finally:
@@ -734,9 +697,7 @@ def _try_load_from_yaml(self, value):
734697

735698
def _deserialize(self, value, attr, data, **kwargs):
736699
try:
737-
return super(TypeSensitiveUnionField, self)._deserialize(
738-
value, attr, data, **kwargs
739-
)
700+
return super(TypeSensitiveUnionField, self)._deserialize(value, attr, data, **kwargs)
740701
except ValidationError as e:
741702
if isinstance(value, str) and self._allow_load_from_yaml:
742703
value = self._try_load_from_yaml(value)
@@ -775,9 +736,7 @@ def CodeField(**kwargs) -> Field:
775736
# put arm versioned string at last order as it can deserialize any string into "azureml:<origin>"
776737
ArmVersionedStr(azureml_type=AzureMLResourceType.CODE),
777738
],
778-
metadata={
779-
"description": "A local path or http:, https:, azureml: url pointing to a remote location."
780-
},
739+
metadata={"description": "A local path or http:, https:, azureml: url pointing to a remote location."},
781740
**kwargs,
782741
)
783742

@@ -798,9 +757,7 @@ def EnvironmentField(*, extra_fields: List[Field] = None, **kwargs):
798757
[
799758
NestedField(AnonymousEnvironmentSchema),
800759
RegistryStr(azureml_type=AzureMLResourceType.ENVIRONMENT),
801-
ArmVersionedStr(
802-
azureml_type=AzureMLResourceType.ENVIRONMENT, allow_default_version=True
803-
),
760+
ArmVersionedStr(azureml_type=AzureMLResourceType.ENVIRONMENT, allow_default_version=True),
804761
]
805762
+ extra_fields,
806763
**kwargs,
@@ -900,12 +857,8 @@ def __init__(
900857
lower_bound: Optional[str] = None,
901858
**kwargs,
902859
) -> None:
903-
self._upper = (
904-
None if upper_bound is None else self._version_to_tuple(upper_bound)
905-
)
906-
self._lower = (
907-
None if lower_bound is None else self._version_to_tuple(lower_bound)
908-
)
860+
self._upper = None if upper_bound is None else self._version_to_tuple(upper_bound)
861+
self._lower = None if lower_bound is None else self._version_to_tuple(lower_bound)
909862
super().__init__(*args, **kwargs)
910863

911864
def _version_to_tuple(self, value: str):
@@ -926,9 +879,7 @@ def _validate(self, value):
926879
class DumpableIntegerField(fields.Integer):
927880
"""A int field that cannot serialize other type of values to int if self.strict."""
928881

929-
def _serialize(
930-
self, value, attr, obj, **kwargs
931-
) -> typing.Optional[typing.Union[str, T]]:
882+
def _serialize(self, value, attr, obj, **kwargs) -> typing.Optional[typing.Union[str, T]]:
932883
if self.strict and not isinstance(value, int):
933884
# this implementation can serialize bool to bool
934885
raise self.make_error("invalid", input=value)
@@ -954,18 +905,14 @@ def _validated(self, value):
954905
raise self.make_error("invalid", input=value)
955906
return super()._validated(value)
956907

957-
def _serialize(
958-
self, value, attr, obj, **kwargs
959-
) -> typing.Optional[typing.Union[str, T]]:
908+
def _serialize(self, value, attr, obj, **kwargs) -> typing.Optional[typing.Union[str, T]]:
960909
return super()._serialize(self._validated(value), attr, obj, **kwargs)
961910

962911

963912
class DumpableStringField(fields.String):
964913
"""A string field that cannot serialize other type of values to string if self.strict."""
965914

966-
def _serialize(
967-
self, value, attr, obj, **kwargs
968-
) -> typing.Optional[typing.Union[str, T]]:
915+
def _serialize(self, value, attr, obj, **kwargs) -> typing.Optional[typing.Union[str, T]]:
969916
if not isinstance(value, str):
970917
raise ValidationError("Given value is not a string")
971918
return super()._serialize(value, attr, obj, **kwargs)
@@ -978,9 +925,7 @@ def __init__(self, experimental_field: fields.Field, **kwargs):
978925
self._experimental_field = resolve_field_instance(experimental_field)
979926
self.required = experimental_field.required
980927
except FieldInstanceResolutionError as error:
981-
raise ValueError(
982-
'"experimental_field" must be subclasses or instances of marshmallow fields.'
983-
) from error
928+
raise ValueError('"experimental_field" must be subclasses or instances of marshmallow fields.') from error
984929

985930
@property
986931
def experimental_field(self):
@@ -998,9 +943,7 @@ def _serialize(self, value, attr, obj, **kwargs):
998943

999944
def _deserialize(self, value, attr, data, **kwargs):
1000945
if value is not None:
1001-
message = "Field '{0}': {1} {2}".format(
1002-
attr, EXPERIMENTAL_FIELD_MESSAGE, EXPERIMENTAL_LINK_MESSAGE
1003-
)
946+
message = "Field '{0}': {1} {2}".format(attr, EXPERIMENTAL_FIELD_MESSAGE, EXPERIMENTAL_LINK_MESSAGE)
1004947
if not _is_warning_cached(message):
1005948
module_logger.warning(message)
1006949

@@ -1129,6 +1072,4 @@ def _serialize(self, value, attr, obj, **kwargs):
11291072
def _deserialize(self, value, attr, data, **kwargs):
11301073
if isinstance(value, str) and value.startswith("git+"):
11311074
return value
1132-
raise ValidationError(
1133-
"In order to specify a git path, please provide the correct path prefixed with 'git+\n"
1134-
)
1075+
raise ValidationError("In order to specify a git path, please provide the correct path prefixed with 'git+\n")

0 commit comments

Comments
 (0)