Skip to content

Commit b741563

Browse files
committed
debugging test_shelltask
1 parent a1e3268 commit b741563

File tree

8 files changed

+177
-158
lines changed

8 files changed

+177
-158
lines changed

pydra/design/base.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from typing import Self
99
import attrs.validators
1010
from attrs.converters import default_if_none
11-
from fileformats.generic import File
11+
from fileformats.generic import File, FileSet
1212
from pydra.utils.typing import TypeParser, is_optional, is_fileset_or_union, is_type
1313
from pydra.engine.helpers import (
1414
from_list_if_single,
@@ -59,6 +59,8 @@ def convert_default_value(value: ty.Any, self_: "Field") -> ty.Any:
5959
return value
6060
if self_.type is ty.Callable and isinstance(value, ty.Callable):
6161
return value
62+
if isinstance(self_, Out) and TypeParser.contains_type(FileSet, self_.type):
63+
return value
6264
return TypeParser[self_.type](self_.type, label=self_.name)(value)
6365

6466

pydra/design/shell.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from __future__ import annotations
44
import typing as ty
55
import re
6+
import glob
67
from collections import defaultdict
78
import inspect
89
from copy import copy
@@ -25,6 +26,7 @@
2526
from pydra.utils.typing import (
2627
is_fileset_or_union,
2728
MultiInputObj,
29+
TypeParser,
2830
is_optional,
2931
optional_type,
3032
)
@@ -439,6 +441,16 @@ def make(
439441
if inpt.position is None:
440442
inpt.position = position_stack.pop(0)
441443

444+
# Convert string default values to callables that glob the files in the cwd
445+
for outpt in parsed_outputs.values():
446+
if (
447+
isinstance(outpt, out)
448+
and isinstance(outpt.default, str)
449+
and TypeParser.contains_type(generic.FileSet, outpt.type)
450+
):
451+
outpt.callable = GlobCallable(outpt.default)
452+
outpt.default = NO_DEFAULT
453+
442454
defn = make_task_def(
443455
ShellDef,
444456
ShellOutputs,
@@ -782,3 +794,16 @@ class _InputPassThrough:
782794

783795
def __call__(self, inputs: ShellDef) -> ty.Any:
784796
return getattr(inputs, self.name)
797+
798+
799+
class GlobCallable:
800+
"""Callable that can be used to glob files"""
801+
802+
def __init__(self, pattern: str):
803+
self.pattern = pattern
804+
805+
def __call__(self) -> generic.FileSet:
806+
matches = glob.glob(self.pattern)
807+
if not matches:
808+
raise FileNotFoundError(f"No files found matching pattern: {self.pattern}")
809+
return matches

pydra/engine/core.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -227,15 +227,9 @@ def uid(self):
227227

228228
@property
229229
def output_names(self):
230-
"""Get the names of the outputs from the task's output_spec
231-
(not everything has to be generated, see _generated_output_names).
232-
"""
230+
"""Get the names of the outputs from the task's output_spec"""
233231
return [f.name for f in attr.fields(self.definition.Outputs)]
234232

235-
@property
236-
def _generated_output_names(self):
237-
return self.output_names
238-
239233
@property
240234
def can_resume(self):
241235
"""Whether the task accepts checkpoint-restart."""
@@ -286,7 +280,7 @@ def inputs(self) -> dict[str, ty.Any]:
286280
map_copyfiles[name] = copied_value
287281
self._inputs.update(
288282
template_update(
289-
self.definition, self.output_dir, map_copyfiles=map_copyfiles
283+
self.definition, output_dir=self.output_dir, map_copyfiles=map_copyfiles
290284
)
291285
)
292286
return self._inputs

pydra/engine/environments.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import typing as ty
22
import os
3+
from copy import copy
34
from .helpers import execute
45
from pathlib import Path
56
from fileformats.generic import FileSet
@@ -128,14 +129,14 @@ def get_bindings(
128129
f"No support for generating bindings for {type(fileset)} types "
129130
f"({fileset})"
130131
)
131-
copy = fld.copy_mode == FileSet.CopyMode.copy
132+
copy_file = fld.copy_mode == FileSet.CopyMode.copy
132133

133134
host_path, env_path = fileset.parent, Path(f"{root}{fileset.parent}")
134135

135136
# Default to mounting paths as read-only, but respect existing modes
136137
bindings[host_path] = (
137138
env_path,
138-
"rw" if copy or isinstance(fld, shell.outarg) else "ro",
139+
"rw" if copy_file or isinstance(fld, shell.outarg) else "ro",
139140
)
140141

141142
# Provide updated in-container paths to the command to be run. If a

pydra/engine/helpers_file.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def template_update_single(
159159
values: dict[str, ty.Any] = None,
160160
output_dir: Path | None = None,
161161
spec_type: str = "input",
162-
) -> Path | None:
162+
) -> Path | list[Path | None] | None:
163163
"""Update a single template from the input_spec or output_spec
164164
based on the value from inputs_dict
165165
(checking the types of the fields, that have "output_file_template)"
@@ -196,18 +196,19 @@ def template_update_single(
196196
return None
197197
# inputs_dict[field.name] is True or spec_type is output
198198
value = _template_formatting(field, definition, values)
199-
# changing path so it is in the output_dir
200199
if output_dir and value is not None:
200+
# changing path so it is in the output_dir
201201
# should be converted to str, it is also used for input fields that should be str
202202
if type(value) is list:
203-
return [output_dir / val.name for val in value]
203+
value = [output_dir / val.name for val in value]
204204
else:
205-
return output_dir / value.name
206-
else:
207-
return None
205+
value = output_dir / value.name
206+
return value
208207

209208

210-
def _template_formatting(field, definition, values):
209+
def _template_formatting(
210+
field: "shell.arg", definition: "ShellDef", values: dict[str, ty.Any]
211+
) -> Path | list[Path] | None:
211212
"""Formatting the field template based on the values from inputs.
212213
Taking into account that the field with a template can be a MultiOutputFile
213214
and the field values needed in the template can be a list -
@@ -226,7 +227,7 @@ def _template_formatting(field, definition, values):
226227
227228
Returns
228229
-------
229-
formatted : str or list
230+
formatted : Path or list[Path | None] or None
230231
formatted template
231232
"""
232233
# if a template is a function it has to be run first with the inputs as the only arg
@@ -237,6 +238,8 @@ def _template_formatting(field, definition, values):
237238
# as default, we assume that keep_extension is True
238239
if isinstance(template, (tuple, list)):
239240
formatted = [_single_template_formatting(field, t, values) for t in template]
241+
if any([val is None for val in formatted]):
242+
return None
240243
else:
241244
assert isinstance(template, str)
242245
formatted = _single_template_formatting(field, template, values)

pydra/engine/specs.py

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -898,12 +898,27 @@ def _from_task(cls, task: "Task[ShellDef]") -> Self:
898898
setattr(outputs, fld.name, None)
899899
else:
900900
raise ValueError(
901-
f"file system path(s) provided to mandatory field {fld.name!r},"
902-
f"{resolved_value}, does not exist, this is likely due to an "
901+
f"file system path(s) provided to mandatory field {fld.name!r}, "
902+
f"'{resolved_value}', does not exist, this is likely due to an "
903903
f"error in the {task.name!r} task"
904904
)
905905
return outputs
906906

907+
# @classmethod
908+
# def _from_defaults(cls) -> Self:
909+
# """Create an output object from the default values of the fields"""
910+
# defaults = {}
911+
# for field in attrs_fields(cls):
912+
# if isinstance(field.default, attrs.Factory):
913+
# defaults[field.name] = field.default.factory()
914+
# elif TypeParser.contains_type(FileSet, field.type):
915+
# # Will be set by the templating code
916+
# defaults[field.name] = attrs.NOTHING
917+
# else:
918+
# defaults[field.name] = field.default
919+
920+
# return cls(**defaults)
921+
907922
@classmethod
908923
def _resolve_default_value(cls, fld: shell.out, output_dir: Path) -> ty.Any:
909924
"""Resolve path and glob expr default values relative to the output dir"""
@@ -991,20 +1006,24 @@ def _resolve_value(
9911006
call_args_val[argnm] = fld
9921007
elif argnm == "output_dir":
9931008
call_args_val[argnm] = task.output_dir
1009+
elif argnm == "executable":
1010+
call_args_val[argnm] = task.definition.executable
9941011
elif argnm == "inputs":
9951012
call_args_val[argnm] = task.inputs
9961013
elif argnm == "stdout":
9971014
call_args_val[argnm] = task.return_values["stdout"]
9981015
elif argnm == "stderr":
9991016
call_args_val[argnm] = task.return_values["stderr"]
1017+
elif argnm == "self":
1018+
pass # If the callable is a class
10001019
else:
10011020
try:
10021021
call_args_val[argnm] = task.inputs[argnm]
10031022
except KeyError as e:
10041023
e.add_note(
1005-
f"arguments of the callable function from {fld.name} "
1024+
f"arguments of the callable function from {fld.name!r} "
10061025
f"has to be in inputs or be field or output_dir, "
1007-
f"but {argnm} is used"
1026+
f"but {argnm!r} is used"
10081027
)
10091028
raise
10101029
return callable_(**call_args_val)
@@ -1040,7 +1059,7 @@ def cmdline(self) -> str:
10401059
the current working directory."""
10411060
# Skip the executable, which can be a multi-part command, e.g. 'docker run'.
10421061
values = attrs_values(self)
1043-
values.update(template_update(self))
1062+
values.update(template_update(self, output_dir=Path.cwd()))
10441063
cmd_args = self._command_args(values=values)
10451064
cmdline = cmd_args[0]
10461065
for arg in cmd_args[1:]:
@@ -1221,22 +1240,6 @@ def _format_arg(self, field: shell.arg, values: dict[str, ty.Any]) -> list[str]:
12211240
cmd_el_str = ""
12221241
return split_cmd(cmd_el_str)
12231242

1224-
def _generated_output_names(self, stdout: str, stderr: str):
1225-
"""Returns a list of all outputs that will be generated by the task.
1226-
Takes into account the task input and the requires list for the output fields.
1227-
TODO: should be in all Output specs?
1228-
"""
1229-
# checking the input (if all mandatory fields are provided, etc.)
1230-
self._check_rules()
1231-
output_names = ["return_code", "stdout", "stderr"]
1232-
for fld in list_fields(self):
1233-
# assuming that field should have either default or metadata, but not both
1234-
if is_set(fld.default):
1235-
output_names.append(fld.name)
1236-
elif is_set(self.Outputs._resolve_output_value(fld, stdout, stderr)):
1237-
output_names.append(fld.name)
1238-
return output_names
1239-
12401243
def _rule_violations(self) -> list[str]:
12411244

12421245
errors = super()._rule_violations()

0 commit comments

Comments
 (0)