Skip to content

Commit 5f9e2d4

Browse files
committed
apply value serializer to task_class_attrs
1 parent 4391689 commit 5f9e2d4

File tree

1 file changed

+20
-5
lines changed

1 file changed

+20
-5
lines changed

pydra/utils/general.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -590,6 +590,9 @@ def get_plugin_classes(namespace: types.ModuleType, class_name: str) -> dict[str
590590
def task_def_as_dict(
591591
task_def: "type[Task]",
592592
filter: ty.Callable[[attrs.Attribute, ty.Any], bool] | None = None,
593+
value_serializer: (
594+
ty.Callable[[ty.Any, attrs.Attribute, ty.Any], ty.Any] | None
595+
) = None,
593596
**kwargs: ty.Any,
594597
) -> ty.Dict[str, ty.Any]:
595598
"""Converts a Pydra task class into a dictionary representation that can be serialized
@@ -604,9 +607,14 @@ def task_def_as_dict(
604607
A function to filter out certain attributes from the task definition passed
605608
through to `attrs.asdict`. It should take an attribute and its value as
606609
arguments and return a boolean (True to keep the attribute, False to filter it out).
610+
value_serializer : callable, optional
611+
A function to serialize the value of an attribute. It should take the task
612+
definition, the attribute, and its value as arguments and return a serialized
613+
value.
607614
**kwargs : dict
608-
Additional keyword arguments to pass to `attrs.asdict` (i.e. all except `filter`)
609-
See the `attrs` documentation for more details.
615+
Additional keyword arguments to pass to `attrs.asdict` (i.e. all except `filter`,
616+
and `value_serializer`), e.g. `recurse`. See the `attrs` documentation for more
617+
details.
610618
611619
Returns
612620
-------
@@ -621,15 +629,15 @@ def task_def_as_dict(
621629
input_fields = task_fields(task_def)
622630
executor = input_fields.pop(task_def._executor_name).default
623631
input_dicts = [
624-
attrs.asdict(i, filter=filter, **kwargs)
632+
attrs.asdict(i, filter=filter, value_serializer=value_serializer, **kwargs)
625633
for i in input_fields
626634
if (
627635
not isinstance(i, Out) # filter out outarg fields
628636
and i.name not in task_def.BASE_ATTRS
629637
)
630638
]
631639
output_dicts = [
632-
attrs.asdict(o, filter=filter, **kwargs)
640+
attrs.asdict(o, filter=filter, value_serializer=value_serializer, **kwargs)
633641
for o in task_fields(task_def.Outputs)
634642
if o.name not in task_def.Outputs.BASE_ATTRS
635643
]
@@ -640,7 +648,14 @@ def task_def_as_dict(
640648
"inputs": {d.pop("name"): d for d in input_dicts},
641649
"outputs": {d.pop("name"): d for d in output_dicts},
642650
}
643-
dct.update({a: getattr(task_def, "_" + a) for a in task_def.TASK_CLASS_ATTRS})
651+
class_attrs = {a: getattr(task_def, "_" + a) for a in task_def.TASK_CLASS_ATTRS}
652+
if value_serializer:
653+
attrs_fields = {f.name: f for f in attrs.fields(task_def)}
654+
class_attrs = {
655+
n: value_serializer(task_def, attrs_fields[n], v)
656+
for n, v in class_attrs.items()
657+
}
658+
dct.update(class_attrs)
644659

645660
return dct
646661

0 commit comments

Comments
 (0)