Skip to content

Commit 9e4a94b

Browse files
committed
apply value_serializer to class attributes
1 parent 8dd5429 commit 9e4a94b

File tree

2 files changed

+86
-33
lines changed

2 files changed

+86
-33
lines changed

pydra/utils/general.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -650,9 +650,15 @@ def task_class_as_dict(
650650
}
651651
class_attrs = {a: getattr(task_class, "_" + a) for a in task_class.TASK_CLASS_ATTRS}
652652
if value_serializer:
653-
attrs_fields = {f.name: f for f in attrs.fields(task_class)}
653+
# We need to create a mock attrs object for the class attrs to apply the
654+
# value_serializer to it
655+
mock_cls = _make_attrs_class(
656+
{a: task_class.__annotations__.get(a, ty.Any) for a in class_attrs}
657+
)
658+
attrs_fields = {f.name: f for f in attrs.fields(mock_cls)}
659+
mock = mock_cls(**class_attrs)
654660
class_attrs = {
655-
n: value_serializer(task_class, attrs_fields[n], v)
661+
n: value_serializer(mock, attrs_fields[n], v)
656662
for n, v in class_attrs.items()
657663
}
658664
dct.update(class_attrs)
@@ -686,3 +692,25 @@ def _filter_out_defaults(atr: attrs.Attribute, value: ty.Any) -> bool:
686692
if value == atr.default:
687693
return False
688694
return True
695+
696+
697+
def _make_attrs_class(field_types: dict[str, type]) -> type:
698+
"""Creates an attrs given the a dictionary of field names and their types.
699+
700+
Parameters
701+
----------
702+
field_types : dict[str, type]
703+
A dictionary mapping field names to their types.
704+
705+
Returns
706+
-------
707+
type
708+
the attrs class.
709+
"""
710+
return attrs.define(
711+
type(
712+
"MockAttrsClass",
713+
(),
714+
{n: attrs.field(type=t) for n, t in field_types.items()},
715+
)
716+
)

pydra/utils/tests/test_general.py

Lines changed: 56 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,40 @@
1+
import typing as ty
2+
import attrs
13
from pydra.compose import python, workflow, shell
2-
from fileformats.generic import File
4+
from fileformats.text import TextFile
35
from pydra.utils.general import task_class_as_dict, task_class_from_dict, task_fields
4-
from pydra.utils.tests.utils import SpecificFuncTask, Concatenate
5-
6-
7-
def test_python_task_class_as_dict():
8-
9-
@python.define(outputs=["out_int"], xor=["b", "c"])
10-
def Add(a: int, b: int | None = None, c: int | None = None) -> int:
11-
"""
12-
Parameters
13-
----------
14-
a: int
15-
the first arg
16-
b : int, optional
17-
the optional second arg
18-
c : int, optional
19-
the optional third arg
20-
21-
Returns
22-
-------
23-
out_int : int
24-
the sum of a and b
25-
"""
26-
return a + (b if b is not None else c)
6+
from pydra.utils.tests.utils import Concatenate
7+
8+
9+
@python.define(outputs=["out_int"], xor=["b", "c"])
10+
def Add(a: int, b: int | None = None, c: int | None = None) -> int:
11+
"""
12+
Parameters
13+
----------
14+
a: int
15+
the first arg
16+
b : int, optional
17+
the optional second arg
18+
c : int, optional
19+
the optional third arg
20+
21+
Returns
22+
-------
23+
out_int : int
24+
the sum of a and b
25+
"""
26+
return a + (b if b is not None else c)
27+
28+
29+
def test_python_task_class_as_dict(tmp_path):
2730

2831
dct = task_class_as_dict(Add)
2932
Reloaded = task_class_from_dict(dct)
3033
assert task_fields(Add) == task_fields(Reloaded)
3134

35+
add = Reloaded(a=1, b=2)
36+
assert add(cache_root=tmp_path / "cache").out_int == 3
37+
3238

3339
def test_shell_task_class_as_dict():
3440

@@ -41,18 +47,37 @@ def test_shell_task_class_as_dict():
4147
assert task_fields(MyCmd) == task_fields(Reloaded)
4248

4349

44-
def test_workflow_task_class_as_dict():
50+
def test_workflow_task_class_as_dict(tmp_path):
4551

46-
@workflow.define
47-
def AWorkflow(in_file: File, a_param: int) -> tuple[File, File]:
48-
spec_func = workflow.add(SpecificFuncTask(in_file))
52+
@workflow.define(outputs=["out_file"])
53+
def AWorkflow(in_file: TextFile, a_param: int) -> TextFile:
4954
concatenate = workflow.add(
50-
Concatenate(
51-
in_file1=in_file, in_file2=spec_func.out_file, duplicates=a_param
52-
)
55+
Concatenate(in_file1=in_file, in_file2=in_file, duplicates=a_param)
5356
)
5457
return concatenate.out_file
5558

5659
dct = task_class_as_dict(AWorkflow)
5760
Reloaded = task_class_from_dict(dct)
5861
assert task_fields(AWorkflow) == task_fields(Reloaded)
62+
63+
foo_file = tmp_path / "file1.txt"
64+
foo_file.write_text("foo")
65+
66+
outputs = Reloaded(in_file=foo_file, a_param=2)(cache_root=tmp_path / "cache")
67+
assert outputs.out_file.contents == "foo\nfoo\nfoo\nfoo"
68+
69+
70+
def test_task_class_as_dict_with_value_serializer():
71+
72+
def frozen_set_to_list_serializer(
73+
mock_class: ty.Any, atr: attrs.Attribute, value: ty.Any
74+
) -> ty.Any:
75+
# This is just a dummy serializer
76+
if isinstance(value, frozenset):
77+
return list(
78+
frozen_set_to_list_serializer(mock_class, atr, v) for v in value
79+
)
80+
return value
81+
82+
dct = task_class_as_dict(Add, value_serializer=frozen_set_to_list_serializer)
83+
assert dct["xor"] == [["b", "c"]] or dct["xor"] == [["c", "b"]]

0 commit comments

Comments
 (0)