Skip to content

Commit f0fbb66

Browse files
committed
handle the formatting of optional types in command_pos_args
1 parent d056d53 commit f0fbb66

File tree

2 files changed

+14
-4
lines changed

2 files changed

+14
-4
lines changed

pydra/engine/specs.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import cloudpickle as cp
1818
from fileformats.core import FileSet
1919
from pydra.utils.messenger import AuditFlag, Messenger
20-
from pydra.utils.typing import TypeParser
20+
from pydra.utils.typing import TypeParser, is_optional, non_optional_type
2121
from .helpers import (
2222
attrs_fields,
2323
attrs_values,
@@ -1100,6 +1100,7 @@ def _command_pos_args(
11001100
cmd_add = []
11011101
# formatter that creates a custom command argument
11021102
# it can take the value of the field, all inputs, or the value of other fields.
1103+
tp = non_optional_type(field.type) if is_optional(field.type) else field.type
11031104
if field.formatter:
11041105
call_args = inspect.getfullargspec(field.formatter)
11051106
call_args_val = {}
@@ -1121,20 +1122,20 @@ def _command_pos_args(
11211122
cmd_el_str = cmd_el_str.strip().replace(" ", " ")
11221123
if cmd_el_str != "":
11231124
cmd_add += split_cmd(cmd_el_str)
1124-
elif field.type is bool and "{" not in field.argstr:
1125+
elif tp is bool and "{" not in field.argstr:
11251126
# if value is simply True the original argstr is used,
11261127
# if False, nothing is added to the command.
11271128
if value is True:
11281129
cmd_add.append(field.argstr)
1129-
elif ty.get_origin(field.type) is MultiInputObj:
1130+
elif ty.get_origin(tp) is MultiInputObj:
11301131
# if the field is MultiInputObj, it is used to create a list of arguments
11311132
for val in value or []:
11321133
cmd_add += self._format_arg(field, val)
11331134
else:
11341135
cmd_add += self._format_arg(field, value)
11351136
return field.position, cmd_add
11361137

1137-
def _format_arg(self, field: shell.arg, value: ty.Any) -> list[str]:
1138+
def _format_arg(self, field: shell.arg, value: ty.Any, tp: type) -> list[str]:
11381139
"""Returning arguments used to specify the command args for a single inputs"""
11391140
if (
11401141
field.argstr.endswith("...")

pydra/utils/typing.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1042,6 +1042,15 @@ def is_optional(type_: type) -> bool:
10421042
return False
10431043

10441044

1045+
def non_optional_type(type_: type) -> type:
1046+
if is_optional(type_):
1047+
non_optional = [a for a in ty.get_args(type_) if a is not type(None)]
1048+
if len(non_optional) == 1:
1049+
return non_optional[0]
1050+
return ty.Union[tuple(non_optional)]
1051+
return type_
1052+
1053+
10451054
def is_fileset_or_union(type_: type) -> bool:
10461055
"""Check if the type is a FileSet or a Union containing a FileSet"""
10471056
if is_union(type_):

0 commit comments

Comments
 (0)