@@ -455,7 +455,7 @@ def model_output(
455455 inputs = inputs .to ("cpu" )
456456
457457 model_output = lambda inputs : post_process_transforms (
458- self .config .model_info .get_model ().get_output (model , inputs )
458+ self .config .model_info .get_model ().get_output (model , inputs ) # TODO(cyril) refactor those functions
459459 )
460460
461461 def model_output (inputs ):
@@ -870,7 +870,7 @@ def inference(self):
870870 model .to ("cpu" )
871871
872872 except Exception as e :
873- self .log (f"Error : { e } " )
873+ self .log (f"Error during inference : { e } " )
874874 self .quit ()
875875 finally :
876876 self .quit ()
@@ -1078,6 +1078,7 @@ def train(self):
10781078 torch .set_num_threads (1 )
10791079 self .log ("Number of threads has been set to 1 for macOS" )
10801080
1081+ self .log (f"config model : { self .config .model_info .name } " )
10811082 model_name = model_config .name
10821083 model_class = model_config .get_model ()
10831084
@@ -1314,7 +1315,7 @@ def train(self):
13141315 )
13151316 )
13161317 except RuntimeError as e :
1317- logger .error (f"Error : { e } " )
1318+ logger .error (f"Error when loading weights : { e } " )
13181319 warn = (
13191320 "WARNING:\n It'd seem that the weights were incompatible with the model,\n "
13201321 "the model will be trained from random weights"
@@ -1333,6 +1334,9 @@ def train(self):
13331334
13341335 device = self .config .device
13351336
1337+ if model_name == "test" :
1338+ self .quit ()
1339+
13361340 for epoch in range (self .config .max_epochs ):
13371341 # self.log("\n")
13381342 self .log ("-" * 10 )
@@ -1472,7 +1476,7 @@ def train(self):
14721476 model .to ("cpu" )
14731477
14741478 except Exception as e :
1475- self .log (f"Error : { e } " )
1479+ self .log (f"Error in training : { e } " )
14761480 self .quit ()
14771481 finally :
14781482 self .quit ()
0 commit comments