Skip to content

Commit 4beb6a0

Browse files
committed
automated _task_type so that it always matches the name of the sub-package of pydra.compose the Task is defined in
1 parent 91ae089 commit 4beb6a0

File tree

6 files changed

+16
-12
lines changed

6 files changed

+16
-12
lines changed

pydra/compose/base/task.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -150,11 +150,18 @@ class Task(ty.Generic[OutputsType]):
150150
"""Base class for all tasks"""
151151

152152
# Task type to be overridden in derived classes
153-
_task_type = ""
153+
@classmethod
154+
def _task_type(cls) -> str:
155+
mod_parts = cls.__module__.split(".")
156+
assert len(mod_parts) == 3
157+
assert mod_parts[:2] == ["pydra", "compose"]
158+
return mod_parts[2]
159+
154160
# The attribute containing the function/executable used to run the task
155161
_executor_name = None
156162

157163
# Class attributes
164+
TASK_CLASS_ATTRS = ("xor",)
158165
_xor: frozenset[frozenset[str | None]] = (
159166
frozenset()
160167
) # overwritten in derived classes
@@ -461,7 +468,7 @@ def _hash(self):
461468

462469
@property
463470
def _checksum(self):
464-
return f"{self._task_type}-{self._hash}"
471+
return f"{self._task_type()}-{self._hash}"
465472

466473
def _hash_changes(self):
467474
"""Detects any changes in the hashed values between the current inputs and the
@@ -610,7 +617,7 @@ def _check_resolved(self):
610617

611618
@register_serializer
612619
def bytes_repr_task(obj: Task, cache: Cache) -> ty.Iterator[bytes]:
613-
yield f"task[{obj._task_type}]:(".encode()
620+
yield f"task[{obj._task_type()}]:(".encode()
614621
for field in task_fields(obj):
615622
yield f"{field.name}=".encode()
616623
yield hash_single(getattr(obj, field.name), cache)

pydra/compose/python.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,6 @@ def _from_task(cls, job: "Job[PythonTask]") -> ty.Self:
231231
@attrs.define(kw_only=True, auto_attribs=False, eq=False, repr=False)
232232
class PythonTask(base.Task[PythonOutputsType]):
233233

234-
_task_type = "python"
235234
_executor_name = "function"
236235

237236
def _run(self, job: "Job[PythonTask]", rerun: bool = True) -> None:

pydra/compose/shell/task.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,6 @@ def append_args_converter(value: ty.Any) -> list[str]:
232232
@attrs.define(kw_only=True, auto_attribs=False, eq=False, repr=False)
233233
class ShellTask(base.Task[ShellOutputsType]):
234234

235-
_task_type = "shell"
236235
_executor_name = "executable"
237236

238237
BASE_NAMES = ["append_args"]

pydra/compose/workflow.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,6 @@ def _from_task(cls, job: "Job[WorkflowTask]") -> ty.Self:
352352
@attrs.define(kw_only=True, auto_attribs=False, eq=False, repr=False)
353353
class WorkflowTask(base.Task[WorkflowOutputsType]):
354354

355-
_task_type = "workflow"
356355
_executor_name = "constructor"
357356

358357
RESERVED_FIELD_NAMES = base.Task.RESERVED_FIELD_NAMES + ("construct",)

pydra/engine/workflow.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -251,11 +251,11 @@ def add(
251251
if (
252252
environment
253253
and not isinstance(environment, native.Environment)
254-
and task._task_type != "shell"
254+
and task._task_type() != "shell"
255255
):
256256
raise ValueError(
257257
"Environments can only be used with 'shell' tasks not "
258-
f"{task._task_type!r} tasks ({task!r})"
258+
f"{task._task_type()!r} tasks ({task!r})"
259259
)
260260
node = Node[OutputsType](
261261
name=name,

pydra/utils/general.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -609,13 +609,13 @@ def task_def_as_dict(task_def: "type[Task]") -> ty.Dict[str, ty.Any]:
609609
attrs.asdict(o, filter=_filter_defaults) for o in task_fields(task_def.Outputs)
610610
]
611611
dct = {
612-
"type": task_def._task_type,
612+
"type": task_def._task_type(),
613613
task_def._executor_name: executor,
614614
"name": task_def.__name__,
615615
"inputs": {d.pop("name"): d for d in input_dicts},
616616
"outputs": {d.pop("name"): d for d in output_dicts},
617-
"xor": task_def._xor,
618617
}
618+
dct.update({a: getattr(task_def, "_" + a) for a in task_def.TASK_CLASS_ATTRS})
619619

620620
return dct
621621

@@ -635,8 +635,8 @@ def task_def_from_dict(task_def_dict: dict[str, ty.Any]) -> type["Task"]:
635635
"""
636636
dct = copy(task_def_dict)
637637
task_type = dct.pop("type")
638-
compose_module = importlib.import_module(f"pydra.compose.{task_type}")
639-
return compose_module.define(dct.pop(compose_module.Task._executor_name), **dct)
638+
mod = importlib.import_module(f"pydra.compose.{task_type}")
639+
return mod.define(dct.pop(mod.Task._executor_name), **dct)
640640

641641

642642
def _filter_defaults(atr: attrs.Attribute, value: ty.Any) -> bool:

0 commit comments

Comments
 (0)