Skip to content

Commit 8644b63

Browse files
committed
reprocess record types
1 parent 278a62b commit 8644b63

File tree

2 files changed

+35
-9
lines changed

2 files changed

+35
-9
lines changed

cwltool/main.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,6 @@ def output_callback(out, processStatus):
241241

242242
return final_output[0]
243243

244-
245244
class FSAction(argparse.Action):
246245
objclass = None # type: Text
247246

@@ -294,8 +293,9 @@ class DirectoryAppendAction(FSAppendAction):
294293
objclass = "Directory"
295294

296295

297-
def add_argument(toolparser, name, inptype, description="", default=None):
298-
# type: (argparse.ArgumentParser, Text, Any, Text, Any) -> None
296+
def add_argument(toolparser, name, inptype, records, description="",
297+
default=None):
298+
# type: (argparse.ArgumentParser, Text, Any, List[Text], Text, Any) -> None
299299
if len(name) == 1:
300300
flag = "-"
301301
else:
@@ -329,12 +329,14 @@ def add_argument(toolparser, name, inptype, description="", default=None):
329329
elif isinstance(inptype, dict) and inptype["type"] == "enum":
330330
atype = Text
331331
elif isinstance(inptype, dict) and inptype["type"] == "record":
332+
records.append(name)
332333
for field in inptype['fields']:
333334
fieldname = name+"."+shortname(field['name'])
334335
fieldtype = field['type']
335336
fielddescription = field.get("doc", "")
336337
add_argument(
337-
toolparser, fieldname, fieldtype, fielddescription)
338+
toolparser, fieldname, fieldtype, records,
339+
fielddescription)
338340
return
339341
if inptype == "string":
340342
atype = Text
@@ -364,8 +366,8 @@ def add_argument(toolparser, name, inptype, description="", default=None):
364366
default=default, **typekw)
365367

366368

367-
def generate_parser(toolparser, tool, namemap):
368-
# type: (argparse.ArgumentParser, Process, Dict[Text, Text]) -> argparse.ArgumentParser
369+
def generate_parser(toolparser, tool, namemap, records):
370+
# type: (argparse.ArgumentParser, Process, Dict[Text, Text], List[Text]) -> argparse.ArgumentParser
369371
toolparser.add_argument("job_order", nargs="?", help="Job input json file")
370372
namemap["job_order"] = "job_order"
371373

@@ -375,7 +377,7 @@ def generate_parser(toolparser, tool, namemap):
375377
inptype = inp["type"]
376378
description = inp.get("doc", "")
377379
default = inp.get("default", None)
378-
add_argument(toolparser, name, inptype, description, default)
380+
add_argument(toolparser, name, inptype, records, description, default)
379381

380382
return toolparser
381383

@@ -418,12 +420,23 @@ def load_job_order(args, t, stdin, print_input_deps=False, relative_deps=False,
418420
else:
419421
input_basedir = args.basedir if args.basedir else os.getcwd()
420422
namemap = {} # type: Dict[Text, Text]
421-
toolparser = generate_parser(argparse.ArgumentParser(prog=args.workflow), t, namemap)
423+
records = [] # type: List[Text]
424+
toolparser = generate_parser(
425+
argparse.ArgumentParser(prog=args.workflow), t, namemap, records)
422426
if toolparser:
423427
if args.tool_help:
424428
toolparser.print_help()
425429
return 0
426430
cmd_line = vars(toolparser.parse_args(args.job_order))
431+
for record_name in records:
432+
record = {}
433+
record_items = {
434+
k:v for k,v in cmd_line.iteritems()
435+
if k.startswith(record_name)}
436+
for key, value in record_items.iteritems():
437+
record[key[len(record_name)+1:]] = value
438+
del cmd_line[key]
439+
cmd_line[str(record_name)] = record
427440

428441
if cmd_line["job_order"]:
429442
try:

tests/test_toolargparse.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ class ToolArgparse(unittest.TestCase):
5858
one: File
5959
two: string
6060
61+
expression: $(inputs.foo.two)
62+
6163
outputs: []
6264
'''
6365

@@ -77,7 +79,7 @@ def test_bool(self):
7779
except SystemExit as e:
7880
self.assertEquals(e.code, 0)
7981

80-
def test_record(self):
82+
def test_record_help(self):
8183
with NamedTemporaryFile() as f:
8284
f.write(self.script3)
8385
f.flush()
@@ -86,6 +88,17 @@ def test_record(self):
8688
except SystemExit as e:
8789
self.assertEquals(e.code, 0)
8890

91+
def test_record(self):
92+
with NamedTemporaryFile() as f:
93+
f.write(self.script3)
94+
f.flush()
95+
try:
96+
self.assertEquals(main([f.name, '--foo.one', 'README.rst',
97+
'--foo.two', 'test']), 0)
98+
except SystemExit as e:
99+
self.assertEquals(e.code, 0)
100+
101+
89102

90103
if __name__ == '__main__':
91104
unittest.main()

0 commit comments

Comments
 (0)