Skip to content

Commit bd38bc9

Browse files
committed
support mixup
1 parent c541525 commit bd38bc9

File tree

4 files changed

+71
-17
lines changed

4 files changed

+71
-17
lines changed

pymic/loss/cls/basic.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -65,10 +65,7 @@ def forward(self, loss_input_dict):
6565
labels = loss_input_dict['ground_truth'][:, None] # reshape to N, 1
6666
softmax = nn.Softmax(dim = 1)
6767
predict = softmax(predict)
68-
num_class = list(predict.size())[1]
69-
data_type = 'float' if(predict.dtype is torch.float32) else 'double'
70-
soft_y = get_soft_label(labels, num_class, data_type)
71-
loss = self.l1_loss(predict, soft_y)
68+
loss = self.l1_loss(predict, labels)
7269
return loss
7370

7471
class MSELoss(AbstractClassificationLoss):
@@ -84,10 +81,7 @@ def forward(self, loss_input_dict):
8481
labels = loss_input_dict['ground_truth'][:, None] # reshape to N, 1
8582
softmax = nn.Softmax(dim = 1)
8683
predict = softmax(predict)
87-
num_class = list(predict.size())[1]
88-
data_type = 'float' if(predict.dtype is torch.float32) else 'double'
89-
soft_y = get_soft_label(labels, num_class, data_type)
90-
loss = self.mse_loss(predict, soft_y)
84+
loss = self.mse_loss(predict, labels)
9185
return loss
9286

9387
class NLLLoss(AbstractClassificationLoss):

pymic/net_run/agent_cls.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import torch
1010
import torch.nn as nn
1111
from datetime import datetime
12+
from random import random
1213
from torch.optim import lr_scheduler
1314
from torchvision import transforms
1415
from tensorboardX import SummaryWriter
@@ -17,6 +18,7 @@
1718
from pymic.net.net_dict_cls import TorchClsNetDict
1819
from pymic.transform.trans_dict import TransformDict
1920
from pymic.net_run.agent_abstract import NetRunAgent
21+
from pymic.util.general import mixup
2022
import warnings
2123
warnings.filterwarnings('ignore', '.*output shape of zoom.*')
2224

@@ -111,16 +113,17 @@ def get_evaluation_score(self, outputs, labels):
111113
"""
112114
Get evaluation score for a prediction.
113115
114-
:param outputs: (tensor) Prediction obtained by a network.
115-
:param labels: (tensor) The ground truth.
116+
:param outputs: (tensor) Prediction obtained by a network with size N X C.
117+
:param labels: (tensor) The ground truth with size N X C.
116118
"""
117119
metrics = self.config['training'].get("evaluation_metric", "accuracy")
118120
if(metrics != "accuracy"): # default classification accuracy
119121
raise ValueError("Not implemeted for metric {0:}".format(metrics))
120122
if(self.task_type == "cls"):
121-
_, preds = torch.max(outputs, 1)
122-
consis= self.convert_tensor_type(preds == labels.data)
123-
score = torch.mean(consis)
123+
out_argmax = torch.argmax(outputs, 1)
124+
lab_argmax = torch.argmax(labels, 1)
125+
consis = self.convert_tensor_type(out_argmax == lab_argmax)
126+
score = torch.mean(consis)
124127
elif(self.task_type == "cls_nexcl"): #nonexclusive classification
125128
preds = self.convert_tensor_type(outputs > 0.5)
126129
consis= self.convert_tensor_type(preds == labels.data)
@@ -129,6 +132,7 @@ def get_evaluation_score(self, outputs, labels):
129132

130133
def training(self):
131134
iter_valid = self.config['training']['iter_valid']
135+
mixup_prob = self.config['training'].get('mixup_probability', 0.5)
132136
sample_num = 0
133137
running_loss = 0
134138
running_score= 0
@@ -140,8 +144,11 @@ def training(self):
140144
self.trainIter = iter(self.train_loader)
141145
data = next(self.trainIter)
142146
inputs = self.convert_tensor_type(data['image'])
143-
labels = data['label'].long()
147+
labels = self.convert_tensor_type(data['label_prob'])
148+
if(random() < mixup_prob):
149+
inputs, labels = mixup(inputs, labels)
144150
inputs, labels = inputs.to(self.device), labels.to(self.device)
151+
145152
# zero the parameter gradients
146153
self.optimizer.zero_grad()
147154
# forward + backward + optimize
@@ -174,7 +181,7 @@ def validation(self):
174181
self.net.eval()
175182
for data in validIter:
176183
inputs = self.convert_tensor_type(data['image'])
177-
labels = data['label'].long()
184+
labels = self.convert_tensor_type(data['label_prob'])
178185
inputs, labels = inputs.to(self.device), labels.to(self.device)
179186
self.optimizer.zero_grad()
180187
# forward + backward + optimize

pymic/net_run/agent_seg.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import torch.optim as optim
1313
import torch.nn.functional as F
1414
from datetime import datetime
15+
from random import random
1516
from torch.optim import lr_scheduler
1617
from tensorboardX import SummaryWriter
1718
from pymic.io.image_read_write import save_nd_array_as_image
@@ -28,6 +29,7 @@
2829
from pymic.transform.trans_dict import TransformDict
2930
from pymic.util.post_process import PostProcessDict
3031
from pymic.util.image_process import convert_label
32+
from pymic.util.general import mixup
3133

3234
class SegmentationAgent(NetRunAgent):
3335
def __init__(self, config, stage = 'train'):
@@ -120,6 +122,7 @@ def set_postprocessor(self, postprocessor):
120122
def training(self):
121123
class_num = self.config['network']['class_num']
122124
iter_valid = self.config['training']['iter_valid']
125+
mixup_prob = self.config['training'].get('mixup_probability', 0.5)
123126
train_loss = 0
124127
train_dice_list = []
125128
self.net.train()
@@ -132,7 +135,9 @@ def training(self):
132135
# get the inputs
133136
inputs = self.convert_tensor_type(data['image'])
134137
labels_prob = self.convert_tensor_type(data['label_prob'])
135-
138+
if(random() < mixup_prob):
139+
inputs, labels_prob = mixup(inputs, labels_prob)
140+
136141
# # for debug
137142
# for i in range(inputs.shape[0]):
138143
# image_i = inputs[i][0]

pymic/util/general.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,4 +29,52 @@ def get_one_hot_seg(label, class_num):
2929
one_hot = one_hot.view(*size)
3030
one_hot = torch.transpose(one_hot, 1, -1)
3131
one_hot = torch.squeeze(one_hot, -1)
32-
return one_hot
32+
return one_hot
33+
34+
def mixup(inputs, labels):
35+
"""Shuffle a minibatch and do linear interpolation between images and labels.
36+
Both classification and segmentation labels are supported. The targets should
37+
be one-hot labels.
38+
39+
:param inputs: a tensor of input images with size N X C0 x H x W.
40+
:param labels: a tensor of one-hot labels. The shape is N X C for classification
41+
tasks, and N X C X H X W for segmentation tasks.
42+
"""
43+
input_shape = list(inputs.shape)
44+
label_shape = list(labels.shape)
45+
img_dim = len(input_shape) - 2
46+
N = input_shape[0] # batch size
47+
C = label_shape[1] # class number
48+
rp1 = torch.randperm(N)
49+
inputs1 = inputs[rp1]
50+
labels1 = labels[rp1]
51+
52+
rp2 = torch.randperm(N)
53+
inputs2 = inputs[rp2]
54+
labels2 = labels[rp2]
55+
56+
a = np.random.beta(1, 1, [N, 1])
57+
if(img_dim == 2):
58+
b = np.tile(a[..., None, None], [1] + input_shape[1:])
59+
elif(img_dim == 3):
60+
b = np.tile(a[..., None, None, None], [1] + input_shape[1:])
61+
else:
62+
raise ValueError("MixUp only supports 2D and 3D images, but the " +
63+
"input image has {0:} dimensions".format(img_dim))
64+
65+
inputs1 = inputs1 * torch.from_numpy(b).float()
66+
inputs2 = inputs2 * torch.from_numpy(1 - b).float()
67+
inputs_mix = inputs1 + inputs2
68+
69+
if(len(label_shape) == 2): # for classification tasks
70+
c = np.tile(a, [1, C])
71+
elif(img_dim == 2): # for 2D segmentation tasks
72+
c = np.tile(a[..., None, None], [1] + label_shape[1:])
73+
else: # for 3D segmentation tasks
74+
c = np.tile(a[..., None, None, None], [1] + label_shape[1:])
75+
76+
labels1 = labels1 * torch.from_numpy(c).float()
77+
labels2 = labels2 * torch.from_numpy(1 - c).float()
78+
labels_mix = labels1 + labels2
79+
80+
return inputs_mix, labels_mix

0 commit comments

Comments
 (0)