1717from tensorboardX import SummaryWriter
1818from pymic .io .image_read_write import save_nd_array_as_image
1919from pymic .io .nifty_dataset import NiftyDataset
20- from pymic .io . transform3d import get_transform
21- from pymic .net .net_factory import get_network
20+ from pymic .transform . trans_dict import TransformDict
21+ from pymic .net .net_dict import NetDict
2222from pymic .net_run .infer_func import volume_infer
2323from pymic .net_run .get_optimizer import get_optimiser
24- from pymic .loss .loss_factory import get_loss
24+ from pymic .loss .loss_dict import LossDict
2525from pymic .loss .util import get_soft_label
2626from pymic .loss .util import reshape_prediction_and_ground_truth
2727from pymic .loss .util import get_classwise_dice
@@ -35,23 +35,29 @@ def __init__(self, config, stage = 'train'):
3535 self .stage = stage
3636 if (stage == 'inference' ):
3737 self .stage = 'test'
38- self .net = None
3938 self .train_set = None
4039 self .valid_set = None
4140 self .test_set = None
41+ self .net = None
4242 self .loss_calculater = None
43+ self .transform_dict = TransformDict
44+ self .loss_dict = LossDict
45+ self .net_dict = NetDict
4346 self .tensor_type = config ['dataset' ]['tensor_type' ]
4447
4548 def set_datasets (self , train_set , valid_set , test_set ):
4649 self .train_set = train_set
4750 self .valid_set = valid_set
4851 self .test_set = test_set
4952
50- def set_network (self , net ):
51- self .net = net
53+ def set_transform_dict (self , custom_transform_dict ):
54+ self .transform_dict = custom_transform_dict
5255
53- def set_loss_calculater (self , loss_calculater ):
54- self .loss_calculater = loss_calculater
56+ def set_network_dict (self , custom_net_dict ):
57+ self .net_dict = custom_net_dict
58+
59+ def set_loss_dict (self , custom_loss_dict ):
60+ self .loss_dict = custom_loss_dict
5561
5662 def get_stage_dataset_from_config (self , stage ):
5763 assert (stage in ['train' , 'valid' , 'test' ])
@@ -66,11 +72,14 @@ def get_stage_dataset_from_config(self, stage):
6672 with_weight = False
6773 else :
6874 raise ValueError ("Incorrect value for stage: {0:}" .format (stage ))
75+ self .transform_list = []
76+ for name in transform_names :
77+ if (name not in self .transform_dict ):
78+ raise (ValueError ("Undefined transform {0:}" .format (name )))
79+ one_transform = self .transform_dict [name ](self .config ['dataset' ])
80+ self .transform_list .append (one_transform )
6981
70- self .transform_list = [get_transform (name , self .config ['dataset' ]) \
71- for name in transform_names ]
7282 csv_file = self .config ['dataset' ].get (stage + '_csv' , None )
73-
7483 dataset = NiftyDataset (root_dir = root_dir ,
7584 csv_file = csv_file ,
7685 modal_num = modal_num ,
@@ -99,8 +108,10 @@ def create_dataset(self):
99108 batch_size = batch_size , shuffle = False , num_workers = batch_size )
100109
101110 def create_network (self ):
102- if (self .net is None ):
103- self .net = get_network (self .config ['network' ])
111+ net_name = self .config ['network' ]['net_type' ]
112+ if (net_name not in self .net_dict ):
113+ raise ValueError ("Undefined network {0:}" .format (net_name ))
114+ self .net = self .net_dict [net_name ](self .config ['network' ])
104115 if (self .tensor_type == 'float' ):
105116 self .net .float ()
106117 else :
@@ -156,11 +167,14 @@ def train(self):
156167 self .checkpoint = None
157168 self .create_optimizer ()
158169
159- train_loss = 0
170+ loss_name = self .config ['training' ]['loss_type' ]
171+ if (loss_name not in self .loss_dict ):
172+ raise ValueError ("Undefined loss function {0:}" .format (loss_name ))
173+ self .loss_calculater = self .loss_dict [loss_name ](self .config ['training' ])
174+
175+ trainIter = iter (self .train_loader )
176+ train_loss = 0
160177 train_dice_list = []
161- if (self .loss_calculater is None ):
162- self .loss_calculater = get_loss (self .config ['training' ])
163- trainIter = iter (self .train_loader )
164178 print ("{0:} training start" .format (str (datetime .now ())[:- 7 ]))
165179 for it in range (iter_start , iter_max ):
166180 try :
0 commit comments