Skip to content

Commit b080cca

Browse files
committed
added task_def_as_dict and task_def_from_dict utility methods for serializing task classes to and from dictionaries
1 parent da37abb commit b080cca

File tree

7 files changed

+111
-6
lines changed

7 files changed

+111
-6
lines changed

pydra/compose/base/helpers.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,8 @@ def extract_function_inputs_and_outputs(
217217
inpt.default = default
218218
elif inspect.isclass(inpt) or ty.get_origin(inpt):
219219
inputs[inpt_name] = arg_type(type=inpt, default=default)
220+
elif isinstance(inpt, dict):
221+
inputs[inpt_name] = arg_type(**inpt)
220222
else:
221223
raise ValueError(
222224
f"Unrecognised input type ({inpt}) for input {inpt_name} with default "

pydra/compose/base/task.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,11 @@ def __repr__(self) -> str:
149149
class Task(ty.Generic[OutputsType]):
150150
"""Base class for all tasks"""
151151

152+
# Task type to be overridden in derived classes
153+
_task_type = ""
154+
# The attribute containing the function/executable used to run the task
155+
_executor_name = None
156+
152157
# Class attributes
153158
_xor: frozenset[frozenset[str | None]] = (
154159
frozenset()

pydra/compose/python.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ def define(
102102
bases: ty.Sequence[type] = (),
103103
outputs_bases: ty.Sequence[type] = (),
104104
auto_attribs: bool = True,
105+
name: str | None = None,
105106
xor: ty.Sequence[str | None] | ty.Sequence[ty.Sequence[str | None]] = (),
106107
) -> "Task":
107108
"""
@@ -117,6 +118,8 @@ def define(
117118
The outputs of the function or class.
118119
auto_attribs : bool
119120
Whether to use auto_attribs mode when creating the class.
121+
name: str | None
122+
The name of the returned class
120123
xor: Sequence[str | None] | Sequence[Sequence[str | None]], optional
121124
Names of args that are exclusive mutually exclusive, which must include
122125
the name of the current field. If this list includes None, then none of the
@@ -132,7 +135,7 @@ def make(wrapped: ty.Callable | type) -> Task:
132135
if inspect.isclass(wrapped):
133136
klass = wrapped
134137
function = klass.function
135-
name = klass.__name__
138+
class_name = klass.__name__
136139
check_explicit_fields_are_none(klass, inputs, outputs)
137140
parsed_inputs, parsed_outputs = extract_fields_from_class(
138141
Task,
@@ -154,7 +157,8 @@ def make(wrapped: ty.Callable | type) -> Task:
154157
inferred_inputs, inferred_outputs = extract_function_inputs_and_outputs(
155158
function, arg, inputs, outputs
156159
)
157-
name = function.__name__
160+
161+
class_name = function.__name__ if name is None else name
158162

159163
parsed_inputs, parsed_outputs = ensure_field_objects(
160164
arg_type=arg,
@@ -179,7 +183,7 @@ def make(wrapped: ty.Callable | type) -> Task:
179183
Outputs,
180184
parsed_inputs,
181185
parsed_outputs,
182-
name=name,
186+
name=class_name,
183187
klass=klass,
184188
bases=bases,
185189
outputs_bases=outputs_bases,
@@ -228,6 +232,7 @@ def _from_task(cls, job: "Job[PythonTask]") -> ty.Self:
228232
class PythonTask(base.Task[PythonOutputsType]):
229233

230234
_task_type = "python"
235+
_executor_name = "function"
231236

232237
def _run(self, job: "Job[PythonTask]", rerun: bool = True) -> None:
233238
# Prepare the inputs to the function

pydra/compose/shell/task.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,7 @@ def append_args_converter(value: ty.Any) -> list[str]:
233233
class ShellTask(base.Task[ShellOutputsType]):
234234

235235
_task_type = "shell"
236+
_executor_name = "executable"
236237

237238
BASE_NAMES = ["append_args"]
238239

pydra/compose/workflow.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ def define(
111111
outputs_bases: ty.Sequence[type] = (),
112112
lazy: list[str] | None = None,
113113
auto_attribs: bool = True,
114+
name: str | None = None,
114115
xor: ty.Sequence[str | None] | ty.Sequence[ty.Sequence[str | None]] = (),
115116
) -> "Task":
116117
"""
@@ -127,6 +128,8 @@ def define(
127128
The outputs of the function or class.
128129
auto_attribs : bool
129130
Whether to use auto_attribs mode when creating the class.
131+
name: str | None
132+
The name of the returned class
130133
xor: Sequence[str | None] | Sequence[Sequence[str | None]], optional
131134
Names of args that are exclusive mutually exclusive, which must include
132135
the name of the current field. If this list includes None, then none of the
@@ -145,7 +148,7 @@ def make(wrapped: ty.Callable | type) -> Task:
145148
if inspect.isclass(wrapped):
146149
klass = wrapped
147150
constructor = klass.constructor
148-
name = klass.__name__
151+
class_name = klass.__name__
149152
check_explicit_fields_are_none(klass, inputs, outputs)
150153
parsed_inputs, parsed_outputs = extract_fields_from_class(
151154
Task,
@@ -167,7 +170,8 @@ def make(wrapped: ty.Callable | type) -> Task:
167170
inferred_inputs, inferred_outputs = extract_function_inputs_and_outputs(
168171
constructor, arg, inputs, outputs
169172
)
170-
name = constructor.__name__
173+
174+
class_name = constructor.__name__ if name is None else name
171175

172176
parsed_inputs, parsed_outputs = ensure_field_objects(
173177
arg_type=arg,
@@ -195,7 +199,7 @@ def make(wrapped: ty.Callable | type) -> Task:
195199
Outputs,
196200
parsed_inputs,
197201
parsed_outputs,
198-
name=name,
202+
name=class_name,
199203
klass=klass,
200204
bases=bases,
201205
outputs_bases=outputs_bases,
@@ -349,6 +353,7 @@ def _from_task(cls, job: "Job[WorkflowTask]") -> ty.Self:
349353
class WorkflowTask(base.Task[WorkflowOutputsType]):
350354

351355
_task_type = "workflow"
356+
_executor_name = "constructor"
352357

353358
RESERVED_FIELD_NAMES = base.Task.RESERVED_FIELD_NAMES + ("construct",)
354359

pydra/utils/general.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import inspect
55
import sys
66
import typing as ty
7+
from copy import copy
78
import re
89
import attrs
910
import ast
@@ -584,3 +585,61 @@ def get_plugin_classes(namespace: types.ModuleType, class_name: str) -> dict[str
584585
for pkg in sub_packages
585586
if hasattr(pkg, class_name)
586587
}
588+
589+
590+
def task_def_as_dict(task_def: "type[Task]") -> ty.Dict[str, ty.Any]:
591+
"""Converts a Pydra task class into a dictionary representation that can be serialized
592+
and saved to a file, then read and passed to an appropriate `pydra.compose.*.define`
593+
method to recreate the task.
594+
595+
Parameters
596+
----------
597+
task_def : type[pydra.compose.base.Task]
598+
The Pydra task class to convert.
599+
600+
Returns
601+
-------
602+
dict[str, ty.Any]
603+
A dictionary representation of the Pydra task.
604+
"""
605+
input_fields = task_fields(task_def)
606+
executor = input_fields.pop(task_def._executor_name).default
607+
input_dicts = [attrs.asdict(i, filter=_filter_defaults) for i in input_fields]
608+
output_dicts = [
609+
attrs.asdict(o, filter=_filter_defaults) for o in task_fields(task_def.Outputs)
610+
]
611+
dct = {
612+
"type": task_def._task_type,
613+
task_def._executor_name: executor,
614+
"name": task_def.__name__,
615+
"inputs": {d.pop("name"): d for d in input_dicts},
616+
"outputs": {d.pop("name"): d for d in output_dicts},
617+
"xor": task_def._xor,
618+
}
619+
620+
return dct
621+
622+
623+
def task_def_from_dict(task_def_dict: dict[str, ty.Any]) -> type["Task"]:
624+
"""Unserializes a task definition from a dictionary created by `task_def_as_dict`
625+
626+
Parameters
627+
----------
628+
task_def_dict: dict[str, Any]
629+
the dictionary representation to unserialize
630+
631+
Returns
632+
-------
633+
type[pydra.compose.base.Task]
634+
the unserialized task class
635+
"""
636+
dct = copy(task_def_dict)
637+
task_type = dct.pop("type")
638+
compose_module = importlib.import_module(f"pydra.compose.{task_type}")
639+
return compose_module.define(dct.pop("function"), **dct)
640+
641+
642+
def _filter_defaults(atr: attrs.Attribute, value: ty.Any) -> bool:
643+
if value == atr.default:
644+
return False
645+
return True

pydra/utils/tests/test_general.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
from pydra.compose import python # , workflow, shell
2+
from pydra.utils.general import task_def_as_dict, task_def_from_dict, task_fields
3+
4+
5+
def test_python_task_def_as_dict():
6+
7+
@python.define(outputs=["out_int"], xor=["b", "c"])
8+
def Add(a: int, b: int | None = None, c: int | None = None) -> int:
9+
"""
10+
Parameters
11+
----------
12+
a: int
13+
the first arg
14+
b : int, optional
15+
the optional second arg
16+
c : int, optional
17+
the optional third arg
18+
19+
Returns
20+
-------
21+
out_int : int
22+
the sum of a and b
23+
"""
24+
return a + (b if b is not None else c)
25+
26+
dct = task_def_as_dict(Add)
27+
Reloaded = task_def_from_dict(dct)
28+
assert task_fields(Add) == task_fields(Reloaded)

0 commit comments

Comments
 (0)