7
7
import typing as ty
8
8
from importlib import import_module
9
9
import logging
10
+ import tempfile
10
11
from traceback import format_exc
11
12
import re
12
13
from tqdm import tqdm
15
16
import black .parsing
16
17
from fileformats .core import FileSet
17
18
from fileformats .medimage_mrtrix3 import ImageFormat , ImageIn , ImageOut , Tracks
18
- from pydra .engine . helpers import make_klass
19
- from pydra .engine import specs
19
+ from pydra .design import shell
20
+ from pydra .utils . typing import MultiInputObj
20
21
from pydra .utils import add_exc_note
22
+ from pydra .engine .helpers import list_fields
21
23
22
24
23
25
logger = logging .getLogger ("pydra-auto-gen" )
@@ -176,7 +178,7 @@ def auto_gen_mrtrix3_pydra(
176
178
manual_path = output_dir / "pydra" / "tasks" / "mrtrix3" / "manual"
177
179
if manual_path .exists ():
178
180
for manual_file in manual_path .iterdir ():
179
- manual_cmd = manual_file .stem [: - 1 ]
181
+ manual_cmd = manual_file .stem
180
182
if not manual_cmd .startswith ("." ) and not manual_cmd .startswith ("__" ):
181
183
manual_cmds .append (manual_cmd )
182
184
@@ -205,9 +207,9 @@ def auto_gen_mrtrix3_pydra(
205
207
206
208
# Write init
207
209
init_path = output_dir / "pydra" / "tasks" / "mrtrix3" / pkg_version / "__init__.py"
208
- imports = "\n " .join (f"from .{ c } _ import { pascal_case_task_name (c )} " for c in cmds )
210
+ imports = "\n " .join (f"from .{ c } import { pascal_case_task_name (c )} " for c in cmds )
209
211
imports += "\n " + "\n " .join (
210
- f"from ..manual.{ c } _ import { pascal_case_task_name (c )} " for c in manual_cmds
212
+ f"from ..manual.{ c } import { pascal_case_task_name (c )} " for c in manual_cmds
211
213
)
212
214
init_path .write_text (f"# Auto-generated, do not edit\n \n { imports } \n " )
213
215
@@ -267,20 +269,29 @@ def auto_gen_cmd(
267
269
code_str = code_str .replace (f"{ old_name } _output" , f"{ cmd_name } _output" )
268
270
code_str = re .sub (r"(?<!\w)5tt_in(?!\w)" , "in_5tt" , code_str )
269
271
try :
270
- code_str = black .format_file_contents (
271
- code_str , fast = False , mode = black .FileMode ()
272
- )
273
- except black .report .NothingChanged :
274
- pass
275
- except black .parsing .InvalidInput :
276
- if log_errors :
277
- logger .error ("Could not parse generated interface for '%s'" , cmd_name )
278
- logger .error (format_exc ())
279
- return []
280
- else :
281
- raise
272
+ try :
273
+ code_str = black .format_file_contents (
274
+ code_str , fast = False , mode = black .FileMode ()
275
+ )
276
+ except black .report .NothingChanged :
277
+ pass
278
+ except black .parsing .InvalidInput :
279
+ if log_errors :
280
+ logger .error (
281
+ "Could not parse generated interface (%s) for '%s'" , cmd_name
282
+ )
283
+ logger .error (format_exc ())
284
+ return []
285
+ else :
286
+ raise
287
+ except Exception as e :
288
+ tfile = Path (tempfile .mkdtemp ()) / (cmd_name + ".py" )
289
+ tfile .write_text (code_str )
290
+ e .add_note (f"when formatting { cmd_name } " )
291
+ e .add_note (f"generated file is { tfile } " )
292
+ raise e
282
293
output_path = (
283
- output_dir / "pydra" / "tasks" / "mrtrix3" / pkg_version / (cmd_name + "_ .py" )
294
+ output_dir / "pydra" / "tasks" / "mrtrix3" / pkg_version / (cmd_name + ".py" )
284
295
)
285
296
output_path .parent .mkdir (exist_ok = True , parents = True )
286
297
with open (output_path , "w" ) as f :
@@ -301,9 +312,12 @@ def auto_gen_cmd(
301
312
def auto_gen_test (cmd_name : str , output_dir : Path , log_errors : bool , pkg_version : str ):
302
313
tests_dir = output_dir / "pydra" / "tasks" / "mrtrix3" / pkg_version / "tests"
303
314
tests_dir .mkdir (exist_ok = True )
304
- module = import_module (f"pydra.tasks.mrtrix3.{ pkg_version } .{ cmd_name } _" )
305
- interface = getattr (module , pascal_case_task_name (cmd_name ))
306
- task = interface ()
315
+ module = import_module (f"pydra.tasks.mrtrix3.{ pkg_version } .{ cmd_name } " )
316
+ definition_klass = getattr (module , pascal_case_task_name (cmd_name ))
317
+
318
+ input_fields = list_fields (definition_klass )
319
+ output_fields = list_fields (definition_klass .Outputs )
320
+ output_fields_dict = {f .name : f for f in output_fields }
307
321
308
322
code_str = f"""# Auto-generated test for { cmd_name }
309
323
@@ -324,9 +338,8 @@ def test_{cmd_name.lower()}(tmp_path, cli_parse_only):
324
338
325
339
task = { pascal_case_task_name (cmd_name )} (
326
340
"""
327
- input_fields = attrs .fields (type (task .inputs ))
328
- output_fields = attrs .fields (make_klass (task .output_spec ))
329
341
342
+ field : shell .arg
330
343
for field in input_fields :
331
344
if field .name in (
332
345
"executable" ,
@@ -335,6 +348,7 @@ def test_{cmd_name.lower()}(tmp_path, cli_parse_only):
335
348
"quiet" ,
336
349
"info" ,
337
350
"nthreads" ,
351
+ "additional_args" ,
338
352
"config" ,
339
353
"args" ,
340
354
):
@@ -355,12 +369,12 @@ def get_value(type_):
355
369
value = "True"
356
370
elif type_ is Path :
357
371
try :
358
- output_field = getattr ( output_fields , field .name )
372
+ output_field = output_fields_dict [ field .name ]
359
373
except AttributeError :
360
374
pass
361
375
else :
362
376
output_type = output_field .type
363
- if ty .get_origin (output_type ) is specs . MultiInputObj :
377
+ if ty .get_origin (output_type ) is MultiInputObj :
364
378
output_type = ty .get_args (output_type )[0 ]
365
379
if ty .get_origin (output_type ) in (list , tuple ):
366
380
output_type = ty .get_args (output_type )[0 ]
@@ -369,7 +383,7 @@ def get_value(type_):
369
383
value = f"{ output_type .__name__ } .sample()"
370
384
elif ty .get_origin (type_ ) is ty .Union :
371
385
value = get_value (ty .get_args (type_ )[0 ])
372
- elif ty .get_origin (type_ ) is specs . MultiInputObj :
386
+ elif ty .get_origin (type_ ) is MultiInputObj :
373
387
value = "[" + get_value (ty .get_args (type_ )[0 ]) + "]"
374
388
elif ty .get_origin (type_ ) and issubclass (ty .get_origin (type_ ), ty .Sequence ):
375
389
value = (
@@ -386,8 +400,8 @@ def get_value(type_):
386
400
387
401
if field .default is not attrs .NOTHING :
388
402
value = field .default
389
- elif "allowed_values" in field .metadata :
390
- value = repr (field .metadata [ " allowed_values" ] [0 ])
403
+ elif field .allowed_values :
404
+ value = repr (field .allowed_values [0 ])
391
405
else :
392
406
value = get_value (field .type )
393
407
0 commit comments