Skip to content

Commit 98718c8

Browse files
committed
fixed broken _task_type and _from_job
1 parent 63969ff commit 98718c8

File tree

5 files changed

+14
-8
lines changed

5 files changed

+14
-8
lines changed

pydra/compose/base/task.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -152,9 +152,15 @@ class Task(ty.Generic[OutputsType]):
152152
# Task type to be overridden in derived classes
153153
@classmethod
154154
def _task_type(cls) -> str:
155-
mod_parts = cls.__module__.split(".")
156-
assert mod_parts[:2] == ["pydra", "compose"]
157-
return mod_parts[2]
155+
for base in cls.__mro__:
156+
parts = base.__module__.split(".")
157+
if parts[:2] == ["pydra", "compose"]:
158+
return parts[2]
159+
raise RuntimeError(
160+
f"Cannot determine task type for {cls.__name__} in module {cls.__module__} "
161+
"because none of its base classes are in the pydra.compose namespace:\n"
162+
+ "\n".join(f"{b.__name__!r} in {b.__module__!r}" for b in cls.__mro__)
163+
)
158164

159165
# The attribute containing the function/executable used to run the task
160166
_executor_name = None

pydra/compose/python.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ def _from_job(cls, job: "Job[PythonTask]") -> ty.Self:
219219
outputs : Outputs
220220
The outputs of the job in dataclass
221221
"""
222-
outputs = super()._from_task(job)
222+
outputs = super()._from_job(job)
223223
for name, val in job.return_values.items():
224224
setattr(outputs, name, val)
225225
return outputs

pydra/compose/shell/task.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def _from_job(cls, job: "Job[Task]") -> ty.Self:
7878
outputs : Outputs
7979
The outputs of the shell process
8080
"""
81-
outputs = super()._from_task(job)
81+
outputs = super()._from_job(job)
8282
fld: field.out
8383
for fld in task_fields(cls):
8484
if fld.name in ["return_code", "stdout", "stderr"]:

pydra/compose/workflow.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,7 @@ def _from_job(cls, job: "Job[WorkflowTask]") -> ty.Self:
340340
values[name] = val_out
341341

342342
# Set the values in the outputs object
343-
outputs = super()._from_task(job)
343+
outputs = super()._from_job(job)
344344
outputs = attrs.evolve(outputs, **values)
345345
outputs._cache_dir = job.cache_dir
346346
return outputs

pydra/engine/job.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,7 @@ def run(self, rerun: bool = False):
331331
try:
332332
self.audit.monitor()
333333
self.task._run(self, rerun)
334-
result.outputs = self.task.Outputs._from_task(self)
334+
result.outputs = self.task.Outputs._from_job(self)
335335
except Exception:
336336
etype, eval, etr = sys.exc_info()
337337
traceback = format_exception(etype, eval, etr)
@@ -385,7 +385,7 @@ async def run_async(self, rerun: bool = False) -> Result:
385385
try:
386386
self.audit.monitor()
387387
await self.task._run_async(self, rerun)
388-
result.outputs = self.task.Outputs._from_task(self)
388+
result.outputs = self.task.Outputs._from_job(self)
389389
except Exception:
390390
etype, eval, etr = sys.exc_info()
391391
traceback = format_exception(etype, eval, etr)

0 commit comments

Comments
 (0)