Skip to content

Commit 78690fc

Browse files
committed
fixed up handling of union types
1 parent 5992a20 commit 78690fc

File tree

5 files changed

+23
-11
lines changed

5 files changed

+23
-11
lines changed

nipype2pydra/interface/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from .base import BaseInterfaceConverter
2-
from .function import FunctionInterfaceConverter
3-
from .shell_command import ShellCommandInterfaceConverter
2+
from .python import PythonInterfaceConverter
3+
from .shell import ShellInterfaceConverter
44
from .base import (
55
InputsConverter,
66
OutputsConverter,
@@ -11,8 +11,8 @@
1111

1212
__all__ = [
1313
"BaseInterfaceConverter",
14-
"FunctionInterfaceConverter",
15-
"ShellCommandInterfaceConverter",
14+
"PythonInterfaceConverter",
15+
"ShellInterfaceConverter",
1616
"InputsConverter",
1717
"OutputsConverter",
1818
"TestGenerator",

nipype2pydra/interface/loaders.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@ def get_converter(nipype_module: str, nipype_name: str, **kwargs):
66
nipype_interface = getattr(import_module(nipype_module), nipype_name)
77

88
if hasattr(nipype_interface, "_cmd"):
9-
from .shell_command import ShellCommandInterfaceConverter as Converter
9+
from .shell import ShellInterfaceConverter as Converter
1010
else:
11-
from .function import FunctionInterfaceConverter as Converter
11+
from .python import PythonInterfaceConverter as Converter
1212

1313
return Converter(nipype_module=nipype_module, nipype_name=nipype_name, **kwargs)

nipype2pydra/interface/function.py renamed to nipype2pydra/interface/python.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414

1515
@attrs.define(slots=False)
16-
class FunctionInterfaceConverter(BaseInterfaceConverter):
16+
class PythonInterfaceConverter(BaseInterfaceConverter):
1717

1818
converter_type = "function"
1919

@@ -153,7 +153,11 @@ def generate_code(self, input_fields, nonstd_types, output_fields) -> ty.Tuple[
153153
spec_str += f" {name}: {type_}\n"
154154

155155
spec_str += " @staticmethod\n"
156-
spec_str += " def function(" + ", ".join(f"{n}: {t}" for n, t, _ in input_fields) + ")"
156+
spec_str += (
157+
" def function("
158+
+ ", ".join(f"{n}: {t}" for n, t, _ in input_fields)
159+
+ ")"
160+
)
157161
output_types = [o[1] for o in output_fields]
158162
if any(t is not ty.Any for t in output_types):
159163
spec_str += "-> "

nipype2pydra/interface/shell_command.py renamed to nipype2pydra/interface/shell.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626

2727

2828
@attrs.define(slots=False)
29-
class ShellCommandInterfaceConverter(BaseInterfaceConverter):
29+
class ShellInterfaceConverter(BaseInterfaceConverter):
3030

3131
converter_type = "shell_command"
3232
_format_argstrs: ty.Dict[str, str] = attrs.field(factory=dict)

nipype2pydra/utils/misc.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from pathlib import Path
1111
from fileformats.core import FileSet, from_mime
1212
from fileformats.core.mixin import WithClassifiers
13+
from pydra.utils.typing import is_union, is_optional
1314
from ..exceptions import (
1415
UnmatchedParensException,
1516
UnmatchedQuoteException,
@@ -590,8 +591,15 @@ def type_to_str(type_: type, mandatory: bool = False) -> str:
590591
type_str = type_.__name__
591592
else:
592593
type_str = str(type_)
593-
if origin := ty.get_origin(type):
594-
args = [type_to_str(arg) for arg in ty.get_args(type_)]
594+
if is_union(type_):
595+
args = [t if t is not type(None) else None for t in ty.get_args(type_)]
596+
if not mandatory and not is_optional(type_):
597+
args.append(None)
598+
return " | ".join(
599+
type_to_str(a, mandatory=True) if a is not None else "None" for a in args
600+
)
601+
if origin := ty.get_origin(type_):
602+
args = [type_to_str(arg, mandatory=True) for arg in ty.get_args(type_)]
595603
type_str = f"{origin.__name__}[{', '.join(args)}]"
596604
module = origin.__module__
597605
else:

0 commit comments

Comments
 (0)