@@ -86,26 +86,32 @@ def add_parser_arguments(parser) -> None:
8686 parser .add_argument ("--gpu_ids" , nargs = "+" , default = [0 ], type = int , help = "gpu_id(s)" )
8787
8888
89- def main (args : Optional [argparse .Namespace ] = None ) -> None :
89+ def main (args : Optional [argparse .Namespace ] = None ):
9090 """Trains a model."""
9191 time_start = time .time ()
92+
9293 if args is None :
9394 parser = argparse .ArgumentParser ()
9495 add_parser_arguments (parser )
9596 args = parser .parse_args ()
96- if args .json and not args .json .exists ():
97- save_default_train_options (args .json )
97+
98+ args .path_json = Path (args .json )
99+
100+ if args .path_json and not args .path_json .exists ():
101+ save_default_train_options (args .path_json )
98102 return
99- with open (args .json , "r" ) as fi :
103+
104+ with open (args .path_json , "r" ) as fi :
100105 train_options = json .load (fi )
106+
101107 args .__dict__ .update (train_options )
102108 add_logging_file_handler (Path (args .path_save_dir , "train_model.log" ))
103109 logger .info (f"Started training at: { datetime .datetime .now ()} " )
104110
105111 set_seeds (args .seed )
106112 log_training_options (vars (args ))
107113 path_model = os .path .join (args .path_save_dir , "model.p" )
108- model = fnet .models .load_or_init_model (path_model , args .json )
114+ model = fnet .models .load_or_init_model (path_model , args .path_json )
109115 init_cuda (args .gpu_ids [0 ])
110116 model .to_gpu (args .gpu_ids )
111117 logger .info (model )
@@ -124,6 +130,8 @@ def main(args: Optional[argparse.Namespace] = None) -> None:
124130 # Get patch pair providers
125131 bpds_train = get_bpds_train (args )
126132 bpds_val = get_bpds_val (args )
133+
134+ # MAIN LOOP
127135 for idx_iter in range (model .count_iter , args .n_iter ):
128136 do_save = ((idx_iter + 1 ) % args .interval_save == 0 ) or (
129137 (idx_iter + 1 ) == args .n_iter
@@ -164,6 +172,8 @@ def main(args: Optional[argparse.Namespace] = None) -> None:
164172 path_save = os .path .join (args .path_save_dir , "loss_curves.png" ),
165173 )
166174
175+ return model
176+
167177
168178def train_model (
169179 batch_size : int = 28 ,
@@ -182,8 +192,9 @@ def train_model(
182192 seed : Optional [int ] = None ,
183193 json : Optional [str ] = None ,
184194 gpu_ids : Optional [List [int ]] = None ,
185- ) -> None :
195+ ):
186196 """Python API for training."""
197+
187198 bpds_kwargs = bpds_kwargs or {
188199 "buffer_size" : 16 ,
189200 "buffer_switch_interval" : 2800 , # every 100 updates
@@ -201,7 +212,8 @@ def train_model(
201212 }
202213 iter_checkpoint = iter_checkpoint or []
203214 gpu_ids = gpu_ids or [0 ]
204- json = json or str (Path (path_save_dir , "train_options.json" ))
215+
216+ json = json or f"{ path_save_dir } train_options.json"
205217
206218 pnames , _ , _ , locs = inspect .getargvalues (inspect .currentframe ())
207219 train_options = {k : locs [k ] for k in pnames }
@@ -214,10 +226,11 @@ def train_model(
214226 path_json .parent .mkdir (parents = True )
215227
216228 json = globals ()["json" ] # retrieve global module
217- with path_json .open ("w" ) as fo :
218- json .dump (train_options , fo , indent = 4 , sort_keys = True )
229+ with path_json .open ("w" ) as f :
230+ json .dump (train_options , f , indent = 4 , sort_keys = True )
219231 logger .info (f"Saved: { path_json } " )
220232
221233 args = argparse .Namespace ()
222234 args .__dict__ .update (train_options )
223- main (args )
235+
236+ return main (args )
0 commit comments