@@ -587,7 +587,11 @@ def get_plugin_classes(namespace: types.ModuleType, class_name: str) -> dict[str
587587 }
588588
589589
590- def task_def_as_dict (task_def : "type[Task]" ) -> ty .Dict [str , ty .Any ]:
590+ def task_def_as_dict (
591+ task_def : "type[Task]" ,
592+ filter : ty .Callable [[attrs .Attribute , ty .Any ], bool ] | None = None ,
593+ ** kwargs : ty .Any ,
594+ ) -> ty .Dict [str , ty .Any ]:
591595 """Converts a Pydra task class into a dictionary representation that can be serialized
592596 and saved to a file, then read and passed to an appropriate `pydra.compose.*.define`
593597 method to recreate the task.
@@ -596,6 +600,13 @@ def task_def_as_dict(task_def: "type[Task]") -> ty.Dict[str, ty.Any]:
596600 ----------
597601 task_def : type[pydra.compose.base.Task]
598602 The Pydra task class to convert.
603+ filter : callable, optional
604+ A function to filter out certain attributes from the task definition passed
605+ through to `attrs.asdict`. It should take an attribute and its value as
606+ arguments and return a boolean (True to keep the attribute, False to filter it out).
607+ **kwargs : dict
608+ Additional keyword arguments to pass to `attrs.asdict` (i.e. all except `filter`)
609+ See the `attrs` documentation for more details.
599610
600611 Returns
601612 -------
@@ -604,18 +615,21 @@ def task_def_as_dict(task_def: "type[Task]") -> ty.Dict[str, ty.Any]:
604615 """
605616 from pydra .compose .base import Out
606617
618+ if filter is None :
619+ filter = _filter_out_defaults
620+
607621 input_fields = task_fields (task_def )
608622 executor = input_fields .pop (task_def ._executor_name ).default
609623 input_dicts = [
610- attrs .asdict (i , filter = _filter_defaults )
624+ attrs .asdict (i , filter = filter , ** kwargs )
611625 for i in input_fields
612626 if (
613627 not isinstance (i , Out ) # filter out outarg fields
614628 and i .name not in task_def .BASE_ATTRS
615629 )
616630 ]
617631 output_dicts = [
618- attrs .asdict (o , filter = _filter_defaults )
632+ attrs .asdict (o , filter = filter , ** kwargs )
619633 for o in task_fields (task_def .Outputs )
620634 if o .name not in task_def .Outputs .BASE_ATTRS
621635 ]
@@ -650,7 +664,10 @@ def task_def_from_dict(task_def_dict: dict[str, ty.Any]) -> type["Task"]:
650664 return mod .define (dct .pop (mod .Task ._executor_name ), ** dct )
651665
652666
653- def _filter_defaults (atr : attrs .Attribute , value : ty .Any ) -> bool :
667+ def _filter_out_defaults (atr : attrs .Attribute , value : ty .Any ) -> bool :
668+ """Filter out values that match the attributes default value."""
669+ if isinstance (atr .default , attrs .Factory ) and atr .default .factory () == value :
670+ return False
654671 if value == atr .default :
655672 return False
656673 return True
0 commit comments