Skip to content

Commit cd95808

Browse files
Complete marshmallow 4.x Field constructor compatibility fixes
Co-authored-by: kshitij-microsoft <[email protected]>
1 parent 9307b99 commit cd95808

File tree

1 file changed

+34
-6
lines changed
  • sdk/ml/azure-ai-ml/azure/ai/ml/_schema/core

1 file changed

+34
-6
lines changed

sdk/ml/azure-ai-ml/azure/ai/ml/_schema/core/fields.py

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -68,13 +68,30 @@ def _resolve_field_instance(cls_or_instance):
6868
)
6969

7070

71+
def _filter_field_kwargs(**kwargs):
72+
"""
73+
Helper function to filter kwargs to only include parameters that are
74+
valid for marshmallow 4.x Field constructor.
75+
76+
Returns a dict containing only valid Field parameters.
77+
"""
78+
valid_field_params = {}
79+
for param in ['load_default', 'dump_default', 'missing', 'allow_none', 'validate',
80+
'required', 'load_only', 'dump_only', 'error_messages', 'metadata', 'data_key']:
81+
if param in kwargs:
82+
valid_field_params[param] = kwargs[param]
83+
return valid_field_params
84+
85+
7186
class StringTransformedEnum(Field):
7287
def __init__(self, **kwargs):
73-
# pop marshmallow unknown args to avoid warnings
88+
# Extract custom parameters that are not supported by Field in marshmallow 4.x
7489
self.allowed_values = kwargs.pop("allowed_values", None)
7590
self.casing_transform = kwargs.pop("casing_transform", lambda x: x.lower())
7691
self.pass_original = kwargs.pop("pass_original", False)
77-
super().__init__(**kwargs)
92+
93+
# Only pass valid Field parameters to parent constructor
94+
super().__init__(**_filter_field_kwargs(**kwargs))
7895
if isinstance(self.allowed_values, str):
7996
self.allowed_values = [self.allowed_values]
8097
self.allowed_values = [self.casing_transform(x) for x in self.allowed_values]
@@ -284,9 +301,12 @@ class ArmStr(Field):
284301
"""A string represents an ARM ID for some AzureML resource."""
285302

286303
def __init__(self, **kwargs):
304+
# Extract custom parameters that are not supported by Field in marshmallow 4.x
287305
self.azureml_type = kwargs.pop("azureml_type", None)
288306
self.pattern = kwargs.pop("pattern", r"^azureml:.+")
289-
super().__init__(**kwargs)
307+
308+
# Only pass valid Field parameters to parent constructor
309+
super().__init__(**_filter_field_kwargs(**kwargs))
290310

291311
def _jsonschema_type_mapping(self):
292312
schema = {
@@ -338,6 +358,7 @@ class ArmVersionedStr(ArmStr):
338358
"""A string represents an ARM ID for some AzureML resource with version."""
339359

340360
def __init__(self, **kwargs):
361+
# Extract custom parameters that are not supported by Field in marshmallow 4.x
341362
self.allow_default_version = kwargs.pop("allow_default_version", False)
342363
super().__init__(**kwargs)
343364

@@ -459,13 +480,17 @@ class UnionField(fields.Field):
459480
"""A field that can be one of multiple types."""
460481

461482
def __init__(self, union_fields: List[fields.Field], is_strict=False, **kwargs):
462-
super().__init__(**kwargs)
483+
# Store custom parameter separately
484+
self.is_strict = is_strict
485+
486+
# Only pass valid Field parameters to parent constructor
487+
super().__init__(**_filter_field_kwargs(**kwargs))
463488
try:
464489
# add the validation and make sure union_fields must be subclasses or instances of
465490
# marshmallow.Field
466491
self._union_fields = [_resolve_field_instance(cls_or_instance) for cls_or_instance in union_fields]
467492
# TODO: make serialization/de-serialization work in the same way as json schema when is_strict is True
468-
self.is_strict = is_strict # S\When True, combine fields with oneOf instead of anyOf at schema generation
493+
# When True, combine fields with oneOf instead of anyOf at schema generation
469494
except ValueError as error:
470495
raise ValueError(
471496
'Elements of "union_fields" must be subclasses or instances of marshmallow.Field.'
@@ -935,8 +960,11 @@ class RegistryStr(Field):
935960
"""A string represents a registry ID for some AzureML resource."""
936961

937962
def __init__(self, **kwargs):
963+
# Extract custom parameters that are not supported by Field in marshmallow 4.x
938964
self.azureml_type = kwargs.pop("azureml_type", None)
939-
super().__init__(**kwargs)
965+
966+
# Only pass valid Field parameters to parent constructor
967+
super().__init__(**_filter_field_kwargs(**kwargs))
940968

941969
def _jsonschema_type_mapping(self):
942970
schema = {

0 commit comments

Comments
 (0)