Skip to content

Commit d4ce1d3

Browse files
committed
fixed up toy environments example, all docs should build now
1 parent b0b45b9 commit d4ce1d3

File tree

5 files changed

+91
-38
lines changed

5 files changed

+91
-38
lines changed

new-docs/source/tutorial/6-workflow.ipynb

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -435,13 +435,36 @@
435435
"outputs": [],
436436
"source": [
437437
"import tempfile\n",
438+
"from pathlib import Path\n",
438439
"import numpy as np\n",
439440
"from fileformats.medimage import Nifti1\n",
440441
"import fileformats.medimage_mrtrix3 as mrtrix3\n",
441442
"from pydra.engine.environments import Docker\n",
442443
"from pydra.design import workflow, python\n",
443444
"from pydra.tasks.mrtrix3.v3_0 import MrConvert, MrThreshold\n",
444445
"\n",
446+
"MRTRIX2NUMPY_DTYPES = {\n",
447+
" \"Int8\": np.dtype(\"i1\"),\n",
448+
" \"UInt8\": np.dtype(\"u1\"),\n",
449+
" \"Int16LE\": np.dtype(\"<i2\"),\n",
450+
" \"Int16BE\": np.dtype(\">i2\"),\n",
451+
" \"UInt16LE\": np.dtype(\"<u2\"),\n",
452+
" \"UInt16BE\": np.dtype(\">u2\"),\n",
453+
" \"Int32LE\": np.dtype(\"<i4\"),\n",
454+
" \"Int32BE\": np.dtype(\">i4\"),\n",
455+
" \"UInt32LE\": np.dtype(\"<u4\"),\n",
456+
" \"UInt32BE\": np.dtype(\">u4\"),\n",
457+
" \"Float32LE\": np.dtype(\"<f4\"),\n",
458+
" \"Float32BE\": np.dtype(\">f4\"),\n",
459+
" \"Float64LE\": np.dtype(\"<f8\"),\n",
460+
" \"Float64BE\": np.dtype(\">f8\"),\n",
461+
" \"CFloat32LE\": np.dtype(\"<c8\"),\n",
462+
" \"CFloat32BE\": np.dtype(\">c8\"),\n",
463+
" \"CFloat64LE\": np.dtype(\"<c16\"),\n",
464+
" \"CFloat64BE\": np.dtype(\">c16\"),\n",
465+
"}\n",
466+
"\n",
467+
"\n",
445468
"@workflow.define(outputs=[\"out_image\"])\n",
446469
"def ToyMedianThreshold(in_image: Nifti1) -> mrtrix3.ImageFormat:\n",
447470
" \"\"\"A toy example workflow that\n",
@@ -457,23 +480,17 @@
457480
" )\n",
458481
"\n",
459482
" @python.define\n",
460-
" def SelectDataFile(in_file: mrtrix3.ImageHeader) -> mrtrix3.ImageDataFile:\n",
461-
" return in_file.data_file\n",
462-
"\n",
463-
" select_data = workflow.add(SelectDataFile(in_file=input_conversion.out_file))\n",
464-
"\n",
465-
" @python.define\n",
466-
" def Median(data_file: mrtrix3.ImageDataFile) -> float:\n",
467-
" data = np.load(data_file)\n",
483+
" def Median(mih: mrtrix3.ImageHeader) -> float:\n",
484+
" \"\"\"A bespoke function that reads the separate data file in the MRTrix3 image\n",
485+
" header format (i.e. .mih) and calculates the median value.\"\"\"\n",
486+
" dtype = MRTRIX2NUMPY_DTYPES[mih.metadata[\"datatype\"].strip()]\n",
487+
" data = np.frombuffer(Path.read_bytes(mih.data_file), dtype=dtype)\n",
468488
" return np.median(data)\n",
469489
"\n",
470-
" median = workflow.add(Median(data_file=select_data.out))\n",
490+
" median = workflow.add(Median(mih=input_conversion.out_file))\n",
471491
" threshold = workflow.add(\n",
472-
" MrThreshold(\n",
473-
" in_file=in_image,\n",
474-
" abs=median.out\n",
475-
" ), \n",
476-
" environment=Docker(\"mrtrix3/mrtrix3\", tag=\"\")\n",
492+
" MrThreshold(in_file=in_image, out_file=\"binary.mif\", abs=median.out),\n",
493+
" environment=Docker(\"mrtrix3/mrtrix3\", tag=\"latest\"),\n",
477494
" )\n",
478495
"\n",
479496
" output_conversion = workflow.add(\n",

new-docs/source/tutorial/tst.py

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,33 @@
11
import tempfile
2+
from pathlib import Path
23
import numpy as np
34
from fileformats.medimage import Nifti1
45
import fileformats.medimage_mrtrix3 as mrtrix3
56
from pydra.engine.environments import Docker
67
from pydra.design import workflow, python
78
from pydra.tasks.mrtrix3.v3_0 import MrConvert, MrThreshold
89

10+
MRTRIX2NUMPY_DTYPES = {
11+
"Int8": np.dtype("i1"),
12+
"UInt8": np.dtype("u1"),
13+
"Int16LE": np.dtype("<i2"),
14+
"Int16BE": np.dtype(">i2"),
15+
"UInt16LE": np.dtype("<u2"),
16+
"UInt16BE": np.dtype(">u2"),
17+
"Int32LE": np.dtype("<i4"),
18+
"Int32BE": np.dtype(">i4"),
19+
"UInt32LE": np.dtype("<u4"),
20+
"UInt32BE": np.dtype(">u4"),
21+
"Float32LE": np.dtype("<f4"),
22+
"Float32BE": np.dtype(">f4"),
23+
"Float64LE": np.dtype("<f8"),
24+
"Float64BE": np.dtype(">f8"),
25+
"CFloat32LE": np.dtype("<c8"),
26+
"CFloat32BE": np.dtype(">c8"),
27+
"CFloat64LE": np.dtype("<c16"),
28+
"CFloat64BE": np.dtype(">c16"),
29+
}
30+
931

1032
@workflow.define(outputs=["out_image"])
1133
def ToyMedianThreshold(in_image: Nifti1) -> mrtrix3.ImageFormat:
@@ -22,20 +44,17 @@ def ToyMedianThreshold(in_image: Nifti1) -> mrtrix3.ImageFormat:
2244
)
2345

2446
@python.define
25-
def SelectDataFile(in_file: mrtrix3.ImageHeader) -> mrtrix3.ImageDataFile:
26-
return in_file.data_file
27-
28-
select_data = workflow.add(SelectDataFile(in_file=input_conversion.out_file))
29-
30-
@python.define
31-
def Median(data_file: mrtrix3.ImageDataFile) -> float:
32-
data = np.load(data_file)
47+
def Median(mih: mrtrix3.ImageHeader) -> float:
48+
"""A bespoke function that reads the separate data file in the MRTrix3 image
49+
header format (i.e. .mih) and calculates the median value."""
50+
dtype = MRTRIX2NUMPY_DTYPES[mih.metadata["datatype"].strip()]
51+
data = np.frombuffer(Path.read_bytes(mih.data_file), dtype=dtype)
3352
return np.median(data)
3453

35-
median = workflow.add(Median(data_file=select_data.out))
54+
median = workflow.add(Median(mih=input_conversion.out_file))
3655
threshold = workflow.add(
37-
MrThreshold(in_file=in_image, abs=median.out),
38-
environment=Docker("mrtrix3/mrtrix3", tag=""),
56+
MrThreshold(in_file=in_image, out_file="binary.mif", abs=median.out),
57+
environment=Docker("mrtrix3/mrtrix3", tag="latest"),
3958
)
4059

4160
output_conversion = workflow.add(

pydra/engine/environments.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ class Docker(Container):
9494
def execute(self, task: "Task[ShellDef]") -> dict[str, ty.Any]:
9595
docker_img = f"{self.image}:{self.tag}"
9696
# mounting all input locations
97-
mounts, inputs_mod_root = task.definition._get_bindings(root=self.root)
97+
mounts, input_updates = task.definition._get_bindings(root=self.root)
9898

9999
docker_args = [
100100
"docker",
@@ -112,7 +112,11 @@ def execute(self, task: "Task[ShellDef]") -> dict[str, ty.Any]:
112112
keys = ["return_code", "stdout", "stderr"]
113113

114114
values = execute(
115-
docker_args + [docker_img] + task.definition._command_args(root=self.root),
115+
docker_args
116+
+ [docker_img]
117+
+ task.definition._command_args(
118+
root=self.root, input_updates=input_updates
119+
),
116120
)
117121
output = dict(zip(keys, values))
118122
if output["return_code"]:
@@ -129,7 +133,7 @@ class Singularity(Container):
129133
def execute(self, task: "Task[ShellDef]") -> dict[str, ty.Any]:
130134
singularity_img = f"{self.image}:{self.tag}"
131135
# mounting all input locations
132-
mounts, inputs_mod_root = task.definition._get_bindings(root=self.root)
136+
mounts, input_updates = task.definition._get_bindings(root=self.root)
133137

134138
# todo adding xargsy etc
135139
singularity_args = [
@@ -150,7 +154,9 @@ def execute(self, task: "Task[ShellDef]") -> dict[str, ty.Any]:
150154
values = execute(
151155
singularity_args
152156
+ [singularity_img]
153-
+ task.definition._command_args(root=self.root),
157+
+ task.definition._command_args(
158+
root=self.root, input_updates=input_updates
159+
),
154160
)
155161
output = dict(zip(keys, values))
156162
if output["return_code"]:

pydra/engine/specs.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import attrs
1717
from attrs.converters import default_if_none
1818
import cloudpickle as cp
19-
from fileformats.core import FileSet
19+
from fileformats.generic import FileSet
2020
from pydra.utils.messenger import AuditFlag, Messenger
2121
from pydra.utils.typing import TypeParser, is_optional, non_optional_type
2222
from .helpers import (
@@ -1224,7 +1224,7 @@ def _get_bindings(
12241224
Mapping from paths in the host environment to the target environment
12251225
"""
12261226
bindings: dict[str, tuple[str, str]] = {}
1227-
inputs_mod_root: dict[str, tuple[Path, ...]] = {}
1227+
input_updates: dict[str, tuple[Path, ...]] = {}
12281228
if root is None:
12291229
return bindings
12301230
fld: Arg
@@ -1233,10 +1233,10 @@ def _get_bindings(
12331233
fileset: FileSet | None = self[fld.name]
12341234
if fileset is None:
12351235
continue
1236-
if not isinstance(fileset, FileSet):
1236+
if not isinstance(fileset, (os.PathLike, FileSet)):
12371237
raise NotImplementedError(
1238-
"Generating environment bindings for nested FileSets are not "
1239-
"yet supported"
1238+
"Generating environment bindings for nested FileSets is not "
1239+
"supported yet"
12401240
)
12411241
copy = fld.copy_mode == FileSet.CopyMode.copy
12421242

@@ -1245,11 +1245,17 @@ def _get_bindings(
12451245
# Default to mounting paths as read-only, but respect existing modes
12461246
bindings[host_path] = (env_path, "rw" if copy else "ro")
12471247

1248-
# Provide in-container paths without type-checking
1249-
inputs_mod_root[fld.name] = tuple(
1250-
env_path / rel for rel in fileset.relative_fspaths
1248+
# Provide updated in-container paths to the command to be run. If a
1249+
# fs-object, which resolves to a single path, just pass in the name of
1250+
# that path relative to the location in the mount point in the container.
1251+
# If it is a more complex file-set with multiple paths, then it is converted
1252+
# into a tuple of paths relative to the base of the fileset.
1253+
input_updates[fld.name] = (
1254+
env_path / fileset.name
1255+
if isinstance(fileset, os.PathLike)
1256+
else tuple(env_path / rel for rel in fileset.relative_fspaths)
12511257
)
1252-
return bindings, inputs_mod_root
1258+
return bindings, input_updates
12531259

12541260
def _generated_output_names(self, stdout: str, stderr: str):
12551261
"""Returns a list of all outputs that will be generated by the task.

pydra/utils/typing.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,11 @@ def __call__(self, obj: ty.Any) -> T:
261261
try:
262262
coerced = self.coerce(obj)
263263
except TypeError as e:
264+
if obj is None:
265+
raise TypeError(
266+
f"Mandatory field{self.label_str} of type {self.tp} was not "
267+
"provided a value (i.e. a value that wasn't None) "
268+
) from None
264269
raise TypeError(
265270
f"Incorrect type for field{self.label_str}: {obj!r} is not of type "
266271
f"{self.tp} (and cannot be coerced to it)"

0 commit comments

Comments
 (0)