Skip to content

Commit c346968

Browse files
committed
replace fronzensets with lists in task class serialization and add filter test
1 parent 9888cc6 commit c346968

File tree

3 files changed

+105
-18
lines changed

3 files changed

+105
-18
lines changed

pydra/compose/base/field.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ class Requirement:
6060

6161
name: str
6262
allowed_values: list[str] | None = attrs.field(
63-
default=None, converter=allowed_values_converter
63+
factory=None, converter=allowed_values_converter
6464
)
6565

6666
def satisfied(self, inputs: "Task") -> bool:
@@ -326,7 +326,7 @@ class Arg(Field):
326326
it is False
327327
"""
328328

329-
allowed_values: frozenset = attrs.field(default=(), converter=frozenset)
329+
allowed_values: frozenset = attrs.field(factory=frozenset, converter=frozenset)
330330
copy_mode: File.CopyMode = File.CopyMode.any
331331
copy_collation: File.CopyCollation = File.CopyCollation.any
332332
copy_ext_decomp: File.ExtensionDecomposition = File.ExtensionDecomposition.single

pydra/utils/general.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import inspect
55
import sys
66
import typing as ty
7+
from collections.abc import Mapping, Collection
78
from copy import copy
89
import re
910
import attrs
@@ -624,20 +625,36 @@ def serialize_task_class(
624625
from pydra.compose.base import Out
625626

626627
if filter is None:
627-
filter = _filter_out_defaults
628+
filter = filter_out_defaults
629+
630+
def full_val_serializer(
631+
obj: ty.Any, field: attrs.Attribute, value: ty.Any
632+
) -> ty.Any:
633+
"""A wrapper for the value serializer to handle the case where it is None."""
634+
if value_serializer is not None:
635+
value = value_serializer(obj, field, value)
636+
if isinstance(value, str):
637+
return value
638+
if isinstance(value, Mapping) and not isinstance(value, dict):
639+
# If the value is a mapping, convert it to a dict
640+
value = dict(value)
641+
elif isinstance(value, Collection) and not isinstance(value, list):
642+
# If the value is not a collection or string, convert it to a list
643+
value = list(value)
644+
return value
628645

629646
input_fields = get_fields(task_class)
630647
executor = input_fields.pop(task_class._executor_name).default
631648
input_dicts = [
632-
attrs.asdict(i, filter=filter, value_serializer=value_serializer, **kwargs)
649+
attrs.asdict(i, filter=filter, value_serializer=full_val_serializer, **kwargs)
633650
for i in input_fields
634651
if (
635652
not isinstance(i, Out) # filter out outarg fields
636653
and i.name not in task_class.BASE_ATTRS
637654
)
638655
]
639656
output_dicts = [
640-
attrs.asdict(o, filter=filter, value_serializer=value_serializer, **kwargs)
657+
attrs.asdict(o, filter=filter, value_serializer=full_val_serializer, **kwargs)
641658
for o in get_fields(task_class.Outputs)
642659
if o.name not in task_class.Outputs.BASE_ATTRS
643660
]
@@ -649,6 +666,8 @@ def serialize_task_class(
649666
"outputs": {d.pop("name"): d for d in output_dicts},
650667
}
651668
class_attrs = {a: getattr(task_class, "_" + a) for a in task_class.TASK_CLASS_ATTRS}
669+
# Convert the frozensets to lists for serialization
670+
class_attrs["xor"] = [list(x) for x in class_attrs.pop("xor")]
652671
if value_serializer:
653672
# We need to create a mock attrs object for the class attrs to apply the
654673
# value_serializer to it
@@ -658,7 +677,7 @@ def serialize_task_class(
658677
attrs_fields = {f.name: f for f in attrs.fields(mock_cls)}
659678
mock = mock_cls(**class_attrs)
660679
class_attrs = {
661-
n: value_serializer(mock, attrs_fields[n], v)
680+
n: full_val_serializer(mock, attrs_fields[n], v)
662681
for n, v in class_attrs.items()
663682
}
664683
dct.update(class_attrs)
@@ -685,7 +704,7 @@ def unserialize_task_class(task_class_dict: dict[str, ty.Any]) -> type["Task"]:
685704
return mod.define(dct.pop(mod.Task._executor_name), **dct)
686705

687706

688-
def _filter_out_defaults(atr: attrs.Attribute, value: ty.Any) -> bool:
707+
def filter_out_defaults(atr: attrs.Attribute, value: ty.Any) -> bool:
689708
"""Filter out values that match the attributes default value."""
690709
if isinstance(atr.default, attrs.Factory) and atr.default.factory() == value:
691710
return False

pydra/utils/tests/test_general.py

Lines changed: 79 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,34 @@
11
import typing as ty
22
import attrs
3+
from collections.abc import Collection
34
from pydra.compose import python, workflow, shell
45
from fileformats.text import TextFile
56
from pydra.utils.general import (
67
serialize_task_class,
78
unserialize_task_class,
89
get_fields,
10+
filter_out_defaults,
911
)
1012
from pydra.utils.tests.utils import Concatenate
1113

1214

15+
def check_dict_fully_serialized(dct: dict):
16+
"""Checks if there are any Pydra objects or non list/dict containers in the dict."""
17+
stack = [dct]
18+
while stack:
19+
item = stack.pop()
20+
if isinstance(item, dict):
21+
stack.extend(item.values())
22+
elif isinstance(item, list):
23+
stack.extend(item)
24+
elif isinstance(item, str):
25+
pass
26+
elif isinstance(item, Collection):
27+
raise ValueError(f"Unserializable container object {item} found in dict")
28+
elif type(item).__module__.split(".")[0] == "pydra":
29+
raise ValueError(f"Unserialized Pydra object {item} found in dict")
30+
31+
1332
@python.define(outputs=["out_int"], xor=["b", "c"])
1433
def Add(a: int, b: int | None = None, c: int | None = None) -> int:
1534
"""
@@ -32,12 +51,15 @@ def Add(a: int, b: int | None = None, c: int | None = None) -> int:
3251

3352
def test_python_serialize_task_class(tmp_path):
3453

54+
assert Add(a=1, b=2)(cache_root=tmp_path / "cache1").out_int == 3
55+
3556
dct = serialize_task_class(Add)
57+
assert isinstance(dct, dict)
58+
check_dict_fully_serialized(dct)
3659
Reloaded = unserialize_task_class(dct)
3760
assert get_fields(Add) == get_fields(Reloaded)
3861

39-
add = Reloaded(a=1, b=2)
40-
assert add(cache_root=tmp_path / "cache").out_int == 3
62+
assert Reloaded(a=1, b=2)(cache_root=tmp_path / "cache2").out_int == 3
4163

4264

4365
def test_shell_serialize_task_class():
@@ -47,6 +69,8 @@ def test_shell_serialize_task_class():
4769
)
4870

4971
dct = serialize_task_class(MyCmd)
72+
assert isinstance(dct, dict)
73+
check_dict_fully_serialized(dct)
5074
Reloaded = unserialize_task_class(dct)
5175
assert get_fields(MyCmd) == get_fields(Reloaded)
5276

@@ -61,6 +85,8 @@ def AWorkflow(in_file: TextFile, a_param: int) -> TextFile:
6185
return concatenate.out_file
6286

6387
dct = serialize_task_class(AWorkflow)
88+
assert isinstance(dct, dict)
89+
check_dict_fully_serialized(dct)
6490
Reloaded = unserialize_task_class(dct)
6591
assert get_fields(AWorkflow) == get_fields(Reloaded)
6692

@@ -73,15 +99,57 @@ def AWorkflow(in_file: TextFile, a_param: int) -> TextFile:
7399

74100
def test_serialize_task_class_with_value_serializer():
75101

76-
def frozen_set_to_list_serializer(
77-
mock_class: ty.Any, atr: attrs.Attribute, value: ty.Any
102+
@python.define
103+
def Identity(a: int) -> int:
104+
"""
105+
Parameters
106+
----------
107+
a: int
108+
the arg
109+
110+
Returns
111+
-------
112+
out : int
113+
a returned as is
114+
"""
115+
return a
116+
117+
def type_to_str_serializer(
118+
klass: ty.Any, atr: attrs.Attribute, value: ty.Any
78119
) -> ty.Any:
79-
# This is just a dummy serializer
80-
if isinstance(value, frozenset):
81-
return list(
82-
frozen_set_to_list_serializer(mock_class, atr, v) for v in value
83-
)
120+
if isinstance(value, type):
121+
return value.__module__ + "." + value.__name__
84122
return value
85123

86-
dct = serialize_task_class(Add, value_serializer=frozen_set_to_list_serializer)
87-
assert dct["xor"] == [["b", "c"]] or dct["xor"] == [["c", "b"]]
124+
dct = serialize_task_class(Identity, value_serializer=type_to_str_serializer)
125+
assert isinstance(dct, dict)
126+
check_dict_fully_serialized(dct)
127+
assert dct["inputs"] == {"a": {"type": "builtins.int", "help": "the arg"}}
128+
129+
130+
def test_serialize_task_class_with_filter():
131+
132+
@python.define
133+
def Identity(a: int) -> int:
134+
"""
135+
Parameters
136+
----------
137+
a: int
138+
the arg
139+
140+
Returns
141+
-------
142+
out : int
143+
a returned as is
144+
"""
145+
return a
146+
147+
def no_helps_filter(atr: attrs.Attribute, value: ty.Any) -> bool:
148+
if atr.name == "help":
149+
return False
150+
return filter_out_defaults(atr, value)
151+
152+
dct = serialize_task_class(Identity, filter=no_helps_filter)
153+
assert isinstance(dct, dict)
154+
check_dict_fully_serialized(dct)
155+
assert dct["inputs"] == {"a": {"type": int}}

0 commit comments

Comments
 (0)