11import typing as ty
22import attrs
3+ from collections .abc import Collection
34from pydra .compose import python , workflow , shell
45from fileformats .text import TextFile
56from pydra .utils .general import (
67 serialize_task_class ,
78 unserialize_task_class ,
89 get_fields ,
10+ filter_out_defaults ,
911)
1012from pydra .utils .tests .utils import Concatenate
1113
1214
15+ def check_dict_fully_serialized (dct : dict ):
16+ """Checks if there are any Pydra objects or non list/dict containers in the dict."""
17+ stack = [dct ]
18+ while stack :
19+ item = stack .pop ()
20+ if isinstance (item , dict ):
21+ stack .extend (item .values ())
22+ elif isinstance (item , list ):
23+ stack .extend (item )
24+ elif isinstance (item , str ):
25+ pass
26+ elif isinstance (item , Collection ):
27+ raise ValueError (f"Unserializable container object { item } found in dict" )
28+ elif type (item ).__module__ .split ("." )[0 ] == "pydra" :
29+ raise ValueError (f"Unserialized Pydra object { item } found in dict" )
30+
31+
1332@python .define (outputs = ["out_int" ], xor = ["b" , "c" ])
1433def Add (a : int , b : int | None = None , c : int | None = None ) -> int :
1534 """
@@ -32,12 +51,15 @@ def Add(a: int, b: int | None = None, c: int | None = None) -> int:
3251
3352def test_python_serialize_task_class (tmp_path ):
3453
54+ assert Add (a = 1 , b = 2 )(cache_root = tmp_path / "cache1" ).out_int == 3
55+
3556 dct = serialize_task_class (Add )
57+ assert isinstance (dct , dict )
58+ check_dict_fully_serialized (dct )
3659 Reloaded = unserialize_task_class (dct )
3760 assert get_fields (Add ) == get_fields (Reloaded )
3861
39- add = Reloaded (a = 1 , b = 2 )
40- assert add (cache_root = tmp_path / "cache" ).out_int == 3
62+ assert Reloaded (a = 1 , b = 2 )(cache_root = tmp_path / "cache2" ).out_int == 3
4163
4264
4365def test_shell_serialize_task_class ():
@@ -47,6 +69,8 @@ def test_shell_serialize_task_class():
4769 )
4870
4971 dct = serialize_task_class (MyCmd )
72+ assert isinstance (dct , dict )
73+ check_dict_fully_serialized (dct )
5074 Reloaded = unserialize_task_class (dct )
5175 assert get_fields (MyCmd ) == get_fields (Reloaded )
5276
@@ -61,6 +85,8 @@ def AWorkflow(in_file: TextFile, a_param: int) -> TextFile:
6185 return concatenate .out_file
6286
6387 dct = serialize_task_class (AWorkflow )
88+ assert isinstance (dct , dict )
89+ check_dict_fully_serialized (dct )
6490 Reloaded = unserialize_task_class (dct )
6591 assert get_fields (AWorkflow ) == get_fields (Reloaded )
6692
@@ -73,15 +99,57 @@ def AWorkflow(in_file: TextFile, a_param: int) -> TextFile:
7399
74100def test_serialize_task_class_with_value_serializer ():
75101
76- def frozen_set_to_list_serializer (
77- mock_class : ty .Any , atr : attrs .Attribute , value : ty .Any
102+ @python .define
103+ def Identity (a : int ) -> int :
104+ """
105+ Parameters
106+ ----------
107+ a: int
108+ the arg
109+
110+ Returns
111+ -------
112+ out : int
113+ a returned as is
114+ """
115+ return a
116+
117+ def type_to_str_serializer (
118+ klass : ty .Any , atr : attrs .Attribute , value : ty .Any
78119 ) -> ty .Any :
79- # This is just a dummy serializer
80- if isinstance (value , frozenset ):
81- return list (
82- frozen_set_to_list_serializer (mock_class , atr , v ) for v in value
83- )
120+ if isinstance (value , type ):
121+ return value .__module__ + "." + value .__name__
84122 return value
85123
86- dct = serialize_task_class (Add , value_serializer = frozen_set_to_list_serializer )
87- assert dct ["xor" ] == [["b" , "c" ]] or dct ["xor" ] == [["c" , "b" ]]
124+ dct = serialize_task_class (Identity , value_serializer = type_to_str_serializer )
125+ assert isinstance (dct , dict )
126+ check_dict_fully_serialized (dct )
127+ assert dct ["inputs" ] == {"a" : {"type" : "builtins.int" , "help" : "the arg" }}
128+
129+
130+ def test_serialize_task_class_with_filter ():
131+
132+ @python .define
133+ def Identity (a : int ) -> int :
134+ """
135+ Parameters
136+ ----------
137+ a: int
138+ the arg
139+
140+ Returns
141+ -------
142+ out : int
143+ a returned as is
144+ """
145+ return a
146+
147+ def no_helps_filter (atr : attrs .Attribute , value : ty .Any ) -> bool :
148+ if atr .name == "help" :
149+ return False
150+ return filter_out_defaults (atr , value )
151+
152+ dct = serialize_task_class (Identity , filter = no_helps_filter )
153+ assert isinstance (dct , dict )
154+ check_dict_fully_serialized (dct )
155+ assert dct ["inputs" ] == {"a" : {"type" : int }}
0 commit comments