1212import json
1313import os
1414import sys
15- from collections .abc import MutableMapping
16- from typing import IO , TYPE_CHECKING , Any , Union , cast
15+ from io import TextIOWrapper
16+ from pathlib import Path
17+ from typing import (
18+ IO ,
19+ TYPE_CHECKING ,
20+ Any ,
21+ List ,
22+ MutableMapping ,
23+ Set ,
24+ Union ,
25+ cast ,
26+ Optional , TextIO ,
27+ )
28+ import logging
29+ import re
1730
1831from cwlformat .formatter import stringify_dict
19- from ruamel .yaml .dumper import RoundTripDumper
20- from ruamel .yaml .main import YAML , dump
32+ from ruamel .yaml .comments import Format
33+ from ruamel .yaml .main import YAML
2134from ruamel .yaml .representer import RoundTripRepresenter
2235from schema_salad .sourceline import SourceLine , add_lc_filename
2336
37+ from cwl_utils .loghandler import _logger as _cwlutilslogger
38+
2439if TYPE_CHECKING :
2540 from _typeshed import StrPath
2641
42+ _logger = logging .getLogger ("cwl-graph-split" ) # pylint: disable=invalid-name
43+ defaultStreamHandler = logging .StreamHandler () # pylint: disable=invalid-name
44+ _logger .addHandler (defaultStreamHandler )
45+ _logger .setLevel (logging .INFO )
46+ _cwlutilslogger .setLevel (100 )
47+
2748
2849def arg_parser () -> argparse .ArgumentParser :
2950 """Build the argument parser."""
@@ -82,11 +103,11 @@ def run(args: list[str]) -> int:
82103
83104
84105def graph_split (
85- sourceIO : IO [str ],
86- output_dir : "StrPath" ,
87- output_format : str ,
88- mainfile : str ,
89- pretty : bool ,
106+ sourceIO : IO [str ],
107+ output_dir : "StrPath" ,
108+ output_format : str ,
109+ mainfile : str ,
110+ pretty : bool ,
90111) -> None :
91112 """Loop over the provided packed CWL document and split it up."""
92113 yaml = YAML (typ = "rt" )
@@ -100,8 +121,15 @@ def graph_split(
100121
101122 version = source .pop ("cwlVersion" )
102123
124+ # Check outdir parent exists
125+ if not Path (output_dir ).parent .is_dir ():
126+ raise NotADirectoryError (f"Parent directory of { output_dir } does not exist" )
127+ # If output_dir is not a directory, create it
128+ if not Path (output_dir ).is_dir ():
129+ os .mkdir (output_dir )
130+
103131 def my_represent_none (
104- self : Any , data : Any
132+ self : Any , data : Any
105133 ) -> Any : # pylint: disable=unused-argument
106134 """Force clean representation of 'null'."""
107135 return self .represent_scalar ("tag:yaml.org,2002:null" , "null" )
@@ -111,7 +139,7 @@ def my_represent_none(
111139 for entry in source ["$graph" ]:
112140 entry_id = entry .pop ("id" ).lstrip ("#" )
113141 entry ["cwlVersion" ] = version
114- imports = rewrite (entry , entry_id )
142+ imports = rewrite (entry , entry_id , Path ( output_dir ) )
115143 if imports :
116144 for import_name in imports :
117145 rewrite_types (entry , f"#{ import_name } " , False )
@@ -121,47 +149,47 @@ def my_represent_none(
121149 else :
122150 entry_id = mainfile
123151
124- output_file = os . path . join (output_dir , entry_id + ".cwl" )
152+ output_file = Path (output_dir ) / ( re . sub ( ".cwl$" , "" , entry_id ) + ".cwl" )
125153 if output_format == "json" :
126154 json_dump (entry , output_file )
127155 elif output_format == "yaml" :
128156 yaml_dump (entry , output_file , pretty )
129157
130158
131- def rewrite (document : Any , doc_id : str ) -> set [str ]:
159+ def rewrite (document : Any , doc_id : str , output_dir : Path , pretty : Optional [ bool ] = False ) -> Set [str ]:
132160 """Rewrite the given element from the CWL $graph."""
133161 imports = set ()
134162 if isinstance (document , list ) and not isinstance (document , str ):
135163 for entry in document :
136- imports .update (rewrite (entry , doc_id ))
164+ imports .update (rewrite (entry , doc_id , output_dir , pretty ))
137165 elif isinstance (document , dict ):
138166 this_id = document ["id" ] if "id" in document else None
139167 for key , value in document .items ():
140168 with SourceLine (document , key , Exception ):
141169 if key == "run" and isinstance (value , str ) and value [0 ] == "#" :
142- document [key ] = f"{ value [1 :]} .cwl"
170+ document [key ] = f"{ re . sub ( '.cwl$' , '' , value [1 :]) } .cwl"
143171 elif key in ("id" , "outputSource" ) and value .startswith ("#" + doc_id ):
144- document [key ] = value [len (doc_id ) + 2 :]
172+ document [key ] = value [len (doc_id ) + 2 :]
145173 elif key == "out" and isinstance (value , list ):
146174
147175 def rewrite_id (entry : Any ) -> Union [MutableMapping [Any , Any ], str ]:
148176 if isinstance (entry , MutableMapping ):
149177 if entry ["id" ].startswith (this_id ):
150178 assert isinstance (this_id , str ) # nosec B101
151- entry ["id" ] = cast (str , entry ["id" ])[len (this_id ) + 1 :]
179+ entry ["id" ] = cast (str , entry ["id" ])[len (this_id ) + 1 :]
152180 return entry
153181 elif isinstance (entry , str ):
154182 if this_id and entry .startswith (this_id ):
155- return entry [len (this_id ) + 1 :]
183+ return entry [len (this_id ) + 1 :]
156184 return entry
157185 raise Exception (f"{ entry } is neither a dictionary nor string." )
158186
159187 document [key ][:] = [rewrite_id (entry ) for entry in value ]
160188 elif key in ("source" , "scatter" , "items" , "format" ):
161189 if (
162- isinstance (value , str )
163- and value .startswith ("#" )
164- and "/" in value
190+ isinstance (value , str )
191+ and value .startswith ("#" )
192+ and "/" in value
165193 ):
166194 referrant_file , sub = value [1 :].split ("/" , 1 )
167195 if referrant_file == doc_id :
@@ -172,22 +200,22 @@ def rewrite_id(entry: Any) -> Union[MutableMapping[Any, Any], str]:
172200 new_sources = list ()
173201 for entry in value :
174202 if entry .startswith ("#" + doc_id ):
175- new_sources .append (entry [len (doc_id ) + 2 :])
203+ new_sources .append (entry [len (doc_id ) + 2 :])
176204 else :
177205 new_sources .append (entry )
178206 document [key ] = new_sources
179207 elif key == "$import" :
180208 rewrite_import (document )
181209 elif key == "class" and value == "SchemaDefRequirement" :
182- return rewrite_schemadef (document )
210+ return rewrite_schemadef (document , output_dir , pretty )
183211 else :
184- imports .update (rewrite (value , doc_id ))
212+ imports .update (rewrite (value , doc_id , output_dir , pretty ))
185213 return imports
186214
187215
188216def rewrite_import (document : MutableMapping [str , Any ]) -> None :
189217 """Adjust the $import directive."""
190- external_file = document ["$import" ].split ("/" )[0 ][ 1 :]
218+ external_file = document ["$import" ].split ("/" )[0 ]. lstrip ( "#" )
191219 document ["$import" ] = external_file
192220
193221
@@ -203,7 +231,7 @@ def rewrite_types(field: Any, entry_file: str, sameself: bool) -> None:
203231 if key == name :
204232 if isinstance (value , str ) and value .startswith (entry_file ):
205233 if sameself :
206- field [key ] = value [len (entry_file ) + 1 :]
234+ field [key ] = value [len (entry_file ) + 1 :]
207235 else :
208236 field [key ] = "{d[0]}#{d[1]}" .format (
209237 d = value [1 :].split ("/" , 1 )
@@ -215,19 +243,19 @@ def rewrite_types(field: Any, entry_file: str, sameself: bool) -> None:
215243 rewrite_types (entry , entry_file , sameself )
216244
217245
218- def rewrite_schemadef (document : MutableMapping [str , Any ]) -> set [str ]:
246+ def rewrite_schemadef (document : MutableMapping [str , Any ], output_dir : Path , pretty : Optional [ bool ] = False ) -> Set [str ]:
219247 """Dump the schemadefs to their own file."""
220248 for entry in document ["types" ]:
221249 if "$import" in entry :
222250 rewrite_import (entry )
223251 elif "name" in entry and "/" in entry ["name" ]:
224- entry_file , entry ["name" ] = entry ["name" ].split ("/" )
252+ entry_file , entry ["name" ] = entry ["name" ].lstrip ( "#" ). split ("/" )
225253 for field in entry ["fields" ]:
226254 field ["name" ] = field ["name" ].split ("/" )[2 ]
227255 rewrite_types (field , entry_file , True )
228- with open (entry_file [ 1 :] , "a" , encoding = "utf-8" ) as entry_handle :
229- dump ([ entry ] , entry_handle , Dumper = RoundTripDumper )
230- entry ["$import" ] = entry_file [ 1 :]
256+ with open (output_dir / entry_file , "a" , encoding = "utf-8" ) as entry_handle :
257+ yaml_dump ( entry , entry_handle , pretty )
258+ entry ["$import" ] = entry_file
231259 del entry ["name" ]
232260 del entry ["type" ]
233261 del entry ["fields" ]
@@ -253,20 +281,33 @@ def json_dump(entry: Any, output_file: str) -> None:
253281 json .dump (entry , result_handle , indent = 4 )
254282
255283
256- def yaml_dump (entry : Any , output_file : str , pretty : bool ) -> None :
284+ def yaml_dump (entry : Any , output_file_or_handle : Optional [ Union [ str , Path , TextIOWrapper , TextIO ]] , pretty : bool ) -> None :
257285 """Output object as YAML."""
258- yaml = YAML (typ = "rt" )
286+ yaml = YAML (typ = "rt" , pure = True )
259287 yaml .default_flow_style = False
260- yaml .map_indent = 4
261- yaml .sequence_indent = 2
262- with open (output_file , "w" , encoding = "utf-8" ) as result_handle :
288+ yaml .indent = 4
289+ yaml .block_seq_indent = 2
290+
291+ if isinstance (output_file_or_handle , (str , Path )):
292+ with open (output_file_or_handle , "w" , encoding = "utf-8" ) as result_handle :
293+ if pretty :
294+ result_handle .write (stringify_dict (entry ))
295+ else :
296+ yaml .dump (
297+ entry ,
298+ result_handle
299+ )
300+ elif isinstance (output_file_or_handle , (TextIOWrapper , TextIO )):
263301 if pretty :
264- result_handle .write (stringify_dict (entry ))
302+ output_file_or_handle .write (stringify_dict (entry ))
265303 else :
266304 yaml .dump (
267305 entry ,
268- result_handle ,
306+ output_file_or_handle
269307 )
308+ else :
309+ raise ValueError (
310+ f"output_file_or_handle must be a string or a file handle but got { type (output_file_or_handle )} " )
270311
271312
272313if __name__ == "__main__" :
0 commit comments