Skip to content

Commit 0a71230

Browse files
committed
update NetRunAgent
1 parent 4ef1335 commit 0a71230

File tree

12 files changed

+540
-414
lines changed

12 files changed

+540
-414
lines changed

examples/JSRT2/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55

66
In this example, we show how to use a customized CNN and a customized loss function to segment the heart from X-Ray images. The configurations are the same as those in the `JSRT` example except the network structure and loss function.
77

8-
The customized CNN is detailed in `my_net2d.py`, which is a modification of the 2D UNet. In this new network, we use a residual connection in each block. The customized loss is detailed in `my_loss.py`, where we combine Dice loss and MAE loss as our new loss function.
8+
The customized CNN is detailed in `my_net2d.py`, which is a modification of the 2D UNet. In this new network, we use a residual connection in each block. The customized loss is detailed in `my_loss.py`, where we define a focal dice loss.
99

10-
We also write a customized main function in `jsrt_train_infer.py` so that we can combine TrainInferAgent from PyMIC with our customized CNN and loss function.
10+
We also write a customized main function in `jsrt_net_run.py` so that we can combine NetRunAgent from PyMIC with our customized CNN and loss function.
1111

1212
## Data and preprocessing
1313
1. Data preprocessing is the same as that in the the `JSRT` example. Please follow that example for details.

examples/JSRT2/config/train_test.cfg

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,11 @@ bilinear = True
5353
device_name = cuda:0
5454

5555
batch_size = 4
56-
loss_function = my_loss
56+
loss_type = MyFocalDiceLoss
57+
MyFocalDiceLoss_Enable_Pixel_Weight = False
58+
MyFocalDiceLoss_Enable_Class_Weight = True
59+
MyFocalDiceLoss_beta = 1.5
60+
class_weight = [0.2, 1.0]
5761

5862
# for optimizers
5963
optimizer = Adam

examples/JSRT2/jsrt_net_run.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,12 @@
22
from __future__ import print_function, division
33

44
import sys
5-
from pymic.net_run.net_run import TrainInferAgent
6-
from pymic.net_run.net_factory import net_dict
75
from pymic.util.parse_config import parse_config
6+
from pymic.net_run.net_run_agent import NetRunAgent
7+
from pymic.net.net_factory import net_dict
8+
from pymic.loss.loss_factory import loss_dict
89
from my_net2d import MyUNet2D
9-
from my_loss import MySegmentationLossCalculator
10+
from my_loss import MyFocalDiceLoss
1011

1112
my_net_dict = {
1213
"MyUNet2D": MyUNet2D
@@ -22,6 +23,20 @@ def get_network(params):
2223
raise ValueError("Undefined network: {0:}".format(net_type))
2324
return net
2425

26+
my_loss_dict = {
27+
"MyFocalDiceLoss": MyFocalDiceLoss
28+
}
29+
30+
def get_loss(params):
31+
loss_type = params["loss_type"]
32+
if(loss_type in my_loss_dict):
33+
loss_obj = my_loss_dict[loss_type](params)
34+
elif(loss_type in net_dict):
35+
loss_obj = loss_dict[loss_type](params)
36+
else:
37+
raise ValueError("Undefined loss: {0:}".format(loss_type))
38+
return loss_obj
39+
2540
def main():
2641
if(len(sys.argv) < 3):
2742
print('Number of arguments should be 3. e.g.')
@@ -33,10 +48,10 @@ def main():
3348

3449
# use custormized CNN and loss function
3550
net = get_network(config['network'])
36-
loss_cal = MySegmentationLossCalculator(config['training'])
37-
agent = TrainInferAgent(config, stage)
51+
loss_obj = get_loss(config['training'])
52+
agent = NetRunAgent(config, stage)
3853
agent.set_network(net)
39-
agent.set_loss_calculater(loss_cal)
54+
agent.set_loss_calculater(loss_obj)
4055
agent.run()
4156

4257
if __name__ == "__main__":

examples/JSRT2/my_loss.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# -*- coding: utf-8 -*-
2+
from __future__ import print_function, division
3+
4+
import torch
5+
import torch.nn as nn
6+
import numpy as np
7+
from pymic.loss.util import reshape_tensor_to_2D, get_classwise_dice
8+
9+
class MyFocalDiceLoss(nn.Module):
10+
"""
11+
Focal Dice loss proposed in the following paper:
12+
P. Wang et al. Focal dice loss and image dilatin for brain tumor segmentation.
13+
in Deep Learning in Medical Image Analysis and Multimodal Learning for Clinical
14+
Decision Support, 2018.
15+
"""
16+
def __init__(self, params):
17+
super(MyFocalDiceLoss, self).__init__()
18+
self.enable_pix_weight = params['MyFocalDiceLoss_Enable_Pixel_Weight'.lower()]
19+
self.enable_cls_weight = params['MyFocalDiceLoss_Enable_Class_Weight'.lower()]
20+
self.beta = params['MyFocalDiceLoss_beta'.lower()]
21+
assert(self.beta >= 1.0)
22+
23+
def forward(self, loss_input_dict):
24+
predict = loss_input_dict['prediction']
25+
soft_y = loss_input_dict['ground_truth']
26+
pix_w = loss_input_dict['pixel_weight']
27+
cls_w = loss_input_dict['class_weight']
28+
softmax = loss_input_dict['softmax']
29+
30+
if(softmax):
31+
predict = nn.Softmax(dim = 1)(predict)
32+
predict = reshape_tensor_to_2D(predict)
33+
soft_y = reshape_tensor_to_2D(soft_y)
34+
35+
if(self.enable_pix_weight):
36+
if(pix_w is None):
37+
raise ValueError("Pixel weight is enabled but not defined")
38+
pix_w = reshape_tensor_to_2D(pix_w)
39+
dice_score = get_classwise_dice(predict, soft_y, pix_w)
40+
dice_score = 0.01 + dice_score * 0.98
41+
dice_loss = 1.0 - torch.pow(dice_score, 1.0 / self.beta)
42+
43+
if(self.enable_cls_weight):
44+
if(cls_w is None):
45+
raise ValueError("Class weight is enabled but not defined")
46+
weighted_loss = dice_loss * cls_w
47+
avg_loss = weighted_loss.sum() / cls_w.sum()
48+
else:
49+
avg_loss = torch.mean(dice_loss)
50+
return avg_loss

examples/JSRT2/my_net2d.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ def forward(self, x):
107107
new_shape = [N, D] + list(output.shape)[1:]
108108
output = torch.reshape(output, new_shape)
109109
output = torch.transpose(output, 1, 2)
110+
110111
return output
111112

112113
if __name__ == "__main__":
@@ -122,7 +123,7 @@ def forward(self, x):
122123
xt = torch.from_numpy(x)
123124
xt = torch.tensor(xt)
124125

125-
y = Net(xt)
126+
y, y2 = Net(xt)
126127
print(len(y.size()))
127128
y = y.detach().numpy()
128129
print(y.shape)

pymic/loss/exp_log.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
2+
# -*- coding: utf-8 -*-
3+
from __future__ import print_function, division
4+
5+
import torch
6+
import torch.nn as nn
7+
from pymic.loss.util import reshape_tensor_to_2D, get_classwise_dice
8+
9+
class ExpLogLoss(nn.Module):
10+
"""
11+
The exponential logarithmic loss in this paper:
12+
K. Wong et al.: 3D Segmentation with Exponential Logarithmic Loss for Highly
13+
Unbalanced Object Sizes. MICCAI 2018.
14+
"""
15+
def __init__(self, params):
16+
super(ExpLogLoss, self).__init__()
17+
self.w_dice = params['ExpLogLoss_w_dice'.lower()]
18+
self.gamma = params['ExpLogLoss_gamma'.lower()]
19+
20+
def forward(self, loss_input_dict):
21+
predict = loss_input_dict['prediction']
22+
soft_y = loss_input_dict['ground_truth']
23+
softmax = loss_input_dict['softmax']
24+
25+
if(softmax):
26+
predict = nn.Softmax(dim = 1)(predict)
27+
predict = reshape_tensor_to_2D(predict)
28+
soft_y = reshape_tensor_to_2D(soft_y)
29+
30+
31+
dice_score = get_classwise_dice(predict, soft_y)
32+
dice_score = 0.01 + dice_score * 0.98
33+
exp_dice = -torch.log(dice_score)
34+
exp_dice = torch.pow(exp_dice, self.gamma)
35+
exp_dice = torch.mean(exp_dice)
36+
37+
predict= 0.01 + predict * 0.98
38+
wc = torch.mean(soft_y, dim = 0)
39+
wc = 1.0 / (wc + 0.1)
40+
wc = torch.pow(wc, 0.5)
41+
ce = - torch.log(predict)
42+
exp_ce = wc * torch.pow(ce, self.gamma)
43+
exp_ce = torch.sum(soft_y * exp_ce, dim = 1)
44+
exp_ce = torch.mean(exp_ce)
45+
46+
loss = exp_dice * self.w_dice + exp_ce * (1.0 - self.w_dice)
47+
return loss

pymic/net/__init__.py

Whitespace-only changes.

pymic/net/net2d/__init__.py

Whitespace-only changes.

pymic/net/net2d/unet2d_scse.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import torch
1111
import torch.nn as nn
1212
import numpy as np
13-
from pymic.net2d.squeeze_and_excitation import *
13+
from pymic.net.net2d.squeeze_and_excitation import *
1414

1515
class ConvScSEBlock(nn.Module):
1616
"""two convolution layers with batch norm and leaky relu"""

pymic/net/net3d/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)