Skip to content

Commit 8c696fd

Browse files
committed
debugging mriqc/niworkflows conversions
1 parent abf3551 commit 8c696fd

File tree

11 files changed

+155
-96
lines changed

11 files changed

+155
-96
lines changed

nipype2pydra/helpers.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ def used_symbols(self) -> UsedSymbols:
133133
always_include=self.package.all_explicit,
134134
translations=self.package.all_import_translations,
135135
)
136-
used.imports.update(i.to_statement() for i in self.imports)
136+
used.import_stmts.update(i.to_statement() for i in self.imports)
137137
return used
138138

139139
@cached_property
@@ -147,12 +147,10 @@ def converted_code(self) -> ty.List[str]:
147147
@cached_property
148148
def nested_interfaces(self):
149149
potential_classes = {
150-
full_address(c[1]): c[0]
151-
for c in self.used_symbols.intra_pkg_classes
152-
if c[0]
150+
full_address(c[1]): c[0] for c in self.used_symbols.imported_classes if c[0]
153151
}
154152
potential_classes.update(
155-
(full_address(c), c.__name__) for c in self.used_symbols.local_classes
153+
(full_address(c), c.__name__) for c in self.used_symbols.classes
156154
)
157155
return {
158156
potential_classes[address]: workflow

nipype2pydra/interface/base.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -657,7 +657,9 @@ def pydra_fld_input(self, field, nm):
657657
val = getattr(field, key)
658658
if val is not None:
659659
if key == "argstr" and "%" in val:
660-
val = self.string_formats(argstr=val, name=nm)
660+
val = self.string_formats(
661+
argstr=val, name=nm, type_=field.trait_type
662+
)
661663
elif key == "mandatory" and pydra_default is not None:
662664
val = False # Overwrite mandatory to False if default is provided
663665
pydra_metadata[pydra_key_nm] = val
@@ -666,7 +668,9 @@ def pydra_fld_input(self, field, nm):
666668
template = getattr(field, "name_template")
667669
name_source = ensure_list(getattr(field, "name_source"))
668670
if name_source:
669-
tmpl = self.string_formats(argstr=template, name=name_source[0])
671+
tmpl = self.string_formats(
672+
argstr=template, name=name_source[0], type_=field.trait_type
673+
)
670674
else:
671675
tmpl = template
672676
if nm in self.nipype_interface.output_spec().class_trait_names():
@@ -829,11 +833,14 @@ def pydra_type_converter(self, field, spec_type, name):
829833
pydra_type = ty.Any
830834
return pydra_type
831835

832-
def string_formats(self, argstr, name):
836+
def string_formats(self, argstr, name, type_):
833837
keys = re.findall(r"(%[0-9\.]*(?:s|d|i|g|f))", argstr)
834838
new_argstr = argstr
835839
for i, key in enumerate(keys):
836-
repl = f"{name}" if len(keys) == 1 else f"{name}[{i}]"
840+
if isinstance(type_, traits.trait_types.Bool):
841+
repl = f"{name}:d"
842+
else:
843+
repl = f"{name}" if len(keys) == 1 else f"{name}[{i}]"
837844
match = re.match(r"%([0-9\.]+)f", key)
838845
if match:
839846
repl += ":" + match.group(1)
@@ -972,7 +979,7 @@ def _converted_test(self):
972979
)
973980

974981
return spec_str, UsedSymbols(
975-
module_name=self.nipype_module.__name__, imports=imports
982+
module_name=self.nipype_module.__name__, import_stmts=imports
976983
)
977984

978985
def create_doctests(self, input_fields, nonstd_types):
@@ -1032,7 +1039,7 @@ def _misc_cleanups(self, body: str) -> str:
10321039
body = body.replace("self.cmd", f'"{self.nipype_interface._cmd}"')
10331040

10341041
body = body.replace("self.output_spec().get()", "{}")
1035-
body = body.replace("self._outputs()", "{}")
1042+
body = body.replace("self._outputs().get()", "{}")
10361043
# body = re.sub(
10371044
# r"outputs = self\.(output_spec|_outputs)\(\).*$",
10381045
# r"outputs = {}",

nipype2pydra/interface/function.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def types_to_names(spec_fields):
9191
translations=self.package.all_import_translations,
9292
absolute_imports=True,
9393
)
94-
used.update(method_used, from_other_module=False)
94+
used.update(method_used)
9595

9696
method_body = ""
9797
for field in input_fields:
@@ -129,7 +129,7 @@ def types_to_names(spec_fields):
129129
translations=self.package.all_import_translations,
130130
absolute_imports=True,
131131
)
132-
used.update(init_used, from_other_module=False)
132+
used.update(init_used)
133133
method_body += init_code + "\n"
134134

135135
# Combined src of run_interface and list_outputs
@@ -163,7 +163,7 @@ def types_to_names(spec_fields):
163163
translations=self.package.all_import_translations,
164164
absolute_imports=True,
165165
)
166-
used.update(run_interface_used, from_other_module=False)
166+
used.update(run_interface_used)
167167
method_body += run_interface_code + "\n"
168168

169169
list_outputs_code = inspect.getsource(
@@ -197,7 +197,7 @@ def types_to_names(spec_fields):
197197
translations=self.package.all_import_translations,
198198
absolute_imports=True,
199199
)
200-
used.update(list_outputs_used, from_other_module=False)
200+
used.update(list_outputs_used)
201201
method_body += list_outputs_code + "\n"
202202

203203
assert method_body, "Neither `run_interface` and `list_outputs` are defined"
@@ -250,12 +250,12 @@ def types_to_names(spec_fields):
250250
additional_imports.add(imprt)
251251
spec_str = repl_spec_str
252252

253-
used.imports.update(
253+
used.import_stmts.update(
254254
self.construct_imports(
255255
nonstd_types,
256256
spec_str,
257257
include_task=False,
258-
base=base_imports + list(used.imports) + list(additional_imports),
258+
base=base_imports + list(used.import_stmts) + list(additional_imports),
259259
)
260260
)
261261

nipype2pydra/interface/shell_command.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ def types_to_names(spec_fields):
210210
)
211211
used.update(super_used)
212212

213-
used.imports.update(
213+
used.import_stmts.update(
214214
self.construct_imports(
215215
nonstd_types,
216216
spec_str,

nipype2pydra/package.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -400,7 +400,7 @@ def write(self, package_root: Path, to_include: ty.List[str] = None):
400400
workflow.prepare_connections()
401401

402402
def collect_intra_pkg_objects(used: UsedSymbols, port_nipype: bool = True):
403-
for _, klass in used.intra_pkg_classes:
403+
for _, klass in used.imported_classes:
404404
address = full_address(klass)
405405
if address in self.nipype_port_converters:
406406
if port_nipype:
@@ -412,10 +412,10 @@ def collect_intra_pkg_objects(used: UsedSymbols, port_nipype: bool = True):
412412
)
413413
elif full_address(klass) not in self.interfaces:
414414
intra_pkg_modules[klass.__module__].add(klass)
415-
for _, func in used.intra_pkg_funcs:
415+
for _, func in used.imported_funcs:
416416
if full_address(func) not in list(self.workflows):
417417
intra_pkg_modules[func.__module__].add(func)
418-
for const_mod_address, _, const_name in used.intra_pkg_constants:
418+
for const_mod_address, _, const_name in used.imported_constants:
419419
intra_pkg_modules[const_mod_address].add(const_name)
420420

421421
for conv in list(self.functions.values()) + list(self.classes.values()):
@@ -429,7 +429,7 @@ def collect_intra_pkg_objects(used: UsedSymbols, port_nipype: bool = True):
429429
package_root,
430430
already_converted=already_converted,
431431
)
432-
class_addrs = [full_address(c) for _, c in all_used.intra_pkg_classes]
432+
class_addrs = [full_address(c) for _, c in all_used.imported_classes]
433433
included_addrs = [c.full_address for c in interfaces_to_include]
434434
interfaces_to_include.extend(
435435
self.interfaces[a]
@@ -555,25 +555,23 @@ def write_intra_pkg_modules(
555555
always_include=self.all_explicit,
556556
)
557557

558-
classes = used.local_classes + [
559-
o for o in objs if inspect.isclass(o) and o not in used.local_classes
558+
classes = used.classes + [
559+
o for o in objs if inspect.isclass(o) and o not in used.classes
560560
]
561561

562-
functions = list(used.local_functions) + [
563-
o
564-
for o in objs
565-
if inspect.isfunction(o) and o not in used.local_functions
562+
functions = list(used.functions) + [
563+
o for o in objs if inspect.isfunction(o) and o not in used.functions
566564
]
567565

568566
self.write_to_module(
569567
package_root=package_root,
570568
module_name=out_mod_name,
571569
used=UsedSymbols(
572570
module_name=mod_name,
573-
imports=used.imports,
571+
import_stmts=used.import_stmts,
574572
constants=used.constants,
575-
local_classes=classes,
576-
local_functions=functions,
573+
classes=classes,
574+
functions=functions,
577575
),
578576
find_replace=self.find_replace,
579577
inline_intra_pkg=False,
@@ -871,11 +869,11 @@ def write_to_module(
871869
existing_imports = parse_imports(existing_import_strs, relative_to=module_name)
872870
converter_imports = []
873871

874-
for klass in used.local_classes:
872+
for klass in used.classes:
875873
if f"\nclass {klass.__name__}(" not in code_str:
876874
try:
877875
class_converter = self.classes[full_address(klass)]
878-
converter_imports.extend(class_converter.used_symbols.imports)
876+
converter_imports.extend(class_converter.used_symbols.import_stmts)
879877
except KeyError:
880878
class_converter = ClassConverter.from_object(klass, self)
881879
code_str += "\n" + class_converter.converted_code + "\n"
@@ -903,11 +901,13 @@ def write_to_module(
903901
if converted_code.strip() not in code_str:
904902
code_str += "\n" + converted_code + "\n"
905903

906-
for func in sorted(used.local_functions, key=attrgetter("__name__")):
904+
for func in sorted(used.functions, key=attrgetter("__name__")):
907905
if f"\ndef {func.__name__}(" not in code_str:
908906
if func.__name__ in self.functions:
909907
function_converter = self.functions[full_address(func)]
910-
converter_imports.extend(function_converter.used_symbols.imports)
908+
converter_imports.extend(
909+
function_converter.used_symbols.import_stmts
910+
)
911911
else:
912912
function_converter = FunctionConverter.from_object(func, self)
913913
code_str += "\n" + function_converter.converted_code + "\n"
@@ -923,7 +923,7 @@ def write_to_module(
923923
code_str += (
924924
"\n\n# Intra-package imports that have been inlined in this module\n\n"
925925
)
926-
for func_name, func in sorted(used.intra_pkg_funcs, key=itemgetter(0)):
926+
for func_name, func in sorted(used.imported_funcs, key=itemgetter(0)):
927927
func_src = get_source_code(func)
928928
func_src = re.sub(
929929
r"^(#[^\n]+\ndef) (\w+)(?=\()",
@@ -934,7 +934,7 @@ def write_to_module(
934934
code_str += "\n\n" + cleanup_function_body(func_src)
935935
inlined_symbols.append(func_name)
936936

937-
for klass_name, klass in sorted(used.intra_pkg_classes, key=itemgetter(0)):
937+
for klass_name, klass in sorted(used.imported_classes, key=itemgetter(0)):
938938
klass_src = get_source_code(klass)
939939
klass_src = re.sub(
940940
r"^(#[^\n]+\nclass) (\w+)(?=\()",
@@ -973,7 +973,7 @@ def write_to_module(
973973
imports = ImportStatement.collate(
974974
existing_imports
975975
+ converter_imports
976-
+ [i for i in used.imports if not i.indent]
976+
+ [i for i in used.import_stmts if not i.indent]
977977
+ GENERIC_PYDRA_IMPORTS
978978
+ additional_imports
979979
)

nipype2pydra/pkg_gen/__init__.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1123,13 +1123,13 @@ def insert_args_in_method_calls(
11231123
mod = import_module(mod_name)
11241124
used = UsedSymbols.find(mod, methods, omit_classes=(BaseInterface, TraitedSpec))
11251125
all_funcs.update(methods)
1126-
for func in used.local_functions:
1126+
for func in used.functions:
11271127
all_funcs.add(cleanup_function_body(get_source_code(func)))
1128-
for klass in used.local_classes:
1128+
for klass in used.classes:
11291129
klass_src = cleanup_function_body(get_source_code(klass))
11301130
if klass_src not in all_classes:
11311131
all_classes.append(klass_src)
1132-
for new_func_name, func in used.intra_pkg_funcs:
1132+
for new_func_name, func in used.imported_funcs:
11331133
if new_func_name is None:
11341134
continue # Not referenced directly in this module
11351135
func_src = get_source_code(func)
@@ -1148,7 +1148,7 @@ def insert_args_in_method_calls(
11481148
+ match.group(2)
11491149
)
11501150
all_funcs.add(cleanup_function_body(func_src))
1151-
for new_klass_name, klass in used.intra_pkg_classes:
1151+
for new_klass_name, klass in used.imported_classes:
11521152
if new_klass_name is None:
11531153
continue # Not referenced directly in this module
11541154
klass_src = get_source_code(klass)
@@ -1169,7 +1169,7 @@ def insert_args_in_method_calls(
11691169
klass_src = cleanup_function_body(klass_src)
11701170
if klass_src not in all_classes:
11711171
all_classes.append(klass_src)
1172-
all_imports.update(used.imports)
1172+
all_imports.update(used.import_stmts)
11731173
all_constants.update(used.constants)
11741174
return (
11751175
sorted(

nipype2pydra/statements/imports.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -587,6 +587,8 @@ def parse_imports(
587587
"from fileformats.generic import File, Directory",
588588
"from pydra.engine.specs import MultiInputObj",
589589
"from pathlib import Path",
590+
"import json",
591+
"import yaml",
590592
"import logging",
591593
"import pydra.mark",
592594
"import typing as ty",

nipype2pydra/utils/misc.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
from importlib import import_module
2424
from logging import getLogger
25+
from pydra.engine.specs import MultiInputObj
2526

2627

2728
logger = getLogger("nipype2pydra")
@@ -482,12 +483,20 @@ def from_named_dicts_converter(
482483
def str_to_type(type_str: str) -> type:
483484
"""Resolve a string representation of a type into a valid type"""
484485
if "/" in type_str:
486+
if type_str.startswith("multi["):
487+
assert type_str.endswith("]"), f"Invalid multi type: {type_str}"
488+
type_str = type_str[6:-1]
489+
multi = True
490+
else:
491+
multi = False
485492
tp = from_mime(type_str)
486493
try:
487494
# If datatype is a field, use its primitive instead
488495
tp = tp.primitive # type: ignore
489496
except AttributeError:
490497
pass
498+
if multi:
499+
tp = MultiInputObj[tp]
491500
else:
492501

493502
def resolve_type(type_str: str) -> type:

0 commit comments

Comments
 (0)