Skip to content

Commit 099b0f4

Browse files
committed
fixing up errors in test_shell
1 parent 84a8d3a commit 099b0f4

File tree

4 files changed

+42
-18
lines changed

4 files changed

+42
-18
lines changed

pydra/design/shell.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,12 @@
2222
make_task_def,
2323
NO_DEFAULT,
2424
)
25-
from pydra.utils.typing import is_fileset_or_union, MultiInputObj
25+
from pydra.utils.typing import (
26+
is_fileset_or_union,
27+
MultiInputObj,
28+
is_optional,
29+
optional_type,
30+
)
2631

2732
if ty.TYPE_CHECKING:
2833
from pydra.engine.specs import ShellDef
@@ -94,11 +99,15 @@ class arg(Arg):
9499

95100
argstr: str | None = ""
96101
position: int | None = None
97-
sep: str | None = attrs.field(default=None)
102+
sep: str | None = attrs.field()
98103
allowed_values: list | None = None
99104
container_path: bool = False # IS THIS STILL USED??
100105
formatter: ty.Callable | None = None
101106

107+
@sep.default
108+
def _sep_default(self):
109+
return " " if self.type is tuple or ty.get_origin(self.type) is tuple else None
110+
102111
@sep.validator
103112
def _validate_sep(self, _, sep):
104113
if self.type is ty.Any:
@@ -107,7 +116,10 @@ def _validate_sep(self, _, sep):
107116
tp = ty.get_args(self.type)[0]
108117
else:
109118
tp = self.type
119+
if is_optional(tp):
120+
tp = optional_type(tp)
110121
origin = ty.get_origin(tp) or tp
122+
111123
if (
112124
inspect.isclass(origin)
113125
and issubclass(origin, ty.Sequence)

pydra/design/tests/test_shell.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ def test_interface_template_more_complex():
162162

163163
Cp = shell.define(
164164
(
165-
"cp <in_fs_objects:fs-object,...> <out|out_dir:directory> "
165+
"cp <in_fs_objects:fs-object+> <out|out_dir:directory> "
166166
"-R<recursive> "
167167
"--text-arg <text_arg?> "
168168
"--int-arg <int_arg:int?> "
@@ -187,7 +187,9 @@ def test_interface_template_more_complex():
187187
help=shell.EXECUTABLE_HELP_STRING,
188188
),
189189
shell.arg(
190-
name="in_fs_objects", type=MultiInputObj[FsObject], position=1, sep=" "
190+
name="in_fs_objects",
191+
type=MultiInputObj[FsObject],
192+
position=1,
191193
),
192194
output,
193195
shell.arg(name="recursive", argstr="-R", type=bool, default=False, position=3),
@@ -210,6 +212,7 @@ def test_interface_template_more_complex():
210212
argstr="--tuple-arg",
211213
type=tuple[int, str] | None,
212214
default=None,
215+
sep=" ",
213216
position=6,
214217
),
215218
ShellDef.additional_args,
@@ -245,7 +248,7 @@ def test_interface_template_with_overrides_and_optionals():
245248

246249
Cp = shell.define(
247250
(
248-
"cp <in_fs_objects:fs-object,...> <out|out_dir:directory> <out|out_file:file?> "
251+
"cp <in_fs_objects:fs-object+> <out|out_dir:directory> <out|out_file:file?> "
249252
"-R<recursive> "
250253
"--text-arg <text_arg> "
251254
"--int-arg <int_arg:int?> "
@@ -284,7 +287,9 @@ def test_interface_template_with_overrides_and_optionals():
284287
help=shell.EXECUTABLE_HELP_STRING,
285288
),
286289
shell.arg(
287-
name="in_fs_objects", type=MultiInputObj[FsObject], position=1, sep=" "
290+
name="in_fs_objects",
291+
type=MultiInputObj[FsObject],
292+
position=1,
288293
),
289294
shell.arg(
290295
name="recursive",
@@ -306,6 +311,7 @@ def test_interface_template_with_overrides_and_optionals():
306311
name="tuple_arg",
307312
argstr="--tuple-arg",
308313
type=tuple[int, str],
314+
sep=" ",
309315
position=5,
310316
),
311317
] + outargs + [ShellDef.additional_args]
@@ -332,7 +338,7 @@ def test_interface_template_with_defaults():
332338

333339
Cp = shell.define(
334340
(
335-
"cp <in_fs_objects:fs-object,...> <out|out_dir:directory> "
341+
"cp <in_fs_objects:fs-object+> <out|out_dir:directory> "
336342
"-R<recursive=True> "
337343
"--text-arg <text_arg='foo'> "
338344
"--int-arg <int_arg:int=99> "
@@ -357,7 +363,9 @@ def test_interface_template_with_defaults():
357363
help=shell.EXECUTABLE_HELP_STRING,
358364
),
359365
shell.arg(
360-
name="in_fs_objects", type=MultiInputObj[FsObject], position=1, sep=" "
366+
name="in_fs_objects",
367+
type=MultiInputObj[FsObject],
368+
position=1,
361369
),
362370
output,
363371
shell.arg(name="recursive", argstr="-R", type=bool, default=True, position=3),
@@ -400,7 +408,7 @@ def test_interface_template_with_type_overrides():
400408

401409
Cp = shell.define(
402410
(
403-
"cp <in_fs_objects:fs-object,...> <out|out_dir:directory> "
411+
"cp <in_fs_objects:fs-object+> <out|out_dir:directory> "
404412
"-R<recursive> "
405413
"--text-arg <text_arg> "
406414
"--int-arg <int_arg> "
@@ -426,7 +434,9 @@ def test_interface_template_with_type_overrides():
426434
help=shell.EXECUTABLE_HELP_STRING,
427435
),
428436
shell.arg(
429-
name="in_fs_objects", type=MultiInputObj[FsObject], position=1, sep=" "
437+
name="in_fs_objects",
438+
type=MultiInputObj[FsObject],
439+
position=1,
430440
),
431441
output,
432442
shell.arg(name="recursive", argstr="-R", type=bool, default=False, position=3),

pydra/engine/specs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import cloudpickle as cp
1919
from fileformats.generic import FileSet
2020
from pydra.utils.messenger import AuditFlag, Messenger
21-
from pydra.utils.typing import TypeParser, is_optional, non_optional_type
21+
from pydra.utils.typing import TypeParser, is_optional, optional_type
2222
from .helpers import (
2323
attrs_fields,
2424
attrs_values,
@@ -1130,7 +1130,7 @@ def _command_pos_args(
11301130
cmd_add = []
11311131
# formatter that creates a custom command argument
11321132
# it can take the value of the field, all inputs, or the value of other fields.
1133-
tp = non_optional_type(field.type) if is_optional(field.type) else field.type
1133+
tp = optional_type(field.type) if is_optional(field.type) else field.type
11341134
if field.formatter:
11351135
call_args = inspect.getfullargspec(field.formatter)
11361136
call_args_val = {}

pydra/utils/typing.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1037,22 +1037,24 @@ def label_str(self):
10371037

10381038

10391039
def is_union(type_: type) -> bool:
1040+
"""Checks whether a type is a Union, in either ty.Union[T, U] or T | U form"""
10401041
return ty.get_origin(type_) in UNION_TYPES
10411042

10421043

10431044
def is_optional(type_: type) -> bool:
1044-
"""Check if the type is Optional"""
1045+
"""Check if the type is Optional, i.e. a union containing None"""
10451046
if is_union(type_):
10461047
return any(a is type(None) or is_optional(a) for a in ty.get_args(type_))
10471048
return False
10481049

10491050

1050-
def non_optional_type(type_: type) -> type:
1051+
def optional_type(type_: type) -> type:
1052+
"""Gets the non-None args of an optional type (i.e. a union with a None arg)"""
10511053
if is_optional(type_):
1052-
non_optional = [a for a in ty.get_args(type_) if a is not type(None)]
1053-
if len(non_optional) == 1:
1054-
return non_optional[0]
1055-
return ty.Union[tuple(non_optional)]
1054+
non_none = [a for a in ty.get_args(type_) if a is not type(None)]
1055+
if len(non_none) == 1:
1056+
return non_none[0]
1057+
return ty.Union[tuple(non_none)]
10561058
return type_
10571059

10581060

0 commit comments

Comments
 (0)