Skip to content

Commit 1f2d34d

Browse files
committed
update net dict and loss dictg
1 parent 0a71230 commit 1f2d34d

File tree

7 files changed

+262
-382
lines changed

7 files changed

+262
-382
lines changed

examples/JSRT2/jsrt_net_run.py

Lines changed: 8 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -4,38 +4,15 @@
44
import sys
55
from pymic.util.parse_config import parse_config
66
from pymic.net_run.net_run_agent import NetRunAgent
7-
from pymic.net.net_factory import net_dict
8-
from pymic.loss.loss_factory import loss_dict
7+
from pymic.net.net_dict import NetDict
8+
from pymic.loss.loss_dict import LossDict
99
from my_net2d import MyUNet2D
1010
from my_loss import MyFocalDiceLoss
1111

12-
my_net_dict = {
13-
"MyUNet2D": MyUNet2D
14-
}
15-
16-
def get_network(params):
17-
net_type = params["net_type"]
18-
if(net_type in my_net_dict):
19-
net = my_net_dict[net_type](params)
20-
elif(net_type in net_dict):
21-
net = net_dict[net_type](params)
22-
else:
23-
raise ValueError("Undefined network: {0:}".format(net_type))
24-
return net
25-
26-
my_loss_dict = {
27-
"MyFocalDiceLoss": MyFocalDiceLoss
28-
}
29-
30-
def get_loss(params):
31-
loss_type = params["loss_type"]
32-
if(loss_type in my_loss_dict):
33-
loss_obj = my_loss_dict[loss_type](params)
34-
elif(loss_type in net_dict):
35-
loss_obj = loss_dict[loss_type](params)
36-
else:
37-
raise ValueError("Undefined loss: {0:}".format(loss_type))
38-
return loss_obj
12+
my_net_dict = NetDict
13+
my_net_dict["MyUNet2D"] = MyUNet2D
14+
my_loss_dict = LossDict
15+
my_loss_dict["MyFocalDiceLoss"] = MyFocalDiceLoss
3916

4017
def main():
4118
if(len(sys.argv) < 3):
@@ -47,11 +24,9 @@ def main():
4724
config = parse_config(cfg_file)
4825

4926
# use custormized CNN and loss function
50-
net = get_network(config['network'])
51-
loss_obj = get_loss(config['training'])
5227
agent = NetRunAgent(config, stage)
53-
agent.set_network(net)
54-
agent.set_loss_calculater(loss_obj)
28+
agent.set_network_dict(my_net_dict)
29+
agent.set_loss_dict(my_loss_dict)
5530
agent.run()
5631

5732
if __name__ == "__main__":
Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,10 @@
55
from pymic.loss.dice import DiceWithCrossEntropyLoss, NoiseRobustDiceLoss
66
from pymic.loss.exp_log import ExpLogLoss
77

8-
loss_dict = {'CrossEntropyLoss': CrossEntropyLoss,
8+
LossDict = {'CrossEntropyLoss': CrossEntropyLoss,
99
'GeneralizedCrossEntropyLoss': GeneralizedCrossEntropyLoss,
1010
'DiceLoss': DiceLoss,
1111
'MultiScaleDiceLoss': MultiScaleDiceLoss,
1212
'DiceWithCrossEntropyLoss': DiceWithCrossEntropyLoss,
1313
'NoiseRobustDiceLoss': NoiseRobustDiceLoss,
1414
'ExpLogLoss': ExpLogLoss}
15-
16-
def get_loss(params):
17-
loss_type = params['loss_type']
18-
if(loss_type in loss_dict):
19-
loss_obj = loss_dict[loss_type](params)
20-
else:
21-
raise ValueError("Undefined loss type {0:}".format(loss_type))
22-
return loss_obj
Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,10 @@
66
from pymic.net.net3d.unet2d5 import UNet2D5
77
from pymic.net.net3d.unet3d import UNet3D
88

9-
net_dict = {
9+
NetDict = {
1010
'UNet2D': UNet2D,
1111
'COPLENet': COPLENet,
1212
'UNet2D_ScSE': UNet2D_ScSE,
1313
'UNet2D5': UNet2D5,
1414
'UNet3D': UNet3D
15-
}
16-
17-
def get_network(params):
18-
net_type = params['net_type']
19-
if(net_type in net_dict):
20-
net_obj = net_dict[net_type](params)
21-
else:
22-
raise ValueError("Undefined network type {0:}".format(net_type))
23-
return net_obj
15+
}

pymic/net_run/net_run_agent.py

Lines changed: 31 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,11 @@
1717
from tensorboardX import SummaryWriter
1818
from pymic.io.image_read_write import save_nd_array_as_image
1919
from pymic.io.nifty_dataset import NiftyDataset
20-
from pymic.io.transform3d import get_transform
21-
from pymic.net.net_factory import get_network
20+
from pymic.transform.trans_dict import TransformDict
21+
from pymic.net.net_dict import NetDict
2222
from pymic.net_run.infer_func import volume_infer
2323
from pymic.net_run.get_optimizer import get_optimiser
24-
from pymic.loss.loss_factory import get_loss
24+
from pymic.loss.loss_dict import LossDict
2525
from pymic.loss.util import get_soft_label
2626
from pymic.loss.util import reshape_prediction_and_ground_truth
2727
from pymic.loss.util import get_classwise_dice
@@ -35,23 +35,29 @@ def __init__(self, config, stage = 'train'):
3535
self.stage = stage
3636
if(stage == 'inference'):
3737
self.stage = 'test'
38-
self.net = None
3938
self.train_set = None
4039
self.valid_set = None
4140
self.test_set = None
41+
self.net = None
4242
self.loss_calculater = None
43+
self.transform_dict = TransformDict
44+
self.loss_dict = LossDict
45+
self.net_dict = NetDict
4346
self.tensor_type = config['dataset']['tensor_type']
4447

4548
def set_datasets(self, train_set, valid_set, test_set):
4649
self.train_set = train_set
4750
self.valid_set = valid_set
4851
self.test_set = test_set
4952

50-
def set_network(self, net):
51-
self.net = net
53+
def set_transform_dict(self, custom_transform_dict):
54+
self.transform_dict = custom_transform_dict
5255

53-
def set_loss_calculater(self, loss_calculater):
54-
self.loss_calculater = loss_calculater
56+
def set_network_dict(self, custom_net_dict):
57+
self.net_dict = custom_net_dict
58+
59+
def set_loss_dict(self, custom_loss_dict):
60+
self.loss_dict = custom_loss_dict
5561

5662
def get_stage_dataset_from_config(self, stage):
5763
assert(stage in ['train', 'valid', 'test'])
@@ -66,11 +72,14 @@ def get_stage_dataset_from_config(self, stage):
6672
with_weight = False
6773
else:
6874
raise ValueError("Incorrect value for stage: {0:}".format(stage))
75+
self.transform_list = []
76+
for name in transform_names:
77+
if(name not in self.transform_dict):
78+
raise(ValueError("Undefined transform {0:}".format(name)))
79+
one_transform = self.transform_dict[name](self.config['dataset'])
80+
self.transform_list.append(one_transform)
6981

70-
self.transform_list = [get_transform(name, self.config['dataset']) \
71-
for name in transform_names ]
7282
csv_file = self.config['dataset'].get(stage + '_csv', None)
73-
7483
dataset = NiftyDataset(root_dir=root_dir,
7584
csv_file = csv_file,
7685
modal_num = modal_num,
@@ -99,8 +108,10 @@ def create_dataset(self):
99108
batch_size=batch_size, shuffle=False, num_workers=batch_size)
100109

101110
def create_network(self):
102-
if(self.net is None):
103-
self.net = get_network(self.config['network'])
111+
net_name = self.config['network']['net_type']
112+
if(net_name not in self.net_dict):
113+
raise ValueError("Undefined network {0:}".format(net_name))
114+
self.net = self.net_dict[net_name](self.config['network'])
104115
if(self.tensor_type == 'float'):
105116
self.net.float()
106117
else:
@@ -156,11 +167,14 @@ def train(self):
156167
self.checkpoint = None
157168
self.create_optimizer()
158169

159-
train_loss = 0
170+
loss_name = self.config['training']['loss_type']
171+
if(loss_name not in self.loss_dict):
172+
raise ValueError("Undefined loss function {0:}".format(loss_name))
173+
self.loss_calculater = self.loss_dict[loss_name](self.config['training'])
174+
175+
trainIter = iter(self.train_loader)
176+
train_loss = 0
160177
train_dice_list = []
161-
if(self.loss_calculater is None):
162-
self.loss_calculater = get_loss(self.config['training'])
163-
trainIter = iter(self.train_loader)
164178
print("{0:} training start".format(str(datetime.now())[:-7]))
165179
for it in range(iter_start, iter_max):
166180
try:

pymic/transform/__init__.py

Whitespace-only changes.

pymic/transform/trans_dict.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# -*- coding: utf-8 -*-
2+
from __future__ import print_function, division
3+
from pymic.transform.transform3d import *
4+
5+
TransformDict = {
6+
'ChannelWiseGammaCorrection': ChannelWiseGammaCorrection,
7+
'ChannelWiseNormalize': ChannelWiseNormalize,
8+
'ChannelWiseThreshold': ChannelWiseThreshold,
9+
'ChannelWiseThresholdWithNormalize': ChannelWiseThresholdWithNormalize,
10+
'CropWithBoundingBox': CropWithBoundingBox,
11+
'LabelConvert': LabelConvert,
12+
'LabelConvertNonzero': LabelConvertNonzero,
13+
'LabelToProbability': LabelToProbability,
14+
'RandomCrop': RandomCrop,
15+
'RandomFlip': RandomFlip,
16+
'RandomRotate': RandomRotate,
17+
'ReduceLabelDim': ReduceLabelDim,
18+
'Rescale': Rescale,
19+
'Pad': Pad,
20+
}

0 commit comments

Comments
 (0)