77import typing as ty
88from importlib import import_module
99import logging
10+ import tempfile
1011from traceback import format_exc
1112import re
1213from tqdm import tqdm
1516import black .parsing
1617from fileformats .core import FileSet
1718from fileformats .medimage_mrtrix3 import ImageFormat , ImageIn , ImageOut , Tracks
18- from pydra .engine . helpers import make_klass
19- from pydra .engine import specs
19+ from pydra .design import shell
20+ from pydra .utils . typing import MultiInputObj
2021from pydra .utils import add_exc_note
22+ from pydra .engine .helpers import list_fields
2123
2224
2325logger = logging .getLogger ("pydra-auto-gen" )
@@ -176,7 +178,7 @@ def auto_gen_mrtrix3_pydra(
176178 manual_path = output_dir / "pydra" / "tasks" / "mrtrix3" / "manual"
177179 if manual_path .exists ():
178180 for manual_file in manual_path .iterdir ():
179- manual_cmd = manual_file .stem [: - 1 ]
181+ manual_cmd = manual_file .stem
180182 if not manual_cmd .startswith ("." ) and not manual_cmd .startswith ("__" ):
181183 manual_cmds .append (manual_cmd )
182184
@@ -205,9 +207,9 @@ def auto_gen_mrtrix3_pydra(
205207
206208 # Write init
207209 init_path = output_dir / "pydra" / "tasks" / "mrtrix3" / pkg_version / "__init__.py"
208- imports = "\n " .join (f"from .{ c } _ import { pascal_case_task_name (c )} " for c in cmds )
210+ imports = "\n " .join (f"from .{ c } import { pascal_case_task_name (c )} " for c in cmds )
209211 imports += "\n " + "\n " .join (
210- f"from ..manual.{ c } _ import { pascal_case_task_name (c )} " for c in manual_cmds
212+ f"from ..manual.{ c } import { pascal_case_task_name (c )} " for c in manual_cmds
211213 )
212214 init_path .write_text (f"# Auto-generated, do not edit\n \n { imports } \n " )
213215
@@ -267,20 +269,29 @@ def auto_gen_cmd(
267269 code_str = code_str .replace (f"{ old_name } _output" , f"{ cmd_name } _output" )
268270 code_str = re .sub (r"(?<!\w)5tt_in(?!\w)" , "in_5tt" , code_str )
269271 try :
270- code_str = black .format_file_contents (
271- code_str , fast = False , mode = black .FileMode ()
272- )
273- except black .report .NothingChanged :
274- pass
275- except black .parsing .InvalidInput :
276- if log_errors :
277- logger .error ("Could not parse generated interface for '%s'" , cmd_name )
278- logger .error (format_exc ())
279- return []
280- else :
281- raise
272+ try :
273+ code_str = black .format_file_contents (
274+ code_str , fast = False , mode = black .FileMode ()
275+ )
276+ except black .report .NothingChanged :
277+ pass
278+ except black .parsing .InvalidInput :
279+ if log_errors :
280+ logger .error (
281+ "Could not parse generated interface (%s) for '%s'" , cmd_name
282+ )
283+ logger .error (format_exc ())
284+ return []
285+ else :
286+ raise
287+ except Exception as e :
288+ tfile = Path (tempfile .mkdtemp ()) / (cmd_name + ".py" )
289+ tfile .write_text (code_str )
290+ e .add_note (f"when formatting { cmd_name } " )
291+ e .add_note (f"generated file is { tfile } " )
292+ raise e
282293 output_path = (
283- output_dir / "pydra" / "tasks" / "mrtrix3" / pkg_version / (cmd_name + "_ .py" )
294+ output_dir / "pydra" / "tasks" / "mrtrix3" / pkg_version / (cmd_name + ".py" )
284295 )
285296 output_path .parent .mkdir (exist_ok = True , parents = True )
286297 with open (output_path , "w" ) as f :
@@ -301,9 +312,12 @@ def auto_gen_cmd(
301312def auto_gen_test (cmd_name : str , output_dir : Path , log_errors : bool , pkg_version : str ):
302313 tests_dir = output_dir / "pydra" / "tasks" / "mrtrix3" / pkg_version / "tests"
303314 tests_dir .mkdir (exist_ok = True )
304- module = import_module (f"pydra.tasks.mrtrix3.{ pkg_version } .{ cmd_name } _" )
305- interface = getattr (module , pascal_case_task_name (cmd_name ))
306- task = interface ()
315+ module = import_module (f"pydra.tasks.mrtrix3.{ pkg_version } .{ cmd_name } " )
316+ definition_klass = getattr (module , pascal_case_task_name (cmd_name ))
317+
318+ input_fields = list_fields (definition_klass )
319+ output_fields = list_fields (definition_klass .Outputs )
320+ output_fields_dict = {f .name : f for f in output_fields }
307321
308322 code_str = f"""# Auto-generated test for { cmd_name }
309323
@@ -324,9 +338,8 @@ def test_{cmd_name.lower()}(tmp_path, cli_parse_only):
324338
325339 task = { pascal_case_task_name (cmd_name )} (
326340"""
327- input_fields = attrs .fields (type (task .inputs ))
328- output_fields = attrs .fields (make_klass (task .output_spec ))
329341
342+ field : shell .arg
330343 for field in input_fields :
331344 if field .name in (
332345 "executable" ,
@@ -335,6 +348,7 @@ def test_{cmd_name.lower()}(tmp_path, cli_parse_only):
335348 "quiet" ,
336349 "info" ,
337350 "nthreads" ,
351+ "additional_args" ,
338352 "config" ,
339353 "args" ,
340354 ):
@@ -355,12 +369,12 @@ def get_value(type_):
355369 value = "True"
356370 elif type_ is Path :
357371 try :
358- output_field = getattr ( output_fields , field .name )
372+ output_field = output_fields_dict [ field .name ]
359373 except AttributeError :
360374 pass
361375 else :
362376 output_type = output_field .type
363- if ty .get_origin (output_type ) is specs . MultiInputObj :
377+ if ty .get_origin (output_type ) is MultiInputObj :
364378 output_type = ty .get_args (output_type )[0 ]
365379 if ty .get_origin (output_type ) in (list , tuple ):
366380 output_type = ty .get_args (output_type )[0 ]
@@ -369,7 +383,7 @@ def get_value(type_):
369383 value = f"{ output_type .__name__ } .sample()"
370384 elif ty .get_origin (type_ ) is ty .Union :
371385 value = get_value (ty .get_args (type_ )[0 ])
372- elif ty .get_origin (type_ ) is specs . MultiInputObj :
386+ elif ty .get_origin (type_ ) is MultiInputObj :
373387 value = "[" + get_value (ty .get_args (type_ )[0 ]) + "]"
374388 elif ty .get_origin (type_ ) and issubclass (ty .get_origin (type_ ), ty .Sequence ):
375389 value = (
@@ -386,8 +400,8 @@ def get_value(type_):
386400
387401 if field .default is not attrs .NOTHING :
388402 value = field .default
389- elif "allowed_values" in field .metadata :
390- value = repr (field .metadata [ " allowed_values" ] [0 ])
403+ elif field .allowed_values :
404+ value = repr (field .allowed_values [0 ])
391405 else :
392406 value = get_value (field .type )
393407
0 commit comments