Skip to content

Commit 10492ac

Browse files
author
Peter Amstutz
authored
Refactor main (#1206)
* Refactor main function Pull out blocks of code from main into separate functions.
1 parent a8d8d00 commit 10492ac

File tree

2 files changed

+170
-112
lines changed

2 files changed

+170
-112
lines changed

cwltool/load_tool.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,10 @@
6262
[Dict[Text, Union[Text, bool]], requests.sessions.Session], Fetcher]
6363
ResolverType = Callable[[Loader, Union[Text, Dict[Text, Any]]], Text]
6464

65-
def default_loader(fetcher_constructor=None):
66-
# type: (Optional[FetcherConstructorType]) -> Loader
67-
return Loader(jobloaderctx, fetcher_constructor=fetcher_constructor)
65+
def default_loader(fetcher_constructor=None, enable_dev=False):
66+
# type: (Optional[FetcherConstructorType], bool) -> Loader
67+
return Loader(jobloaderctx, fetcher_constructor=fetcher_constructor,
68+
allow_attachments=lambda r: enable_dev)
6869

6970
def resolve_tool_uri(argsworkflow, # type: Text
7071
resolver=None, # type: Optional[ResolverType]

cwltool/main.py

Lines changed: 166 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,6 @@ def load_job_order(args, # type: argparse.Namespace
275275
sys.exit(1)
276276
return (job_order_object, input_basedir, loader)
277277

278-
279278
def init_job_order(job_order_object, # type: Optional[MutableMapping[Text, Any]]
280279
args, # type: argparse.Namespace
281280
process, # type: Process
@@ -389,7 +388,6 @@ def expand_formats(p): # type: (Dict[Text, Any]) -> None
389388
return job_order_object
390389

391390

392-
393391
def make_relative(base, obj): # type: (Text, Dict[Text, Any]) -> None
394392
"""Relativize the location URI of a File or Directory object."""
395393
uri = obj.get("location", obj.get("path"))
@@ -480,6 +478,157 @@ def supported_cwl_versions(enable_dev): # type: (bool) -> List[Text]
480478
versions.sort()
481479
return versions
482480

481+
def configure_logging(args, # type: argparse.Namespace
482+
stderr_handler, # type: logging.Handler
483+
runtimeContext # type: RuntimeContext
484+
): # type: (...) -> None
485+
# Configure logging
486+
rdflib_logger = logging.getLogger("rdflib.term")
487+
rdflib_logger.addHandler(stderr_handler)
488+
rdflib_logger.setLevel(logging.ERROR)
489+
if args.quiet:
490+
# Silence STDERR, not an eventual provenance log file
491+
stderr_handler.setLevel(logging.WARN)
492+
if runtimeContext.debug:
493+
# Increase to debug for both stderr and provenance log file
494+
_logger.setLevel(logging.DEBUG)
495+
stderr_handler.setLevel(logging.DEBUG)
496+
rdflib_logger.setLevel(logging.DEBUG)
497+
fmtclass = coloredlogs.ColoredFormatter if args.enable_color else logging.Formatter
498+
formatter = fmtclass("%(levelname)s %(message)s")
499+
if args.timestamps:
500+
formatter = fmtclass(
501+
"[%(asctime)s] %(levelname)s %(message)s",
502+
"%Y-%m-%d %H:%M:%S")
503+
stderr_handler.setFormatter(formatter)
504+
505+
def setup_schema(args, # type: argparse.Namespace
506+
custom_schema_callback # type: Optional[Callable[[], None]]
507+
): # type: (...) -> None
508+
if custom_schema_callback is not None:
509+
custom_schema_callback()
510+
elif args.enable_ext:
511+
res = pkg_resources.resource_stream(__name__, 'extensions.yml')
512+
use_custom_schema("v1.0", "http://commonwl.org/cwltool", res.read())
513+
res.close()
514+
else:
515+
use_standard_schema("v1.0")
516+
517+
def setup_provenance(args, # type: argparse.Namespace
518+
argsl, # type: List[str]
519+
runtimeContext # type: RuntimeContext
520+
): # type: (...) -> Optional[int]
521+
if not args.compute_checksum:
522+
_logger.error("--provenance incompatible with --no-compute-checksum")
523+
return 1
524+
ro = ResearchObject(
525+
getdefault(runtimeContext.make_fs_access, StdFsAccess),
526+
temp_prefix_ro=args.tmpdir_prefix, orcid=args.orcid,
527+
full_name=args.cwl_full_name)
528+
runtimeContext.research_obj = ro
529+
log_file_io = ro.open_log_file_for_activity(ro.engine_uuid)
530+
prov_log_handler = logging.StreamHandler(cast(IO[str], log_file_io))
531+
532+
class ProvLogFormatter(logging.Formatter):
533+
"""Enforce ISO8601 with both T and Z."""
534+
535+
def __init__(self): # type: () -> None
536+
super(ProvLogFormatter, self).__init__(
537+
"[%(asctime)sZ] %(message)s")
538+
539+
def formatTime(self, record, datefmt=None):
540+
# type: (logging.LogRecord, Optional[str]) -> str
541+
record_time = time.gmtime(record.created)
542+
formatted_time = time.strftime("%Y-%m-%dT%H:%M:%S", record_time)
543+
with_msecs = "%s,%03d" % (formatted_time, record.msecs)
544+
return with_msecs
545+
prov_log_handler.setFormatter(ProvLogFormatter())
546+
_logger.addHandler(prov_log_handler)
547+
_logger.debug(u"[provenance] Logging to %s", log_file_io)
548+
if argsl is not None:
549+
# Log cwltool command line options to provenance file
550+
_logger.info("[cwltool] %s %s", sys.argv[0], u" ".join(argsl))
551+
_logger.debug(u"[cwltool] Arguments: %s", args)
552+
return None
553+
554+
def setup_loadingContext(loadingContext, # type: Optional[LoadingContext]
555+
runtimeContext, # type: RuntimeContext
556+
args # type: argparse.Namespace
557+
): # type: (...) -> LoadingContext
558+
if loadingContext is None:
559+
loadingContext = LoadingContext(vars(args))
560+
else:
561+
loadingContext = loadingContext.copy()
562+
loadingContext.loader = default_loader(loadingContext.fetcher_constructor,
563+
enable_dev=args.enable_dev)
564+
loadingContext.research_obj = runtimeContext.research_obj
565+
loadingContext.disable_js_validation = \
566+
args.disable_js_validation or (not args.do_validate)
567+
loadingContext.construct_tool_object = getdefault(
568+
loadingContext.construct_tool_object, workflow.default_make_tool)
569+
loadingContext.resolver = getdefault(loadingContext.resolver, tool_resolver)
570+
if loadingContext.do_update is None:
571+
loadingContext.do_update = not (args.pack or args.print_subgraph)
572+
573+
return loadingContext
574+
575+
def make_template(tool # type: Process
576+
): # type: (...) -> None
577+
def my_represent_none(self, data): # pylint: disable=unused-argument
578+
# type: (Any, Any) -> Any
579+
"""Force clean representation of 'null'."""
580+
return self.represent_scalar(u'tag:yaml.org,2002:null', u'null')
581+
yaml.RoundTripRepresenter.add_representer(type(None), my_represent_none)
582+
yaml.round_trip_dump(
583+
generate_input_template(tool), sys.stdout,
584+
default_flow_style=False, indent=4, block_seq_indent=2)
585+
586+
587+
def choose_target(args, # type: argparse.Namespace
588+
tool, # type: Process
589+
loadingContext # type: LoadingContext
590+
): # type: (...) -> Optional[Process]
591+
592+
if loadingContext.loader is None:
593+
raise Exception("loadingContext.loader cannot be None")
594+
595+
if isinstance(tool, Workflow):
596+
url = urllib.parse.urlparse(tool.tool["id"])
597+
if url.fragment:
598+
extracted = get_subgraph([tool.tool["id"] + "/" + r for r in args.target], tool)
599+
else:
600+
extracted = get_subgraph([loadingContext.loader.fetcher.urljoin(tool.tool["id"], "#" + r)
601+
for r in args.target],
602+
tool)
603+
else:
604+
_logger.error("Can only use --target on Workflows")
605+
return None
606+
if isinstance(loadingContext.loader.idx, CommentedMap):
607+
loadingContext.loader.idx[extracted["id"]] = extracted
608+
tool = make_tool(extracted["id"],
609+
loadingContext)
610+
else:
611+
raise Exception("Missing loadingContext.loader.idx!")
612+
613+
return tool
614+
615+
def check_working_directories(runtimeContext # type: RuntimeContext
616+
): # type: (...) -> Optional[int]
617+
for dirprefix in ("tmpdir_prefix", "tmp_outdir_prefix", "cachedir"):
618+
if getattr(runtimeContext, dirprefix) and getattr(runtimeContext, dirprefix) != DEFAULT_TMP_PREFIX:
619+
sl = "/" if getattr(runtimeContext, dirprefix).endswith("/") or dirprefix == "cachedir" \
620+
else ""
621+
setattr(runtimeContext, dirprefix,
622+
os.path.abspath(getattr(runtimeContext, dirprefix)) + sl)
623+
if not os.path.exists(os.path.dirname(getattr(runtimeContext, dirprefix))):
624+
try:
625+
os.makedirs(os.path.dirname(getattr(runtimeContext, dirprefix)))
626+
except Exception as e:
627+
_logger.error("Failed to create directory: %s", Text(e))
628+
return 1
629+
return None
630+
631+
483632
def main(argsl=None, # type: Optional[List[str]]
484633
args=None, # type: Optional[argparse.Namespace]
485634
job_order_object=None, # type: Optional[MutableMapping[Text, Any]]
@@ -545,26 +694,7 @@ def main(argsl=None, # type: Optional[List[str]]
545694
if not hasattr(args, key):
546695
setattr(args, key, val)
547696

548-
# Configure logging
549-
rdflib_logger = logging.getLogger("rdflib.term")
550-
rdflib_logger.addHandler(stderr_handler)
551-
rdflib_logger.setLevel(logging.ERROR)
552-
if args.quiet:
553-
# Silence STDERR, not an eventual provenance log file
554-
stderr_handler.setLevel(logging.WARN)
555-
if runtimeContext.debug:
556-
# Increase to debug for both stderr and provenance log file
557-
_logger.setLevel(logging.DEBUG)
558-
stderr_handler.setLevel(logging.DEBUG)
559-
rdflib_logger.setLevel(logging.DEBUG)
560-
fmtclass = coloredlogs.ColoredFormatter if args.enable_color else logging.Formatter
561-
formatter = fmtclass("%(levelname)s %(message)s")
562-
if args.timestamps:
563-
formatter = fmtclass(
564-
"[%(asctime)s] %(levelname)s %(message)s",
565-
"%Y-%m-%d %H:%M:%S")
566-
stderr_handler.setFormatter(formatter)
567-
##
697+
configure_logging(args, stderr_handler, runtimeContext)
568698

569699
if args.version:
570700
print(versionfunc())
@@ -590,60 +720,15 @@ def main(argsl=None, # type: Optional[List[str]]
590720
if not args.enable_ga4gh_tool_registry:
591721
del ga4gh_tool_registries[:]
592722

593-
if custom_schema_callback is not None:
594-
custom_schema_callback()
595-
elif args.enable_ext:
596-
res = pkg_resources.resource_stream(__name__, 'extensions.yml')
597-
use_custom_schema("v1.0", "http://commonwl.org/cwltool", res.read())
598-
res.close()
599-
else:
600-
use_standard_schema("v1.0")
723+
setup_schema(args, custom_schema_callback)
724+
601725
if args.provenance:
602-
if not args.compute_checksum:
603-
_logger.error("--provenance incompatible with --no-compute-checksum")
726+
if argsl is None:
727+
raise Exception("argsl cannot be None")
728+
if setup_provenance(args, argsl, runtimeContext) is not None:
604729
return 1
605-
ro = ResearchObject(
606-
getdefault(runtimeContext.make_fs_access, StdFsAccess),
607-
temp_prefix_ro=args.tmpdir_prefix, orcid=args.orcid,
608-
full_name=args.cwl_full_name)
609-
runtimeContext.research_obj = ro
610-
log_file_io = ro.open_log_file_for_activity(ro.engine_uuid)
611-
prov_log_handler = logging.StreamHandler(cast(IO[str], log_file_io))
612-
613-
class ProvLogFormatter(logging.Formatter):
614-
"""Enforce ISO8601 with both T and Z."""
615-
616-
def __init__(self): # type: () -> None
617-
super(ProvLogFormatter, self).__init__(
618-
"[%(asctime)sZ] %(message)s")
619-
620-
def formatTime(self, record, datefmt=None):
621-
# type: (logging.LogRecord, Optional[str]) -> str
622-
record_time = time.gmtime(record.created)
623-
formatted_time = time.strftime("%Y-%m-%dT%H:%M:%S", record_time)
624-
with_msecs = "%s,%03d" % (formatted_time, record.msecs)
625-
return with_msecs
626-
prov_log_handler.setFormatter(ProvLogFormatter())
627-
_logger.addHandler(prov_log_handler)
628-
_logger.debug(u"[provenance] Logging to %s", log_file_io)
629-
if argsl is not None:
630-
# Log cwltool command line options to provenance file
631-
_logger.info("[cwltool] %s %s", sys.argv[0], u" ".join(argsl))
632-
_logger.debug(u"[cwltool] Arguments: %s", args)
633-
634-
if loadingContext is None:
635-
loadingContext = LoadingContext(vars(args))
636-
else:
637-
loadingContext = loadingContext.copy()
638-
loadingContext.loader = default_loader(loadingContext.fetcher_constructor)
639-
loadingContext.research_obj = runtimeContext.research_obj
640-
loadingContext.disable_js_validation = \
641-
args.disable_js_validation or (not args.do_validate)
642-
loadingContext.construct_tool_object = getdefault(
643-
loadingContext.construct_tool_object, workflow.default_make_tool)
644-
loadingContext.resolver = getdefault(loadingContext.resolver, tool_resolver)
645-
if loadingContext.do_update is None:
646-
loadingContext.do_update = not (args.pack or args.print_subgraph)
730+
731+
loadingContext = setup_loadingContext(loadingContext, runtimeContext, args)
647732

648733
uri, tool_file_uri = resolve_tool_uri(
649734
args.workflow, resolver=loadingContext.resolver,
@@ -692,14 +777,7 @@ def formatTime(self, record, datefmt=None):
692777

693778
tool = make_tool(uri, loadingContext)
694779
if args.make_template:
695-
def my_represent_none(self, data): # pylint: disable=unused-argument
696-
# type: (Any, Any) -> Any
697-
"""Force clean representation of 'null'."""
698-
return self.represent_scalar(u'tag:yaml.org,2002:null', u'null')
699-
yaml.RoundTripRepresenter.add_representer(type(None), my_represent_none)
700-
yaml.round_trip_dump(
701-
generate_input_template(tool), sys.stdout,
702-
default_flow_style=False, indent=4, block_seq_indent=2)
780+
make_template(tool)
703781
return 0
704782

705783
if args.validate:
@@ -722,23 +800,11 @@ def my_represent_none(self, data): # pylint: disable=unused-argument
722800
return 0
723801

724802
if args.target:
725-
if isinstance(tool, Workflow):
726-
url = urllib.parse.urlparse(tool.tool["id"])
727-
if url.fragment:
728-
extracted = get_subgraph([tool.tool["id"] + "/" + r for r in args.target], tool)
729-
else:
730-
extracted = get_subgraph([loadingContext.loader.fetcher.urljoin(tool.tool["id"], "#" + r)
731-
for r in args.target],
732-
tool)
733-
else:
734-
_logger.error("Can only use --target on Workflows")
803+
ctool = choose_target(args, tool, loadingContext)
804+
if ctool is None:
735805
return 1
736-
if isinstance(loadingContext.loader.idx, CommentedMap):
737-
loadingContext.loader.idx[extracted["id"]] = extracted
738-
tool = make_tool(extracted["id"],
739-
loadingContext)
740806
else:
741-
raise Exception("Missing loadingContext.loader.idx!")
807+
tool = ctool
742808

743809
if args.print_subgraph:
744810
if "name" in tool.tool:
@@ -764,6 +830,7 @@ def my_represent_none(self, data): # pylint: disable=unused-argument
764830

765831
if isinstance(tool, int):
766832
return tool
833+
767834
# If on MacOS platform, TMPDIR must be set to be under one of the
768835
# shared volumes in Docker for Mac
769836
# More info: https://dockstore.org/docs/faq
@@ -774,18 +841,8 @@ def my_represent_none(self, data): # pylint: disable=unused-argument
774841
if runtimeContext.tmpdir_prefix == DEFAULT_TMP_PREFIX:
775842
runtimeContext.tmpdir_prefix = default_mac_path
776843

777-
for dirprefix in ("tmpdir_prefix", "tmp_outdir_prefix", "cachedir"):
778-
if getattr(runtimeContext, dirprefix) and getattr(runtimeContext, dirprefix) != DEFAULT_TMP_PREFIX:
779-
sl = "/" if getattr(runtimeContext, dirprefix).endswith("/") or dirprefix == "cachedir" \
780-
else ""
781-
setattr(runtimeContext, dirprefix,
782-
os.path.abspath(getattr(runtimeContext, dirprefix)) + sl)
783-
if not os.path.exists(os.path.dirname(getattr(runtimeContext, dirprefix))):
784-
try:
785-
os.makedirs(os.path.dirname(getattr(runtimeContext, dirprefix)))
786-
except Exception as e:
787-
_logger.error("Failed to create directory: %s", Text(e))
788-
return 1
844+
if check_working_directories(runtimeContext) is not None:
845+
return 1
789846

790847
if args.cachedir:
791848
if args.move_outputs == "move":

0 commit comments

Comments
 (0)