1
1
import typing as ty
2
2
import attrs
3
+ from collections .abc import Collection
3
4
from pydra .compose import python , workflow , shell
4
5
from fileformats .text import TextFile
5
6
from pydra .utils .general import (
6
7
serialize_task_class ,
7
8
unserialize_task_class ,
8
9
get_fields ,
10
+ filter_out_defaults ,
9
11
)
10
12
from pydra .utils .tests .utils import Concatenate
11
13
12
14
15
+ def check_dict_fully_serialized (dct : dict ):
16
+ """Checks if there are any Pydra objects or non list/dict containers in the dict."""
17
+ stack = [dct ]
18
+ while stack :
19
+ item = stack .pop ()
20
+ if isinstance (item , dict ):
21
+ stack .extend (item .values ())
22
+ elif isinstance (item , list ):
23
+ stack .extend (item )
24
+ elif isinstance (item , str ):
25
+ pass
26
+ elif isinstance (item , Collection ):
27
+ raise ValueError (f"Unserializable container object { item } found in dict" )
28
+ elif type (item ).__module__ .split ("." )[0 ] == "pydra" :
29
+ raise ValueError (f"Unserialized Pydra object { item } found in dict" )
30
+
31
+
13
32
@python .define (outputs = ["out_int" ], xor = ["b" , "c" ])
14
33
def Add (a : int , b : int | None = None , c : int | None = None ) -> int :
15
34
"""
@@ -32,12 +51,15 @@ def Add(a: int, b: int | None = None, c: int | None = None) -> int:
32
51
33
52
def test_python_serialize_task_class (tmp_path ):
34
53
54
+ assert Add (a = 1 , b = 2 )(cache_root = tmp_path / "cache1" ).out_int == 3
55
+
35
56
dct = serialize_task_class (Add )
57
+ assert isinstance (dct , dict )
58
+ check_dict_fully_serialized (dct )
36
59
Reloaded = unserialize_task_class (dct )
37
60
assert get_fields (Add ) == get_fields (Reloaded )
38
61
39
- add = Reloaded (a = 1 , b = 2 )
40
- assert add (cache_root = tmp_path / "cache" ).out_int == 3
62
+ assert Reloaded (a = 1 , b = 2 )(cache_root = tmp_path / "cache2" ).out_int == 3
41
63
42
64
43
65
def test_shell_serialize_task_class ():
@@ -47,6 +69,8 @@ def test_shell_serialize_task_class():
47
69
)
48
70
49
71
dct = serialize_task_class (MyCmd )
72
+ assert isinstance (dct , dict )
73
+ check_dict_fully_serialized (dct )
50
74
Reloaded = unserialize_task_class (dct )
51
75
assert get_fields (MyCmd ) == get_fields (Reloaded )
52
76
@@ -61,6 +85,8 @@ def AWorkflow(in_file: TextFile, a_param: int) -> TextFile:
61
85
return concatenate .out_file
62
86
63
87
dct = serialize_task_class (AWorkflow )
88
+ assert isinstance (dct , dict )
89
+ check_dict_fully_serialized (dct )
64
90
Reloaded = unserialize_task_class (dct )
65
91
assert get_fields (AWorkflow ) == get_fields (Reloaded )
66
92
@@ -73,15 +99,57 @@ def AWorkflow(in_file: TextFile, a_param: int) -> TextFile:
73
99
74
100
def test_serialize_task_class_with_value_serializer ():
75
101
76
- def frozen_set_to_list_serializer (
77
- mock_class : ty .Any , atr : attrs .Attribute , value : ty .Any
102
+ @python .define
103
+ def Identity (a : int ) -> int :
104
+ """
105
+ Parameters
106
+ ----------
107
+ a: int
108
+ the arg
109
+
110
+ Returns
111
+ -------
112
+ out : int
113
+ a returned as is
114
+ """
115
+ return a
116
+
117
+ def type_to_str_serializer (
118
+ klass : ty .Any , atr : attrs .Attribute , value : ty .Any
78
119
) -> ty .Any :
79
- # This is just a dummy serializer
80
- if isinstance (value , frozenset ):
81
- return list (
82
- frozen_set_to_list_serializer (mock_class , atr , v ) for v in value
83
- )
120
+ if isinstance (value , type ):
121
+ return value .__module__ + "." + value .__name__
84
122
return value
85
123
86
- dct = serialize_task_class (Add , value_serializer = frozen_set_to_list_serializer )
87
- assert dct ["xor" ] == [["b" , "c" ]] or dct ["xor" ] == [["c" , "b" ]]
124
+ dct = serialize_task_class (Identity , value_serializer = type_to_str_serializer )
125
+ assert isinstance (dct , dict )
126
+ check_dict_fully_serialized (dct )
127
+ assert dct ["inputs" ] == {"a" : {"type" : "builtins.int" , "help" : "the arg" }}
128
+
129
+
130
+ def test_serialize_task_class_with_filter ():
131
+
132
+ @python .define
133
+ def Identity (a : int ) -> int :
134
+ """
135
+ Parameters
136
+ ----------
137
+ a: int
138
+ the arg
139
+
140
+ Returns
141
+ -------
142
+ out : int
143
+ a returned as is
144
+ """
145
+ return a
146
+
147
+ def no_helps_filter (atr : attrs .Attribute , value : ty .Any ) -> bool :
148
+ if atr .name == "help" :
149
+ return False
150
+ return filter_out_defaults (atr , value )
151
+
152
+ dct = serialize_task_class (Identity , filter = no_helps_filter )
153
+ assert isinstance (dct , dict )
154
+ check_dict_fully_serialized (dct )
155
+ assert dct ["inputs" ] == {"a" : {"type" : int }}
0 commit comments