Skip to content

Commit 20c8e82

Browse files
Complete marshmallow 4.x migration with validation tests and changelog update
Co-authored-by: kshitij-microsoft <[email protected]>
1 parent f277b77 commit 20c8e82

File tree

2 files changed

+352
-0
lines changed

2 files changed

+352
-0
lines changed

sdk/ml/azure-ai-ml/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55

66
### Bugs Fixed
77

8+
### Other Changes
9+
- Upgraded `marshmallow` dependency from version 3.x to 4.x (`>=4.0.0,<5.0.0`) for improved performance and compatibility with latest serialization standards.
10+
811
## 1.27.1 (2025-05-13)
912

1013
### Bugs Fixed
Lines changed: 349 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,349 @@
1+
#!/usr/bin/env python3
2+
"""
3+
Marshmallow 4.x Migration Validation Script for Azure AI ML
4+
5+
This script validates that the marshmallow patterns used in azure-ai-ml
6+
are compatible with marshmallow 4.x. Run this script after upgrading
7+
marshmallow to verify the migration was successful.
8+
9+
Usage:
10+
python test_marshmallow_migration.py
11+
12+
The script tests:
13+
- Basic marshmallow imports and functionality
14+
- Custom schema metaclass patterns (PatchedSchemaMeta)
15+
- PathAware schema decorators (pre_load, post_dump)
16+
- Validation error handling patterns
17+
- Field usage patterns (Nested, Dict, List, etc.)
18+
- marshmallow-jsonschema compatibility (if available)
19+
"""
20+
21+
import sys
22+
import os
23+
import traceback
24+
from pathlib import Path
25+
26+
def get_marshmallow_version():
27+
"""Get current marshmallow version using the new preferred method"""
28+
try:
29+
# Marshmallow 4.x preferred method
30+
import importlib.metadata
31+
return importlib.metadata.version("marshmallow")
32+
except ImportError:
33+
try:
34+
# Fallback for older Python versions
35+
import pkg_resources
36+
return pkg_resources.get_distribution("marshmallow").version
37+
except:
38+
pass
39+
except:
40+
pass
41+
42+
try:
43+
# Deprecated method (will show warning in 3.x, removed in 4.x)
44+
import marshmallow
45+
if hasattr(marshmallow, '__version__'):
46+
return marshmallow.__version__
47+
except:
48+
pass
49+
50+
return "unknown"
51+
52+
def test_basic_imports():
53+
"""Test that all basic marshmallow imports work"""
54+
try:
55+
from marshmallow import Schema, fields, RAISE, ValidationError, INCLUDE, EXCLUDE
56+
from marshmallow.decorators import post_dump, post_load, pre_load, validates
57+
from marshmallow.schema import SchemaMeta
58+
from marshmallow.exceptions import ValidationError as SchemaValidationError
59+
from marshmallow.fields import Field, Nested
60+
from marshmallow.utils import FieldInstanceResolutionError, from_iso_datetime, resolve_field_instance
61+
62+
print("✓ All basic imports successful")
63+
return True
64+
except Exception as e:
65+
print(f"✗ Basic imports failed: {e}")
66+
traceback.print_exc()
67+
return False
68+
69+
def test_patched_schema_metaclass():
70+
"""Test the PatchedSchemaMeta pattern used in azure-ai-ml"""
71+
try:
72+
from marshmallow import Schema, RAISE, fields
73+
from marshmallow.decorators import post_dump
74+
from marshmallow.schema import SchemaMeta
75+
from collections import OrderedDict
76+
77+
class PatchedMeta:
78+
ordered = True
79+
unknown = RAISE
80+
81+
class PatchedBaseSchema(Schema):
82+
class Meta:
83+
unknown = RAISE
84+
ordered = True
85+
86+
@post_dump
87+
def remove_none(self, data, **kwargs):
88+
"""Remove None values from dumped data"""
89+
return OrderedDict((key, value) for key, value in data.items() if value is not None)
90+
91+
class PatchedSchemaMeta(SchemaMeta):
92+
"""Custom metaclass that injects Meta attributes"""
93+
def __new__(mcs, name, bases, dct):
94+
meta = dct.get("Meta")
95+
if meta is None:
96+
dct["Meta"] = PatchedMeta
97+
else:
98+
if not hasattr(meta, "unknown"):
99+
dct["Meta"].unknown = RAISE
100+
if not hasattr(meta, "ordered"):
101+
dct["Meta"].ordered = True
102+
103+
if PatchedBaseSchema not in bases:
104+
bases = bases + (PatchedBaseSchema,)
105+
klass = super().__new__(mcs, name, bases, dct)
106+
return klass
107+
108+
# Test schema creation and usage
109+
class TestSchema(PatchedBaseSchema, metaclass=PatchedSchemaMeta):
110+
name = fields.Str(required=True)
111+
count = fields.Int()
112+
tags = fields.Dict()
113+
114+
schema = TestSchema()
115+
116+
# Test dump with None removal
117+
test_data = {"name": "test", "count": 42, "extra": None, "tags": {"env": "prod"}}
118+
result = schema.dump(test_data)
119+
120+
# Verify None was removed and order is preserved
121+
if isinstance(result, OrderedDict) and "extra" not in result:
122+
print("✓ PatchedSchemaMeta works correctly")
123+
return True
124+
else:
125+
print("✗ PatchedSchemaMeta behavior changed")
126+
return False
127+
128+
except Exception as e:
129+
print(f"✗ PatchedSchemaMeta failed: {e}")
130+
traceback.print_exc()
131+
return False
132+
133+
def test_pathaware_schema_decorators():
134+
"""Test pre_load and post_dump decorators used in PathAwareSchema"""
135+
try:
136+
from marshmallow import Schema, fields, RAISE
137+
from marshmallow.decorators import pre_load, post_dump
138+
from collections import OrderedDict
139+
140+
class TestPathAwareSchema(Schema):
141+
class Meta:
142+
unknown = RAISE
143+
ordered = True
144+
145+
schema_ignored = fields.Str(data_key="$schema", dump_only=True)
146+
name = fields.Str(required=True)
147+
description = fields.Str()
148+
149+
@post_dump
150+
def remove_none(self, data, **kwargs):
151+
return OrderedDict((key, value) for key, value in data.items() if value is not None)
152+
153+
@pre_load
154+
def trim_dump_only(self, data, **kwargs):
155+
"""Remove dump_only fields from load data"""
156+
if isinstance(data, str) or data is None:
157+
return data
158+
for key, value in self.fields.items():
159+
if value.dump_only:
160+
schema_key = value.data_key or key
161+
if isinstance(data, dict) and data.get(schema_key, None) is not None:
162+
data.pop(schema_key)
163+
return data
164+
165+
schema = TestPathAwareSchema()
166+
167+
# Test that dump_only field is included in dump but ignored in load
168+
test_data = {
169+
"name": "test",
170+
"description": "description",
171+
"$schema": "should_be_ignored_on_load"
172+
}
173+
174+
loaded = schema.load(test_data)
175+
dumped = schema.dump({"name": "test", "description": "description"})
176+
177+
print("✓ PathAware schema decorators work")
178+
return True
179+
180+
except Exception as e:
181+
print(f"✗ PathAware schema decorators failed: {e}")
182+
traceback.print_exc()
183+
return False
184+
185+
def test_validation_error_handling():
186+
"""Test validation error patterns used throughout the codebase"""
187+
try:
188+
from marshmallow import Schema, fields, ValidationError
189+
from marshmallow.exceptions import ValidationError as SchemaValidationError
190+
191+
class TestSchema(Schema):
192+
required_field = fields.Str(required=True)
193+
int_field = fields.Int()
194+
195+
schema = TestSchema()
196+
197+
# Test ValidationError import pattern (used in operations files)
198+
validation_error_caught = False
199+
try:
200+
schema.load({}) # Missing required field
201+
except ValidationError:
202+
validation_error_caught = True
203+
204+
# Test SchemaValidationError import pattern (used in operations files)
205+
schema_validation_error_caught = False
206+
try:
207+
schema.load({"required_field": "ok", "int_field": "not_an_int"})
208+
except SchemaValidationError:
209+
schema_validation_error_caught = True
210+
211+
if validation_error_caught and schema_validation_error_caught:
212+
print("✓ Validation error handling works")
213+
return True
214+
else:
215+
print("✗ Validation error handling failed")
216+
return False
217+
218+
except Exception as e:
219+
print(f"✗ Validation error handling failed: {e}")
220+
traceback.print_exc()
221+
return False
222+
223+
def test_field_patterns():
224+
"""Test field usage patterns from the codebase"""
225+
try:
226+
from marshmallow import Schema, fields, validates, ValidationError
227+
228+
class TestSchema(Schema):
229+
# Common field patterns from the codebase
230+
str_field = fields.Str()
231+
int_field = fields.Int()
232+
bool_field = fields.Bool()
233+
list_field = fields.List(fields.Str())
234+
dict_field = fields.Dict(keys=fields.Str(), values=fields.Str())
235+
236+
# Fields with options
237+
required_field = fields.Str(required=True)
238+
nullable_field = fields.Str(allow_none=True)
239+
aliased_field = fields.Str(data_key="alias")
240+
dump_only_field = fields.Str(dump_only=True)
241+
load_only_field = fields.Str(load_only=True)
242+
243+
# Nested field with lambda (marshmallow 4.x compatible)
244+
nested_field = fields.Nested(lambda: TestSchema(), allow_none=True)
245+
246+
@validates('str_field')
247+
def validate_str_field(self, value):
248+
if value == "invalid":
249+
raise ValidationError("Invalid value")
250+
251+
schema = TestSchema()
252+
253+
# Test various operations
254+
test_data = {
255+
"str_field": "test",
256+
"int_field": 42,
257+
"bool_field": True,
258+
"list_field": ["a", "b"],
259+
"dict_field": {"key": "value"},
260+
"required_field": "required",
261+
"nullable_field": None,
262+
"alias": "aliased",
263+
"dump_only_field": "dump",
264+
"load_only_field": "load"
265+
}
266+
267+
dumped = schema.dump(test_data)
268+
loaded = schema.load({
269+
"str_field": "test",
270+
"required_field": "required",
271+
"alias": "aliased",
272+
"load_only_field": "load"
273+
})
274+
275+
print("✓ Field patterns work")
276+
return True
277+
278+
except Exception as e:
279+
print(f"✗ Field patterns failed: {e}")
280+
traceback.print_exc()
281+
return False
282+
283+
def test_marshmallow_jsonschema_compatibility():
284+
"""Test marshmallow-jsonschema compatibility if available"""
285+
try:
286+
from marshmallow_jsonschema import JSONSchema
287+
from marshmallow import Schema, fields
288+
289+
class TestSchema(Schema):
290+
name = fields.Str()
291+
count = fields.Int()
292+
293+
json_schema = JSONSchema()
294+
schema_dict = json_schema.dump(TestSchema())
295+
296+
print("✓ marshmallow-jsonschema compatibility works")
297+
return True
298+
299+
except ImportError:
300+
print("ℹ marshmallow-jsonschema not available, skipping test")
301+
return True
302+
except Exception as e:
303+
print(f"✗ marshmallow-jsonschema compatibility failed: {e}")
304+
traceback.print_exc()
305+
return False
306+
307+
def run_migration_tests():
308+
"""Run all migration tests"""
309+
version = get_marshmallow_version()
310+
print(f"Marshmallow 4.x Migration Validation")
311+
print(f"Marshmallow version: {version}")
312+
print("=" * 60)
313+
314+
tests = [
315+
("Basic Imports", test_basic_imports),
316+
("PatchedSchemaMeta", test_patched_schema_metaclass),
317+
("PathAware Decorators", test_pathaware_schema_decorators),
318+
("Validation Errors", test_validation_error_handling),
319+
("Field Patterns", test_field_patterns),
320+
("JSONSchema Compatibility", test_marshmallow_jsonschema_compatibility),
321+
]
322+
323+
results = []
324+
for test_name, test_func in tests:
325+
print(f"Running {test_name}...")
326+
try:
327+
result = test_func()
328+
results.append(result)
329+
except Exception as e:
330+
print(f"✗ {test_name} crashed: {e}")
331+
results.append(False)
332+
print()
333+
334+
print("=" * 60)
335+
passed = sum(results)
336+
total = len(results)
337+
338+
print(f"Migration test results: {passed}/{total} passed")
339+
340+
if passed == total:
341+
print("🎉 All tests passed! Marshmallow 4.x migration is successful.")
342+
return True
343+
else:
344+
print("⚠️ Some tests failed. The migration may need additional fixes.")
345+
return False
346+
347+
if __name__ == "__main__":
348+
success = run_migration_tests()
349+
sys.exit(0 if success else 1)

0 commit comments

Comments
 (0)