1010
1111import argparse
1212import json
13+ import logging
1314import os
15+ import re
1416import sys
1517from collections .abc import MutableMapping
16- from typing import IO , TYPE_CHECKING , Any , Union , cast
18+ from io import TextIOWrapper
19+ from pathlib import Path
20+ from typing import (
21+ IO ,
22+ Any ,
23+ TextIO ,
24+ Union ,
25+ cast ,
26+ )
1727
1828from cwlformat .formatter import stringify_dict
19- from ruamel .yaml .dumper import RoundTripDumper
20- from ruamel .yaml .main import YAML , dump
29+ from ruamel .yaml .main import YAML
2130from ruamel .yaml .representer import RoundTripRepresenter
2231from schema_salad .sourceline import SourceLine , add_lc_filename
2332
24- if TYPE_CHECKING :
25- from _typeshed import StrPath
33+ from cwl_utils .loghandler import _logger as _cwlutilslogger
34+
35+ _logger = logging .getLogger ("cwl-graph-split" ) # pylint: disable=invalid-name
36+ defaultStreamHandler = logging .StreamHandler () # pylint: disable=invalid-name
37+ _logger .addHandler (defaultStreamHandler )
38+ _logger .setLevel (logging .INFO )
39+ _cwlutilslogger .setLevel (100 )
2640
2741
2842def arg_parser () -> argparse .ArgumentParser :
@@ -73,7 +87,7 @@ def run(args: list[str]) -> int:
7387 with open (options .cwlfile ) as source_handle :
7488 graph_split (
7589 source_handle ,
76- options .outdir ,
90+ Path ( options .outdir ) ,
7791 options .output_format ,
7892 options .mainfile ,
7993 options .pretty ,
@@ -83,7 +97,7 @@ def run(args: list[str]) -> int:
8397
8498def graph_split (
8599 sourceIO : IO [str ],
86- output_dir : "StrPath" ,
100+ output_dir : Path ,
87101 output_format : str ,
88102 mainfile : str ,
89103 pretty : bool ,
@@ -100,6 +114,13 @@ def graph_split(
100114
101115 version = source .pop ("cwlVersion" )
102116
117+ # Check outdir parent exists
118+ if not output_dir .parent .is_dir ():
119+ raise NotADirectoryError (f"Parent directory of { output_dir } does not exist" )
120+ # If output_dir is not a directory, create it
121+ if not output_dir .is_dir ():
122+ output_dir .mkdir ()
123+
103124 def my_represent_none (
104125 self : Any , data : Any
105126 ) -> Any : # pylint: disable=unused-argument
@@ -111,7 +132,7 @@ def my_represent_none(
111132 for entry in source ["$graph" ]:
112133 entry_id = entry .pop ("id" ).lstrip ("#" )
113134 entry ["cwlVersion" ] = version
114- imports = rewrite (entry , entry_id )
135+ imports = rewrite (entry , entry_id , output_dir )
115136 if imports :
116137 for import_name in imports :
117138 rewrite_types (entry , f"#{ import_name } " , False )
@@ -121,25 +142,27 @@ def my_represent_none(
121142 else :
122143 entry_id = mainfile
123144
124- output_file = os . path . join ( output_dir , entry_id + ".cwl" )
145+ output_file = output_dir / ( re . sub ( ".cwl$" , "" , entry_id ) + ".cwl" )
125146 if output_format == "json" :
126147 json_dump (entry , output_file )
127148 elif output_format == "yaml" :
128149 yaml_dump (entry , output_file , pretty )
129150
130151
131- def rewrite (document : Any , doc_id : str ) -> set [str ]:
152+ def rewrite (
153+ document : Any , doc_id : str , output_dir : Path , pretty : bool = False
154+ ) -> set [str ]:
132155 """Rewrite the given element from the CWL $graph."""
133156 imports = set ()
134157 if isinstance (document , list ) and not isinstance (document , str ):
135158 for entry in document :
136- imports .update (rewrite (entry , doc_id ))
159+ imports .update (rewrite (entry , doc_id , output_dir , pretty ))
137160 elif isinstance (document , dict ):
138161 this_id = document ["id" ] if "id" in document else None
139162 for key , value in document .items ():
140163 with SourceLine (document , key , Exception ):
141164 if key == "run" and isinstance (value , str ) and value [0 ] == "#" :
142- document [key ] = f"{ value [1 :]} .cwl"
165+ document [key ] = f"{ re . sub ( '.cwl$' , '' , value [1 :]) } .cwl"
143166 elif key in ("id" , "outputSource" ) and value .startswith ("#" + doc_id ):
144167 document [key ] = value [len (doc_id ) + 2 :]
145168 elif key == "out" and isinstance (value , list ):
@@ -179,15 +202,15 @@ def rewrite_id(entry: Any) -> Union[MutableMapping[Any, Any], str]:
179202 elif key == "$import" :
180203 rewrite_import (document )
181204 elif key == "class" and value == "SchemaDefRequirement" :
182- return rewrite_schemadef (document )
205+ return rewrite_schemadef (document , output_dir , pretty )
183206 else :
184- imports .update (rewrite (value , doc_id ))
207+ imports .update (rewrite (value , doc_id , output_dir , pretty ))
185208 return imports
186209
187210
188211def rewrite_import (document : MutableMapping [str , Any ]) -> None :
189212 """Adjust the $import directive."""
190- external_file = document ["$import" ].split ("/" )[0 ][ 1 :]
213+ external_file = document ["$import" ].split ("/" )[0 ]. lstrip ( "#" )
191214 document ["$import" ] = external_file
192215
193216
@@ -215,19 +238,21 @@ def rewrite_types(field: Any, entry_file: str, sameself: bool) -> None:
215238 rewrite_types (entry , entry_file , sameself )
216239
217240
218- def rewrite_schemadef (document : MutableMapping [str , Any ]) -> set [str ]:
241+ def rewrite_schemadef (
242+ document : MutableMapping [str , Any ], output_dir : Path , pretty : bool = False
243+ ) -> set [str ]:
219244 """Dump the schemadefs to their own file."""
220245 for entry in document ["types" ]:
221246 if "$import" in entry :
222247 rewrite_import (entry )
223248 elif "name" in entry and "/" in entry ["name" ]:
224- entry_file , entry ["name" ] = entry ["name" ].split ("/" )
249+ entry_file , entry ["name" ] = entry ["name" ].lstrip ( "#" ). split ("/" )
225250 for field in entry ["fields" ]:
226251 field ["name" ] = field ["name" ].split ("/" )[2 ]
227252 rewrite_types (field , entry_file , True )
228- with open ( entry_file [ 1 :], "a" , encoding = "utf-8" ) as entry_handle :
229- dump ([ entry ] , entry_handle , Dumper = RoundTripDumper )
230- entry ["$import" ] = entry_file [ 1 :]
253+ with ( output_dir / entry_file ). open ( "a" , encoding = "utf-8" ) as entry_handle :
254+ yaml_dump ( entry , entry_handle , pretty )
255+ entry ["$import" ] = entry_file
231256 del entry ["name" ]
232257 del entry ["type" ]
233258 del entry ["fields" ]
@@ -247,26 +272,40 @@ def seen_import(entry: MutableMapping[str, Any]) -> bool:
247272 return seen_imports
248273
249274
250- def json_dump (entry : Any , output_file : str ) -> None :
275+ def json_dump (entry : Any , output_file : Path ) -> None :
251276 """Output object as JSON."""
252- with open (output_file , "w" , encoding = "utf-8" ) as result_handle :
277+ with output_file . open ("w" , encoding = "utf-8" ) as result_handle :
253278 json .dump (entry , result_handle , indent = 4 )
254279
255280
256- def yaml_dump (entry : Any , output_file : str , pretty : bool ) -> None :
281+ def yaml_dump (
282+ entry : Any ,
283+ output_file_or_handle : Union [str , Path , TextIOWrapper , TextIO ],
284+ pretty : bool ,
285+ ) -> None :
257286 """Output object as YAML."""
258- yaml = YAML (typ = "rt" )
287+ yaml = YAML (typ = "rt" , pure = True )
259288 yaml .default_flow_style = False
260- yaml .map_indent = 4
261- yaml .sequence_indent = 2
262- with open (output_file , "w" , encoding = "utf-8" ) as result_handle :
289+ yaml .indent = 4
290+ yaml .block_seq_indent = 2
291+
292+ if isinstance (output_file_or_handle , (str , Path )):
293+ with open (output_file_or_handle , "w" , encoding = "utf-8" ) as result_handle :
294+ if pretty :
295+ result_handle .write (stringify_dict (entry ))
296+ return
297+ yaml .dump (entry , result_handle )
298+ return
299+ elif isinstance (output_file_or_handle , (TextIOWrapper , TextIO )):
263300 if pretty :
264- result_handle .write (stringify_dict (entry ))
265- else :
266- yaml .dump (
267- entry ,
268- result_handle ,
269- )
301+ output_file_or_handle .write (stringify_dict (entry ))
302+ return
303+ yaml .dump (entry , output_file_or_handle )
304+ return
305+ else :
306+ raise ValueError (
307+ f"output_file_or_handle must be a string or a file handle but got { type (output_file_or_handle )} "
308+ )
270309
271310
272311if __name__ == "__main__" :
0 commit comments