1- from typing import *
2-
31import math
4- from pathlib import Path
52import traceback
6- from tqdm import tqdm
3+ from pathlib import Path
4+ from typing import Any , Callable , Iterator , List , NoReturn , Optional , Union
75
86from seutil import IOUtils , LoggingUtils
7+ from tqdm import tqdm
8+
9+ logger = LoggingUtils .get_logger (__name__ )
910
1011
1112class FilesManager :
1213 """
1314 Handles the loading/dumping of files in a dataset.
1415 """
15- logger = LoggingUtils .get_logger (__name__ )
1616
1717 ALL_LEMMAS_BACKEND_SEXP_TRANSFORMATIONS = "all-lemmas-bsexp-transformations"
1818 ALL_LEMMAS_FOREEND_SEXP_TRANSFORMATIONS = "all-lemmas-fsexp-transformations"
@@ -34,7 +34,7 @@ def __init__(self, data_dir: Path):
3434 def clean_path (self , rel_path : Union [str , List [str ]]):
3535 abs_path = self .data_dir / self .assemble_rel_path (rel_path )
3636 if abs_path .exists ():
37- self . logger .info (f"Removing existing things at { abs_path } " )
37+ logger .info (f"Removing existing things at { abs_path } " )
3838 IOUtils .rm (abs_path )
3939 # end if
4040 return
@@ -43,7 +43,8 @@ def clean_path(self, rel_path: Union[str, List[str]]):
4343 def is_json_format (cls , fmt : IOUtils .Format ) -> bool :
4444 return fmt in [IOUtils .Format .json , IOUtils .Format .jsonPretty , IOUtils .Format .jsonNoSort ]
4545
46- def dump_data (self ,
46+ def dump_data (
47+ self ,
4748 rel_path : Union [str , List [str ]],
4849 data : Any ,
4950 fmt : IOUtils .Format ,
@@ -53,50 +54,43 @@ def dump_data(self,
5354 ):
5455 abs_path = self .data_dir / self .assemble_rel_path (rel_path )
5556 if abs_path .exists () and not exist_ok :
56- LoggingUtils .log_and_raise (self .logger , f"Cannot rewrite existing data at { abs_path } " , IOError )
57- # end if
57+ raise IOError (f"Cannot rewrite existing data at { abs_path } " )
5858
5959 abs_path .parent .mkdir (parents = True , exist_ok = True )
6060 if not is_batched :
6161 if self .is_json_format (fmt ):
6262 data = IOUtils .jsonfy (data )
63- # end if
6463 IOUtils .dump (abs_path , data , fmt )
6564 else :
6665 # In batched mode, the data need to be slice-able and sizable
6766 IOUtils .rm (abs_path )
6867 abs_path .mkdir (parents = True )
6968
70- for batch_i in tqdm (range (math .ceil (len (data )/ per_batch ))):
71- data_batch = data [per_batch * batch_i : per_batch * (batch_i + 1 )]
69+ for batch_i in tqdm (range (math .ceil (len (data ) / per_batch ))):
70+ data_batch = data [per_batch * batch_i : per_batch * (batch_i + 1 )]
7271 if self .is_json_format (fmt ):
7372 data_batch = IOUtils .jsonfy (data_batch )
74- # end if
75- IOUtils .dump (abs_path / f"batch-{ batch_i } .{ fmt .get_extension ()} " , data_batch , fmt )
76- # end for
77- # end if
73+ IOUtils .dump (abs_path / f"batch-{ batch_i } .{ fmt .get_extension ()} " , data_batch , fmt )
7874 return
7975
80- def load_data (self ,
76+ def load_data (
77+ self ,
8178 rel_path : Union [str , List [str ]],
8279 fmt : IOUtils .Format ,
8380 is_batched : bool = False ,
84- clz = None ,
81+ clz = None ,
8582 ) -> Any :
8683 if self .is_json_format (fmt ) and clz is None :
87- self .logger .warning (f"Load data from { rel_path } with json format, but did not specify clz (at { traceback .format_stack ()} )" )
88- # end if
84+ logger .warning (f"Load data from { rel_path } with json format, but did not specify clz (at { traceback .format_stack ()} )" )
8985
9086 abs_path = self .data_dir / self .assemble_rel_path (rel_path )
9187 if not abs_path .exists ():
92- LoggingUtils .log_and_raise (self .logger , f"Cannot find data at { abs_path } " , IOError )
93- # end if
88+ raise IOError (f"Cannot find data at { abs_path } " )
9489
9590 if not is_batched :
9691 data = IOUtils .load (abs_path , fmt )
9792 if self .is_json_format (fmt ) and clz is not None :
9893 data = IOUtils .dejsonfy (data , clz )
99- # end if
10094 return data
10195 else :
10296 data = list ()
@@ -106,25 +100,21 @@ def load_data(self,
106100 data_batch = IOUtils .load (batch_file , fmt )
107101 if self .is_json_format (fmt ) and clz is not None :
108102 data_batch = IOUtils .dejsonfy (data_batch , clz )
109- # end if
110103 data .extend (data_batch )
111- # end for
112104 return data
113- # end if
114105
115- def iter_batched_data (self ,
106+ def iter_batched_data (
107+ self ,
116108 rel_path : Union [str , List [str ]],
117109 fmt : IOUtils .Format ,
118- clz = None ,
110+ clz = None ,
119111 ) -> Iterator :
120112 if self .is_json_format (fmt ) and clz is None :
121- self .logger .warning (f"Load data from { rel_path } with json format, but did not specify clz" )
122- # end if
113+ logger .warning (f"Load data from { rel_path } with json format, but did not specify clz" )
123114
124115 abs_path = self .data_dir / self .assemble_rel_path (rel_path )
125116 if not abs_path .exists ():
126- LoggingUtils .log_and_raise (self .logger , f"Cannot find data at { abs_path } " , IOError )
127- # end if
117+ raise IOError (f"Cannot find data at { abs_path } " )
128118
129119 batch_numbers = sorted ([int (str (f .stem ).split ("-" )[1 ]) for f in abs_path .iterdir ()])
130120 for batch_number in batch_numbers :
@@ -134,10 +124,12 @@ def iter_batched_data(self,
134124 data_entry = IOUtils .dejsonfy (data_entry , clz )
135125 # end if
136126 yield data_entry
137- # end for
138- # end for
139127
140- def dump_ckpt (self , rel_path : Union [str , List [str ]], obj : Any , ckpt_id : int ,
128+ def dump_ckpt (
129+ self ,
130+ rel_path : Union [str , List [str ]],
131+ obj : Any ,
132+ ckpt_id : int ,
141133 dump_func : Callable [[Any , str ], NoReturn ],
142134 ckpt_keep_max : int = 5 ,
143135 ) -> NoReturn :
@@ -152,25 +144,23 @@ def dump_ckpt(self, rel_path: Union[str, List[str]], obj: Any, ckpt_id: int,
152144 ckpt_ids = [int (str (f .name )) for f in abs_path .iterdir ()]
153145 for ckpt_id in sorted (ckpt_ids )[:- ckpt_keep_max ]:
154146 IOUtils .rm (abs_path / str (ckpt_id ))
155- # end for
156- # end if
157147 return
158148
159- def load_ckpt (self , rel_path : Union [str , List [str ]],
149+ def load_ckpt (
150+ self ,
151+ rel_path : Union [str , List [str ]],
160152 load_func : Callable [[str ], Any ],
161153 ckpt_id : Optional [int ] = None ,
162154 ) -> Any :
163155 abs_path = self .data_dir / self .assemble_rel_path (rel_path )
164156 if not abs_path .exists ():
165- LoggingUtils .log_and_raise (self .logger , f"Cannot find data at { abs_path } " , IOError )
166- # end if
157+ raise IOError (f"Cannot find data at { abs_path } " )
167158
168159 if ckpt_id is None :
169160 # Find the latest ckpt
170161 ckpt_ids = [int (str (f .name )) for f in abs_path .iterdir ()]
171162 ckpt_id = max (ckpt_ids )
172- self .logger .info (f"Loading the latest checkpoint { ckpt_id } at { abs_path } " )
173- # end if
163+ logger .info (f"Loading the latest checkpoint { ckpt_id } at { abs_path } " )
174164
175165 return load_func (str (abs_path / str (ckpt_id )))
176166
@@ -181,5 +171,4 @@ def resolve(self, rel_path: Union[str, List[str]]) -> Path:
181171 def assemble_rel_path (cls , rel_path : Union [str , List [str ]]) -> str :
182172 if not isinstance (rel_path , str ):
183173 rel_path = "/" .join (rel_path )
184- # end if
185174 return rel_path
0 commit comments