@@ -221,21 +221,22 @@ def training(self):
221221
222222 def validation (self ):
223223 class_num = self .config ['network' ]['class_num' ]
224- infer_cfg = self .config ['testing' ]
225- infer_cfg ['class_num' ] = class_num
224+ if (self .inferer is None ):
225+ infer_cfg = self .config ['testing' ]
226+ infer_cfg ['class_num' ] = class_num
227+ self .inferer = Inferer (infer_cfg )
226228
227229 valid_loss_list = []
228230 valid_dice_list = []
229231 validIter = iter (self .valid_loader )
230232 with torch .no_grad ():
231233 self .net .eval ()
232- infer_obj = Inferer (self .net , infer_cfg )
233234 for data in validIter :
234235 inputs = self .convert_tensor_type (data ['image' ])
235236 labels_prob = self .convert_tensor_type (data ['label_prob' ])
236237 inputs , labels_prob = inputs .to (self .device ), labels_prob .to (self .device )
237238 batch_n = inputs .shape [0 ]
238- outputs = infer_obj . run (inputs )
239+ outputs = self . inferer . run (self . net , inputs )
239240
240241 # The tensors are on CPU when calculating loss for validation data
241242 loss = self .get_loss_value (data , outputs , labels_prob )
@@ -286,6 +287,8 @@ def train_valid(self):
286287 self .device = torch .device ("cuda:{0:}" .format (device_ids [0 ]))
287288 self .net .to (self .device )
288289 ckpt_dir = self .config ['training' ]['ckpt_save_dir' ]
290+ if (ckpt_dir [- 1 ] == "/" ):
291+ ckpt_dir = ckpt_dir [:- 1 ]
289292 ckpt_prefx = ckpt_dir .split ('/' )[- 1 ]
290293 iter_start = self .config ['training' ]['iter_start' ]
291294 iter_max = self .config ['training' ]['iter_max' ]
@@ -392,9 +395,10 @@ def test_time_dropout(m):
392395 checkpoint = torch .load (ckpt_name , map_location = device )
393396 self .net .load_state_dict (checkpoint ['model_state_dict' ])
394397
395- infer_cfg = self .config ['testing' ]
396- infer_cfg ['class_num' ] = self .config ['network' ]['class_num' ]
397- infer_obj = Inferer (self .net , infer_cfg )
398+ if (self .inferer is None ):
399+ infer_cfg = self .config ['testing' ]
400+ infer_cfg ['class_num' ] = self .config ['network' ]['class_num' ]
401+ self .inferer = Inferer (infer_cfg )
398402 infer_time_list = []
399403 with torch .no_grad ():
400404 for data in self .test_loader :
@@ -412,7 +416,7 @@ def test_time_dropout(m):
412416 # continue
413417 start_time = time .time ()
414418
415- pred = infer_obj . run (images )
419+ pred = self . inferer . run (self . net , images )
416420 # convert tensor to numpy
417421 if (isinstance (pred , (tuple , list ))):
418422 pred = [item .cpu ().numpy () for item in pred ]
@@ -438,10 +442,11 @@ def infer_with_multiple_checkpoints(self):
438442 device_ids = self .config ['testing' ]['gpus' ]
439443 device = torch .device ("cuda:{0:}" .format (device_ids [0 ]))
440444
445+ if (self .inferer is None ):
446+ infer_cfg = self .config ['testing' ]
447+ infer_cfg ['class_num' ] = self .config ['network' ]['class_num' ]
448+ self .inferer = Inferer (infer_cfg )
441449 ckpt_names = self .config ['testing' ]['ckpt_name' ]
442- infer_cfg = self .config ['testing' ]
443- infer_cfg ['class_num' ] = self .config ['network' ]['class_num' ]
444- infer_obj = Inferer (self .net , infer_cfg )
445450 infer_time_list = []
446451 with torch .no_grad ():
447452 for data in self .test_loader :
@@ -463,7 +468,7 @@ def infer_with_multiple_checkpoints(self):
463468 checkpoint = torch .load (ckpt_name , map_location = device )
464469 self .net .load_state_dict (checkpoint ['model_state_dict' ])
465470
466- pred = infer_obj . run (images )
471+ pred = self . inferer . run (self . net , images )
467472 # convert tensor to numpy
468473 if (isinstance (pred , (tuple , list ))):
469474 pred = [item .cpu ().numpy () for item in pred ]
0 commit comments