Skip to content

Commit 16fa35a

Browse files
committed
added show_workflow utils function and use it in the tutorial
1 parent 52a5187 commit 16fa35a

File tree

5 files changed

+124
-30
lines changed

5 files changed

+124
-30
lines changed

docs/source/tutorial/6-workflow.ipynb

Lines changed: 51 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
"outputs": [],
3030
"source": [
3131
"from pydra.compose import workflow, python\n",
32+
"from pydra.utils import show_workflow, print_help\n",
3233
"\n",
3334
"\n",
3435
"# Example python tasks\n",
@@ -59,7 +60,10 @@
5960
"def BasicWorkflow(a, b):\n",
6061
" add = workflow.add(Add(a=a, b=b))\n",
6162
" mul = workflow.add(Mul(a=add.out, b=b))\n",
62-
" return mul.out"
63+
" return mul.out\n",
64+
"\n",
65+
"\n",
66+
"show_workflow(BasicWorkflow, figsize=(2, 2.5))"
6367
]
6468
},
6569
{
@@ -109,7 +113,10 @@
109113
" )(in_video=add_watermark.out_video, width=1280, height=720)\n",
110114
" ).out_video\n",
111115
"\n",
112-
" return output_video # test implicit detection of output name"
116+
" return output_video # test implicit detection of output name\n",
117+
"\n",
118+
"\n",
119+
"show_workflow(ShellWorkflow, figsize=(2.5, 3))"
113120
]
114121
},
115122
{
@@ -136,10 +143,13 @@
136143
"def SplitWorkflow(a: list[int], b: list[float]) -> list[float]:\n",
137144
" # Multiply over all combinations of the elements of a and b, then combine the results\n",
138145
" # for each a element into a list over each b element\n",
139-
" mul = workflow.add(Mul().split(x=a, y=b).combine(\"x\"))\n",
146+
" mul = workflow.add(Mul().split(a=a, b=b).combine(\"a\"))\n",
140147
" # Sume the multiplications across all all b elements for each a element\n",
141148
" sum = workflow.add(Sum(x=mul.out))\n",
142-
" return sum.out"
149+
" return sum.out\n",
150+
"\n",
151+
"\n",
152+
"show_workflow(SplitWorkflow, figsize=(2, 2.5))"
143153
]
144154
},
145155
{
@@ -157,10 +167,13 @@
157167
"source": [
158168
"@workflow.define\n",
159169
"def SplitThenCombineWorkflow(a: list[int], b: list[float], c: float) -> list[float]:\n",
160-
" mul = workflow.add(Mul().split(x=a, y=b))\n",
161-
" add = workflow.add(Add(x=mul.out, y=c).combine(\"Mul.x\"))\n",
170+
" mul = workflow.add(Mul().split(a=a, b=b))\n",
171+
" add = workflow.add(Add(a=mul.out, b=c).combine(\"Mul.a\"))\n",
162172
" sum = workflow.add(Sum(x=add.out))\n",
163-
" return sum.out"
173+
" return sum.out\n",
174+
"\n",
175+
"\n",
176+
"show_workflow(SplitThenCombineWorkflow, figsize=(3, 3.5))"
164177
]
165178
},
166179
{
@@ -214,7 +227,10 @@
214227
" )(in_video=handbrake_input, width=1280, height=720)\n",
215228
" ).out_video\n",
216229
"\n",
217-
" return output_video # test implicit detection of output name"
230+
" return output_video # test implicit detection of output name\n",
231+
"\n",
232+
"\n",
233+
"show_workflow(ConditionalWorkflow(watermark_dims=(10, 10)), figsize=(2.5, 3))"
218234
]
219235
},
220236
{
@@ -238,15 +254,18 @@
238254
"\n",
239255
"@workflow.define\n",
240256
"def RecursiveNestedWorkflow(a: float, depth: int) -> float:\n",
241-
" add = workflow.add(Add(x=a, y=1))\n",
257+
" add = workflow.add(Add(a=a, b=1))\n",
242258
" decrement_depth = workflow.add(Subtract(x=depth, y=1))\n",
243259
" if depth > 0:\n",
244260
" out_node = workflow.add(\n",
245261
" RecursiveNestedWorkflow(a=add.out, depth=decrement_depth.out)\n",
246262
" )\n",
247263
" else:\n",
248264
" out_node = add\n",
249-
" return out_node.out"
265+
" return out_node.out\n",
266+
"\n",
267+
"\n",
268+
"print_help(RecursiveNestedWorkflow)"
250269
]
251270
},
252271
{
@@ -327,7 +346,10 @@
327346
" Mp4Handbrake(in_video=add_watermark.out_video, width=1280, height=720),\n",
328347
" ) # The type of the input video is now correct\n",
329348
"\n",
330-
" return handbrake.output_video"
349+
" return handbrake.out_video\n",
350+
"\n",
351+
"\n",
352+
"show_workflow(TypeErrorWorkflow, plot_type=\"detailed\")"
331353
]
332354
},
333355
{
@@ -372,14 +394,17 @@
372394
"\n",
373395
" wf = workflow.this()\n",
374396
"\n",
375-
" add = wf.add(Add(x=a, y=b), name=\"addition\")\n",
376-
" mul = wf.add(python.define(Mul, outputs={\"out\": float})(x=add.z, y=b))\n",
377-
" divide = wf.add(Divide(x=wf[\"addition\"].lzout.z, y=mul.out), name=\"division\")\n",
397+
" add = wf.add(Add(a=a, b=b), name=\"addition\")\n",
398+
" mul = wf.add(Mul(a=add.out, b=b))\n",
399+
" divide = wf.add(Divide(x=wf[\"addition\"].lzout.out, y=mul.out), name=\"division\")\n",
378400
"\n",
379401
" # Alter one of the inputs to a node after it has been initialised\n",
380-
" wf[\"Mul\"].inputs.y *= 2\n",
402+
" wf[\"Mul\"].inputs.b *= 2\n",
381403
"\n",
382-
" return mul.out, divide.divided"
404+
" return mul.out, divide.divided\n",
405+
"\n",
406+
"\n",
407+
"show_workflow(DirectAccesWorkflow(b=1), plot_type=\"detailed\")"
383408
]
384409
},
385410
{
@@ -410,16 +435,19 @@
410435
"\n",
411436
" wf = workflow.this()\n",
412437
"\n",
413-
" add = wf.add(Add(x=a, y=b), name=\"addition\")\n",
414-
" mul = wf.add(python.define(Mul, outputs={\"out\": float})(x=add.z, y=b))\n",
415-
" divide = wf.add(Divide(x=wf[\"addition\"].lzout.z, y=mul.out), name=\"division\")\n",
438+
" add = wf.add(Add(a=a, b=b), name=\"addition\")\n",
439+
" mul = wf.add(Mul(a=add.out, b=b))\n",
440+
" divide = wf.add(Divide(x=wf[\"addition\"].lzout.out, y=mul.out), name=\"division\")\n",
416441
"\n",
417442
" # Alter one of the inputs to a node after it has been initialised\n",
418-
" wf[\"Mul\"].inputs.y *= 2\n",
443+
" wf[\"Mul\"].inputs.b *= 2\n",
419444
"\n",
420445
" # Set the outputs of the workflow directly\n",
421446
" wf.outputs.out1 = mul.out\n",
422-
" wf.outputs.out2 = divide.divided"
447+
" wf.outputs.out2 = divide.divided\n",
448+
"\n",
449+
"\n",
450+
"show_workflow(SetOutputsOfWorkflow(b=3), plot_type=\"detailed\")"
423451
]
424452
},
425453
{
@@ -514,13 +542,13 @@
514542
" return output_conversion.out_file\n",
515543
"\n",
516544
"\n",
517-
"test_dir = tempfile.mkdtemp()\n",
545+
"test_dir = Path(tempfile.mkdtemp())\n",
518546
"\n",
519547
"nifti_file = Nifti1.sample(test_dir, seed=0)\n",
520548
"\n",
521549
"wf = ToyMedianThreshold(in_image=nifti_file)\n",
522550
"\n",
523-
"outputs = wf()\n",
551+
"outputs = wf(cache_dir=test_dir / \"cache\")\n",
524552
"\n",
525553
"print(outputs)"
526554
]

docs/source/tutorial/7-canonical-form.ipynb

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -230,13 +230,16 @@
230230
" return mul.out\n",
231231
"\n",
232232
" class Outputs(workflow.Outputs):\n",
233-
" out: float"
233+
" out: float\n",
234+
"\n",
235+
"\n",
236+
"print_help(CanonicalWorkflowTask)"
234237
]
235238
}
236239
],
237240
"metadata": {
238241
"kernelspec": {
239-
"display_name": "wf12",
242+
"display_name": "wf13",
240243
"language": "python",
241244
"name": "python3"
242245
},
@@ -250,7 +253,7 @@
250253
"name": "python",
251254
"nbconvert_exporter": "python",
252255
"pygments_lexer": "ipython3",
253-
"version": "3.12.5"
256+
"version": "3.13.1"
254257
}
255258
},
256259
"nbformat": 4,

pydra/utils/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
task_fields,
33
fields_dict,
44
plot_workflow,
5+
show_workflow,
56
task_help,
67
print_help,
78
)
@@ -11,6 +12,7 @@
1112
"__version__",
1213
"task_fields",
1314
"plot_workflow",
15+
"show_workflow",
1416
"task_help",
1517
"print_help",
1618
"fields_dict",

pydra/utils/general.py

Lines changed: 64 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import re
88
import attrs
99
import ast
10+
import tempfile
1011
import importlib
1112
import types
1213
import sysconfig
@@ -204,15 +205,20 @@ def plot_workflow(
204205
plot_type: str = "simple",
205206
export: ty.Sequence[str] | None = None,
206207
name: str | None = None,
207-
output_dir: Path | None = None,
208-
lazy: ty.Sequence[str] | ty.Set[str] = (),
209-
):
208+
lazy: ty.Sequence[str] | ty.Set[str] | None = None,
209+
) -> Path | tuple[Path, list[Path]]:
210210
"""creating a graph - dotfile and optionally exporting to other formats"""
211211
from pydra.engine.workflow import Workflow
212212

213+
if inspect.isclass(workflow_task):
214+
workflow_task = workflow_task()
215+
213216
# Create output directory
214217
out_dir.mkdir(parents=True, exist_ok=True)
215218

219+
if lazy is None:
220+
lazy = [n for n, v in attrs_values(workflow_task).items() if v is attrs.NOTHING]
221+
216222
# Construct the workflow object with all of the fields lazy
217223
wf = Workflow.construct(workflow_task, lazy=lazy)
218224

@@ -245,6 +251,61 @@ def plot_workflow(
245251
return dotfile, formatted_dot
246252

247253

254+
def show_workflow(
255+
workflow_task: "workflow.Task",
256+
plot_type: str = "simple",
257+
lazy: ty.Sequence[str] | ty.Set[str] | None = None,
258+
use_lib: str | None = None,
259+
figsize: tuple[int, int] | None = None,
260+
**kwargs,
261+
) -> None:
262+
"""creating a graph and showing it"""
263+
out_dir = Path(tempfile.mkdtemp())
264+
png_graph = plot_workflow(
265+
workflow_task,
266+
out_dir=out_dir,
267+
plot_type=plot_type,
268+
export="png",
269+
lazy=lazy,
270+
)[1][0]
271+
272+
if use_lib in ("matplotlib", None):
273+
try:
274+
import matplotlib.pyplot as plt
275+
from matplotlib.image import imread
276+
except ImportError:
277+
if use_lib == "matplotlib":
278+
raise ImportError(
279+
"Please install either matplotlib to display the workflow image."
280+
)
281+
else:
282+
use_lib = "matplotlib"
283+
284+
# Read the image
285+
img = imread(png_graph)
286+
287+
if figsize is not None:
288+
plt.figure(figsize=figsize)
289+
# Display the image
290+
plt.imshow(img, **kwargs)
291+
plt.axis("off")
292+
plt.show()
293+
294+
if use_lib in ("PIL", None):
295+
try:
296+
from PIL import Image
297+
except ImportError:
298+
msg = " or matplotlib" if use_lib is None else ""
299+
raise ImportError(
300+
f"Please install either Pillow{msg} to display the workflow image."
301+
)
302+
# Open the PNG image
303+
img = Image.open(png_graph)
304+
305+
# Display the image
306+
img.show(**kwargs)
307+
308+
248309
def attrs_fields(task, exclude_names=()) -> list[attrs.Attribute]:
249310
"""Get the fields of a task, excluding some names."""
250311
return [field for field in task.__attrs_attrs__ if field.name not in exclude_names]

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ classifiers = [
3737
dynamic = ["version"]
3838

3939
[project.optional-dependencies]
40-
dev = ["black", "pre-commit", "pydra[test]"]
40+
dev = ["black", "pre-commit", "pydra[test]", "matplotlib"]
4141
doc = [
4242
"fileformats-extras >= v0.15.0a6",
4343
"fileformats-medimage >= v0.10.0a2",

0 commit comments

Comments
 (0)