Skip to content

Commit 9638882

Browse files
alexiswlmr-c
authored andcommitted
Use yaml.dump over old dump command, stripped double .cwl, and fixed imports
Added cwlformat to requirements (needed for --pretty parameter),
1 parent c1875d5 commit 9638882

File tree

2 files changed

+80
-38
lines changed

2 files changed

+80
-38
lines changed

cwl_utils/graph_split.py

Lines changed: 79 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,39 @@
1212
import json
1313
import os
1414
import sys
15-
from collections.abc import MutableMapping
16-
from typing import IO, TYPE_CHECKING, Any, Union, cast
15+
from io import TextIOWrapper
16+
from pathlib import Path
17+
from typing import (
18+
IO,
19+
TYPE_CHECKING,
20+
Any,
21+
List,
22+
MutableMapping,
23+
Set,
24+
Union,
25+
cast,
26+
Optional, TextIO,
27+
)
28+
import logging
29+
import re
1730

1831
from cwlformat.formatter import stringify_dict
19-
from ruamel.yaml.dumper import RoundTripDumper
20-
from ruamel.yaml.main import YAML, dump
32+
from ruamel.yaml.comments import Format
33+
from ruamel.yaml.main import YAML
2134
from ruamel.yaml.representer import RoundTripRepresenter
2235
from schema_salad.sourceline import SourceLine, add_lc_filename
2336

37+
from cwl_utils.loghandler import _logger as _cwlutilslogger
38+
2439
if TYPE_CHECKING:
2540
from _typeshed import StrPath
2641

42+
_logger = logging.getLogger("cwl-graph-split") # pylint: disable=invalid-name
43+
defaultStreamHandler = logging.StreamHandler() # pylint: disable=invalid-name
44+
_logger.addHandler(defaultStreamHandler)
45+
_logger.setLevel(logging.INFO)
46+
_cwlutilslogger.setLevel(100)
47+
2748

2849
def arg_parser() -> argparse.ArgumentParser:
2950
"""Build the argument parser."""
@@ -82,11 +103,11 @@ def run(args: list[str]) -> int:
82103

83104

84105
def graph_split(
85-
sourceIO: IO[str],
86-
output_dir: "StrPath",
87-
output_format: str,
88-
mainfile: str,
89-
pretty: bool,
106+
sourceIO: IO[str],
107+
output_dir: "StrPath",
108+
output_format: str,
109+
mainfile: str,
110+
pretty: bool,
90111
) -> None:
91112
"""Loop over the provided packed CWL document and split it up."""
92113
yaml = YAML(typ="rt")
@@ -100,8 +121,15 @@ def graph_split(
100121

101122
version = source.pop("cwlVersion")
102123

124+
# Check outdir parent exists
125+
if not Path(output_dir).parent.is_dir():
126+
raise NotADirectoryError(f"Parent directory of {output_dir} does not exist")
127+
# If output_dir is not a directory, create it
128+
if not Path(output_dir).is_dir():
129+
os.mkdir(output_dir)
130+
103131
def my_represent_none(
104-
self: Any, data: Any
132+
self: Any, data: Any
105133
) -> Any: # pylint: disable=unused-argument
106134
"""Force clean representation of 'null'."""
107135
return self.represent_scalar("tag:yaml.org,2002:null", "null")
@@ -111,7 +139,7 @@ def my_represent_none(
111139
for entry in source["$graph"]:
112140
entry_id = entry.pop("id").lstrip("#")
113141
entry["cwlVersion"] = version
114-
imports = rewrite(entry, entry_id)
142+
imports = rewrite(entry, entry_id, Path(output_dir))
115143
if imports:
116144
for import_name in imports:
117145
rewrite_types(entry, f"#{import_name}", False)
@@ -121,47 +149,47 @@ def my_represent_none(
121149
else:
122150
entry_id = mainfile
123151

124-
output_file = os.path.join(output_dir, entry_id + ".cwl")
152+
output_file = Path(output_dir) / (re.sub(".cwl$", "", entry_id) + ".cwl")
125153
if output_format == "json":
126154
json_dump(entry, output_file)
127155
elif output_format == "yaml":
128156
yaml_dump(entry, output_file, pretty)
129157

130158

131-
def rewrite(document: Any, doc_id: str) -> set[str]:
159+
def rewrite(document: Any, doc_id: str, output_dir: Path, pretty: Optional[bool] = False) -> Set[str]:
132160
"""Rewrite the given element from the CWL $graph."""
133161
imports = set()
134162
if isinstance(document, list) and not isinstance(document, str):
135163
for entry in document:
136-
imports.update(rewrite(entry, doc_id))
164+
imports.update(rewrite(entry, doc_id, output_dir, pretty))
137165
elif isinstance(document, dict):
138166
this_id = document["id"] if "id" in document else None
139167
for key, value in document.items():
140168
with SourceLine(document, key, Exception):
141169
if key == "run" and isinstance(value, str) and value[0] == "#":
142-
document[key] = f"{value[1:]}.cwl"
170+
document[key] = f"{re.sub('.cwl$', '', value[1:])}.cwl"
143171
elif key in ("id", "outputSource") and value.startswith("#" + doc_id):
144-
document[key] = value[len(doc_id) + 2 :]
172+
document[key] = value[len(doc_id) + 2:]
145173
elif key == "out" and isinstance(value, list):
146174

147175
def rewrite_id(entry: Any) -> Union[MutableMapping[Any, Any], str]:
148176
if isinstance(entry, MutableMapping):
149177
if entry["id"].startswith(this_id):
150178
assert isinstance(this_id, str) # nosec B101
151-
entry["id"] = cast(str, entry["id"])[len(this_id) + 1 :]
179+
entry["id"] = cast(str, entry["id"])[len(this_id) + 1:]
152180
return entry
153181
elif isinstance(entry, str):
154182
if this_id and entry.startswith(this_id):
155-
return entry[len(this_id) + 1 :]
183+
return entry[len(this_id) + 1:]
156184
return entry
157185
raise Exception(f"{entry} is neither a dictionary nor string.")
158186

159187
document[key][:] = [rewrite_id(entry) for entry in value]
160188
elif key in ("source", "scatter", "items", "format"):
161189
if (
162-
isinstance(value, str)
163-
and value.startswith("#")
164-
and "/" in value
190+
isinstance(value, str)
191+
and value.startswith("#")
192+
and "/" in value
165193
):
166194
referrant_file, sub = value[1:].split("/", 1)
167195
if referrant_file == doc_id:
@@ -172,22 +200,22 @@ def rewrite_id(entry: Any) -> Union[MutableMapping[Any, Any], str]:
172200
new_sources = list()
173201
for entry in value:
174202
if entry.startswith("#" + doc_id):
175-
new_sources.append(entry[len(doc_id) + 2 :])
203+
new_sources.append(entry[len(doc_id) + 2:])
176204
else:
177205
new_sources.append(entry)
178206
document[key] = new_sources
179207
elif key == "$import":
180208
rewrite_import(document)
181209
elif key == "class" and value == "SchemaDefRequirement":
182-
return rewrite_schemadef(document)
210+
return rewrite_schemadef(document, output_dir, pretty)
183211
else:
184-
imports.update(rewrite(value, doc_id))
212+
imports.update(rewrite(value, doc_id, output_dir, pretty))
185213
return imports
186214

187215

188216
def rewrite_import(document: MutableMapping[str, Any]) -> None:
189217
"""Adjust the $import directive."""
190-
external_file = document["$import"].split("/")[0][1:]
218+
external_file = document["$import"].split("/")[0].lstrip("#")
191219
document["$import"] = external_file
192220

193221

@@ -203,7 +231,7 @@ def rewrite_types(field: Any, entry_file: str, sameself: bool) -> None:
203231
if key == name:
204232
if isinstance(value, str) and value.startswith(entry_file):
205233
if sameself:
206-
field[key] = value[len(entry_file) + 1 :]
234+
field[key] = value[len(entry_file) + 1:]
207235
else:
208236
field[key] = "{d[0]}#{d[1]}".format(
209237
d=value[1:].split("/", 1)
@@ -215,19 +243,19 @@ def rewrite_types(field: Any, entry_file: str, sameself: bool) -> None:
215243
rewrite_types(entry, entry_file, sameself)
216244

217245

218-
def rewrite_schemadef(document: MutableMapping[str, Any]) -> set[str]:
246+
def rewrite_schemadef(document: MutableMapping[str, Any], output_dir: Path, pretty: Optional[bool] = False) -> Set[str]:
219247
"""Dump the schemadefs to their own file."""
220248
for entry in document["types"]:
221249
if "$import" in entry:
222250
rewrite_import(entry)
223251
elif "name" in entry and "/" in entry["name"]:
224-
entry_file, entry["name"] = entry["name"].split("/")
252+
entry_file, entry["name"] = entry["name"].lstrip("#").split("/")
225253
for field in entry["fields"]:
226254
field["name"] = field["name"].split("/")[2]
227255
rewrite_types(field, entry_file, True)
228-
with open(entry_file[1:], "a", encoding="utf-8") as entry_handle:
229-
dump([entry], entry_handle, Dumper=RoundTripDumper)
230-
entry["$import"] = entry_file[1:]
256+
with open(output_dir / entry_file, "a", encoding="utf-8") as entry_handle:
257+
yaml_dump(entry, entry_handle, pretty)
258+
entry["$import"] = entry_file
231259
del entry["name"]
232260
del entry["type"]
233261
del entry["fields"]
@@ -253,20 +281,33 @@ def json_dump(entry: Any, output_file: str) -> None:
253281
json.dump(entry, result_handle, indent=4)
254282

255283

256-
def yaml_dump(entry: Any, output_file: str, pretty: bool) -> None:
284+
def yaml_dump(entry: Any, output_file_or_handle: Optional[Union[str, Path, TextIOWrapper, TextIO]], pretty: bool) -> None:
257285
"""Output object as YAML."""
258-
yaml = YAML(typ="rt")
286+
yaml = YAML(typ="rt", pure=True)
259287
yaml.default_flow_style = False
260-
yaml.map_indent = 4
261-
yaml.sequence_indent = 2
262-
with open(output_file, "w", encoding="utf-8") as result_handle:
288+
yaml.indent = 4
289+
yaml.block_seq_indent = 2
290+
291+
if isinstance(output_file_or_handle, (str, Path)):
292+
with open(output_file_or_handle, "w", encoding="utf-8") as result_handle:
293+
if pretty:
294+
result_handle.write(stringify_dict(entry))
295+
else:
296+
yaml.dump(
297+
entry,
298+
result_handle
299+
)
300+
elif isinstance(output_file_or_handle, (TextIOWrapper, TextIO)):
263301
if pretty:
264-
result_handle.write(stringify_dict(entry))
302+
output_file_or_handle.write(stringify_dict(entry))
265303
else:
266304
yaml.dump(
267305
entry,
268-
result_handle,
306+
output_file_or_handle
269307
)
308+
else:
309+
raise ValueError(
310+
f"output_file_or_handle must be a string or a file handle but got {type(output_file_or_handle)}")
270311

271312

272313
if __name__ == "__main__":

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,4 @@ requests
55
schema-salad >= 8.8.20250205075315,<9
66
ruamel.yaml >= 0.17.6, < 0.19
77
typing_extensions;python_version<'3.10'
8+
cwlformat >= 2022.2.18

0 commit comments

Comments
 (0)