@@ -40,7 +40,7 @@ def __init__(self, config, stage = 'train'):
4040 def get_stage_dataset_from_config (self , stage ):
4141 assert (stage in ['train' , 'valid' , 'test' ])
4242 root_dir = self .config ['dataset' ]['root_dir' ]
43- modal_num = self .config ['dataset' ][ 'modal_num' ]
43+ modal_num = self .config ['dataset' ]. get ( 'modal_num' , 1 )
4444
4545 transform_key = stage + '_transform'
4646 if (stage == "valid" and transform_key not in self .config ['dataset' ]):
@@ -61,7 +61,7 @@ def get_stage_dataset_from_config(self, stage):
6161 data_transform = transforms .Compose (self .transform_list )
6262
6363 csv_file = self .config ['dataset' ].get (stage + '_csv' , None )
64- dataset = NiftyDataset (root_dir = root_dir ,
64+ dataset = NiftyDataset (root_dir = root_dir ,
6565 csv_file = csv_file ,
6666 modal_num = modal_num ,
6767 with_label = not (stage == 'test' ),
@@ -286,7 +286,7 @@ def train_valid(self):
286286 self .device = torch .device ("cuda:{0:}" .format (device_ids [0 ]))
287287 self .net .to (self .device )
288288 ckpt_dir = self .config ['training' ]['ckpt_save_dir' ]
289- ckpt_prefx = ckpt_dir .split ('/' )[- 1 ]
289+ ckpt_prefx = ckpt_dir .split ('/' )[- 1 ]
290290 iter_start = self .config ['training' ]['iter_start' ]
291291 iter_max = self .config ['training' ]['iter_max' ]
292292 iter_valid = self .config ['training' ]['iter_valid' ]
@@ -397,7 +397,7 @@ def test_time_dropout(m):
397397 infer_obj = Inferer (self .net , infer_cfg )
398398 infer_time_list = []
399399 with torch .no_grad ():
400- for data in self .test_loder :
400+ for data in self .test_loader :
401401 images = self .convert_tensor_type (data ['image' ])
402402 images = images .to (device )
403403
@@ -444,7 +444,7 @@ def infer_with_multiple_checkpoints(self):
444444 infer_obj = Inferer (self .net , infer_cfg )
445445 infer_time_list = []
446446 with torch .no_grad ():
447- for data in self .test_loder :
447+ for data in self .test_loader :
448448 images = self .convert_tensor_type (data ['image' ])
449449 images = images .to (device )
450450
0 commit comments