@@ -587,7 +587,11 @@ def get_plugin_classes(namespace: types.ModuleType, class_name: str) -> dict[str
587
587
}
588
588
589
589
590
- def task_def_as_dict (task_def : "type[Task]" ) -> ty .Dict [str , ty .Any ]:
590
+ def task_def_as_dict (
591
+ task_def : "type[Task]" ,
592
+ filter : ty .Callable [[attrs .Attribute , ty .Any ], bool ] | None = None ,
593
+ ** kwargs : ty .Any ,
594
+ ) -> ty .Dict [str , ty .Any ]:
591
595
"""Converts a Pydra task class into a dictionary representation that can be serialized
592
596
and saved to a file, then read and passed to an appropriate `pydra.compose.*.define`
593
597
method to recreate the task.
@@ -596,6 +600,13 @@ def task_def_as_dict(task_def: "type[Task]") -> ty.Dict[str, ty.Any]:
596
600
----------
597
601
task_def : type[pydra.compose.base.Task]
598
602
The Pydra task class to convert.
603
+ filter : callable, optional
604
+ A function to filter out certain attributes from the task definition passed
605
+ through to `attrs.asdict`. It should take an attribute and its value as
606
+ arguments and return a boolean (True to keep the attribute, False to filter it out).
607
+ **kwargs : dict
608
+ Additional keyword arguments to pass to `attrs.asdict` (i.e. all except `filter`)
609
+ See the `attrs` documentation for more details.
599
610
600
611
Returns
601
612
-------
@@ -604,18 +615,21 @@ def task_def_as_dict(task_def: "type[Task]") -> ty.Dict[str, ty.Any]:
604
615
"""
605
616
from pydra .compose .base import Out
606
617
618
+ if filter is None :
619
+ filter = _filter_out_defaults
620
+
607
621
input_fields = task_fields (task_def )
608
622
executor = input_fields .pop (task_def ._executor_name ).default
609
623
input_dicts = [
610
- attrs .asdict (i , filter = _filter_defaults )
624
+ attrs .asdict (i , filter = filter , ** kwargs )
611
625
for i in input_fields
612
626
if (
613
627
not isinstance (i , Out ) # filter out outarg fields
614
628
and i .name not in task_def .BASE_ATTRS
615
629
)
616
630
]
617
631
output_dicts = [
618
- attrs .asdict (o , filter = _filter_defaults )
632
+ attrs .asdict (o , filter = filter , ** kwargs )
619
633
for o in task_fields (task_def .Outputs )
620
634
if o .name not in task_def .Outputs .BASE_ATTRS
621
635
]
@@ -650,7 +664,10 @@ def task_def_from_dict(task_def_dict: dict[str, ty.Any]) -> type["Task"]:
650
664
return mod .define (dct .pop (mod .Task ._executor_name ), ** dct )
651
665
652
666
653
- def _filter_defaults (atr : attrs .Attribute , value : ty .Any ) -> bool :
667
+ def _filter_out_defaults (atr : attrs .Attribute , value : ty .Any ) -> bool :
668
+ """Filter out values that match the attributes default value."""
669
+ if isinstance (atr .default , attrs .Factory ) and atr .default .factory () == value :
670
+ return False
654
671
if value == atr .default :
655
672
return False
656
673
return True
0 commit comments