Skip to content

Commit 4391689

Browse files
committed
pass through task_def_as_dict kwargs to attrs.asdict
1 parent 4bdcf78 commit 4391689

File tree

1 file changed

+21
-4
lines changed

1 file changed

+21
-4
lines changed

pydra/utils/general.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -587,7 +587,11 @@ def get_plugin_classes(namespace: types.ModuleType, class_name: str) -> dict[str
587587
}
588588

589589

590-
def task_def_as_dict(task_def: "type[Task]") -> ty.Dict[str, ty.Any]:
590+
def task_def_as_dict(
591+
task_def: "type[Task]",
592+
filter: ty.Callable[[attrs.Attribute, ty.Any], bool] | None = None,
593+
**kwargs: ty.Any,
594+
) -> ty.Dict[str, ty.Any]:
591595
"""Converts a Pydra task class into a dictionary representation that can be serialized
592596
and saved to a file, then read and passed to an appropriate `pydra.compose.*.define`
593597
method to recreate the task.
@@ -596,6 +600,13 @@ def task_def_as_dict(task_def: "type[Task]") -> ty.Dict[str, ty.Any]:
596600
----------
597601
task_def : type[pydra.compose.base.Task]
598602
The Pydra task class to convert.
603+
filter : callable, optional
604+
A function to filter out certain attributes from the task definition passed
605+
through to `attrs.asdict`. It should take an attribute and its value as
606+
arguments and return a boolean (True to keep the attribute, False to filter it out).
607+
**kwargs : dict
608+
Additional keyword arguments to pass to `attrs.asdict` (i.e. all except `filter`)
609+
See the `attrs` documentation for more details.
599610
600611
Returns
601612
-------
@@ -604,18 +615,21 @@ def task_def_as_dict(task_def: "type[Task]") -> ty.Dict[str, ty.Any]:
604615
"""
605616
from pydra.compose.base import Out
606617

618+
if filter is None:
619+
filter = _filter_out_defaults
620+
607621
input_fields = task_fields(task_def)
608622
executor = input_fields.pop(task_def._executor_name).default
609623
input_dicts = [
610-
attrs.asdict(i, filter=_filter_defaults)
624+
attrs.asdict(i, filter=filter, **kwargs)
611625
for i in input_fields
612626
if (
613627
not isinstance(i, Out) # filter out outarg fields
614628
and i.name not in task_def.BASE_ATTRS
615629
)
616630
]
617631
output_dicts = [
618-
attrs.asdict(o, filter=_filter_defaults)
632+
attrs.asdict(o, filter=filter, **kwargs)
619633
for o in task_fields(task_def.Outputs)
620634
if o.name not in task_def.Outputs.BASE_ATTRS
621635
]
@@ -650,7 +664,10 @@ def task_def_from_dict(task_def_dict: dict[str, ty.Any]) -> type["Task"]:
650664
return mod.define(dct.pop(mod.Task._executor_name), **dct)
651665

652666

653-
def _filter_defaults(atr: attrs.Attribute, value: ty.Any) -> bool:
667+
def _filter_out_defaults(atr: attrs.Attribute, value: ty.Any) -> bool:
668+
"""Filter out values that match the attributes default value."""
669+
if isinstance(atr.default, attrs.Factory) and atr.default.factory() == value:
670+
return False
654671
if value == atr.default:
655672
return False
656673
return True

0 commit comments

Comments
 (0)