Skip to content

Commit 664654d

Browse files
Fix marshmallow 4.x context parameter compatibility issues - Phase 1
Co-authored-by: kshitij-microsoft <[email protected]>
1 parent 080cc8a commit 664654d

File tree

77 files changed

+135
-451
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

77 files changed

+135
-451
lines changed

sdk/ml/azure-ai-ml/azure/ai/ml/_schema/component/command_component.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,9 +123,9 @@ def _deserialize(self, value, attr, data, **kwargs):
123123
# Update base_path to parent path of component file.
124124
component_schema_context = deepcopy(self.context)
125125
component_schema_context[BASE_PATH_CONTEXT_KEY] = source_path.parent
126-
component = AnonymousCommandComponentSchema(context=component_schema_context).load(
126+
component = AnonymousCommandComponentSchema().load(
127127
component_dict
128-
)
128+
, context=component_schema_context)
129129
component._source_path = source_path
130130
component._source = ComponentSource.YAML_COMPONENT
131131
return component

sdk/ml/azure-ai-ml/azure/ai/ml/_schema/component/data_transfer_component.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -206,9 +206,9 @@ def _deserialize(self, value, attr, data, **kwargs):
206206
# Update base_path to parent path of component file.
207207
component_schema_context = deepcopy(self.context)
208208
component_schema_context[BASE_PATH_CONTEXT_KEY] = source_path.parent
209-
component = AnonymousDataTransferCopyComponentSchema(context=component_schema_context).load(
209+
component = AnonymousDataTransferCopyComponentSchema().load(
210210
component_dict
211-
)
211+
, context=component_schema_context)
212212
component._source_path = source_path
213213
component._source = ComponentSource.YAML_COMPONENT
214214
return component
@@ -224,9 +224,9 @@ def _deserialize(self, value, attr, data, **kwargs):
224224
# Update base_path to parent path of component file.
225225
component_schema_context = deepcopy(self.context)
226226
component_schema_context[BASE_PATH_CONTEXT_KEY] = source_path.parent
227-
component = AnonymousDataTransferImportComponentSchema(context=component_schema_context).load(
227+
component = AnonymousDataTransferImportComponentSchema().load(
228228
component_dict
229-
)
229+
, context=component_schema_context)
230230
component._source_path = source_path
231231
component._source = ComponentSource.YAML_COMPONENT
232232
return component
@@ -242,9 +242,9 @@ def _deserialize(self, value, attr, data, **kwargs):
242242
# Update base_path to parent path of component file.
243243
component_schema_context = deepcopy(self.context)
244244
component_schema_context[BASE_PATH_CONTEXT_KEY] = source_path.parent
245-
component = AnonymousDataTransferExportComponentSchema(context=component_schema_context).load(
245+
component = AnonymousDataTransferExportComponentSchema().load(
246246
component_dict
247-
)
247+
, context=component_schema_context)
248248
component._source_path = source_path
249249
component._source = ComponentSource.YAML_COMPONENT
250250
return component

sdk/ml/azure-ai-ml/azure/ai/ml/_schema/component/import_component.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,9 @@ def _deserialize(self, value, attr, data, **kwargs):
6464
# Update base_path to parent path of component file.
6565
component_schema_context = deepcopy(self.context)
6666
component_schema_context[BASE_PATH_CONTEXT_KEY] = source_path.parent
67-
component = AnonymousImportComponentSchema(context=component_schema_context).load(
67+
component = AnonymousImportComponentSchema().load(
6868
component_dict
69-
)
69+
, context=component_schema_context)
7070
component._source_path = source_path
7171
component._source = ComponentSource.YAML_COMPONENT
7272
return component

sdk/ml/azure-ai-ml/azure/ai/ml/_schema/component/parallel_component.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,9 +94,9 @@ def _deserialize(self, value, attr, data, **kwargs):
9494
# Update base_path to parent path of component file.
9595
component_schema_context = deepcopy(self.context)
9696
component_schema_context[BASE_PATH_CONTEXT_KEY] = source_path.parent
97-
component = AnonymousParallelComponentSchema(context=component_schema_context).load(
97+
component = AnonymousParallelComponentSchema().load(
9898
component_dict
99-
)
99+
, context=component_schema_context)
100100
component._source_path = source_path
101101
component._source = ComponentSource.YAML_COMPONENT
102102
return component

sdk/ml/azure-ai-ml/azure/ai/ml/_schema/component/spark_component.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,9 @@ def _deserialize(self, value, attr, data, **kwargs):
7070
# Update base_path to parent path of component file.
7171
component_schema_context = deepcopy(self.context)
7272
component_schema_context[BASE_PATH_CONTEXT_KEY] = source_path.parent
73-
component = AnonymousSparkComponentSchema(context=component_schema_context).load(
73+
component = AnonymousSparkComponentSchema().load(
7474
component_dict
75-
)
75+
, context=component_schema_context)
7676
component._source_path = source_path
7777
component._source = ComponentSource.YAML_COMPONENT
7878
return component

sdk/ml/azure-ai-ml/azure/ai/ml/_schema/core/schema.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,29 @@
1717
from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, FILE_PREFIX, PARAMS_OVERRIDE_KEY
1818
from azure.ai.ml.exceptions import MlException
1919

20+
21+
class ContextAwareSchema(PatchedBaseSchema, metaclass=PatchedSchemaMeta):
22+
"""A marshmallow 4.x compatible schema that can store and use context"""
23+
24+
def __init__(self, *args, **kwargs):
25+
# Store context as instance variable (marshmallow 4.x doesn't accept it in constructor)
26+
self.context = kwargs.pop("context", None)
27+
28+
# In marshmallow 4.x, filter out unsupported constructor parameters
29+
# Valid parameters for Schema constructor: only, exclude, many, load_only, dump_only, partial
30+
valid_schema_params = {'only', 'exclude', 'many', 'load_only', 'dump_only', 'partial'}
31+
filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_schema_params}
32+
33+
super().__init__(*args, **filtered_kwargs)
34+
35+
def load(self, json_data, *, many=None, partial=None, unknown=None):
36+
"""Override load to use stored context"""
37+
return super().load(json_data, many=many, partial=partial, unknown=unknown, context=self.context)
38+
39+
def dump(self, obj, *, many=None):
40+
"""Override dump to use stored context"""
41+
return super().dump(obj, many=many, context=self.context)
42+
2043
module_logger = logging.getLogger(__name__)
2144

2245

@@ -34,12 +57,21 @@ def __init__(self, *args, **kwargs):
3457
self.old_base_path = self.context.get(BASE_PATH_CONTEXT_KEY)
3558

3659
# In marshmallow 4.x, filter out unsupported constructor parameters
37-
# Valid parameters for Schema constructor: only, exclude, many, context, load_only, dump_only, partial
38-
valid_schema_params = {'only', 'exclude', 'many', 'context', 'load_only', 'dump_only', 'partial'}
60+
# Valid parameters for Schema constructor: only, exclude, many, load_only, dump_only, partial
61+
# Note: context is NOT a valid constructor parameter in marshmallow 4.x
62+
valid_schema_params = {'only', 'exclude', 'many', 'load_only', 'dump_only', 'partial'}
3963
filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_schema_params}
4064

4165
super().__init__(*args, **filtered_kwargs)
4266

67+
def load(self, json_data, *, many=None, partial=None, unknown=None):
68+
"""Override load to use stored context"""
69+
return super().load(json_data, many=many, partial=partial, unknown=unknown, context=self.context)
70+
71+
def dump(self, obj, *, many=None):
72+
"""Override dump to use stored context"""
73+
return super().dump(obj, many=many, context=self.context)
74+
4375
@pre_load
4476
def add_param_overrides(self, data, **kwargs):
4577
# Removing params override from context so that overriding is done once on the yaml

sdk/ml/azure-ai-ml/azure/ai/ml/_schema/pipeline/pipeline_component.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -247,9 +247,9 @@ def _deserialize(self, value, attr, data, **kwargs):
247247
# Update base_path to parent path of component file.
248248
component_schema_context = deepcopy(self.context)
249249
component_schema_context[BASE_PATH_CONTEXT_KEY] = source_path.parent
250-
component = _AnonymousPipelineComponentSchema(context=component_schema_context).load(
250+
component = _AnonymousPipelineComponentSchema().load(
251251
component_dict
252-
)
252+
, context=component_schema_context)
253253
component._source_path = source_path
254254
component._source = ComponentSource.YAML_COMPONENT
255255
return component

sdk/ml/azure-ai-ml/azure/ai/ml/entities/_assets/_artifacts/_package/base_environment_source.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def _from_rest_object(cls, rest_obj: RestBaseEnvironmentId) -> "RestBaseEnvironm
4242
return BaseEnvironment(type=rest_obj.base_environment_source_type, resource_id=rest_obj.resource_id)
4343

4444
def _to_dict(self) -> Dict:
45-
return dict(BaseEnvironmentSourceSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self))
45+
return dict(BaseEnvironmentSourceSchema().dump(self, context={BASE_PATH_CONTEXT_KEY: "./"}))
4646

4747
def _to_rest_object(self) -> RestBaseEnvironmentId:
4848
return RestBaseEnvironmentId(base_environment_source_type=self.type, resource_id=self.resource_id)

sdk/ml/azure-ai-ml/azure/ai/ml/entities/_assets/_artifacts/_package/model_package.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,7 @@ def dump(
296296
dump_yaml_to_file(dest, yaml_serialized, default_flow_style=False)
297297

298298
def _to_dict(self) -> Dict:
299-
return dict(ModelPackageSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self))
299+
return dict(ModelPackageSchema().dump(self, context={BASE_PATH_CONTEXT_KEY: "./"}))
300300

301301
@classmethod
302302
def _from_rest_object(cls, model_package_rest_object: PackageResponse) -> Any:

sdk/ml/azure-ai-ml/azure/ai/ml/entities/_assets/_artifacts/code.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def _load(
8585
return res
8686

8787
def _to_dict(self) -> Dict:
88-
res: dict = CodeAssetSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self)
88+
res: dict = CodeAssetSchema().dump(self, context={BASE_PATH_CONTEXT_KEY: "./"})
8989
return res
9090

9191
@classmethod

0 commit comments

Comments
 (0)