Skip to content

Commit 0218ebc

Browse files
committed
noise-learning
update pymic for noise-learning
1 parent 02ed81f commit 0218ebc

File tree

16 files changed

+542
-37
lines changed

16 files changed

+542
-37
lines changed

pymic/io/nifty_dataset.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ def __init__(self, root_dir, csv_file, modal_num = 1, class_num = 2,
9595
csv_file, modal_num, with_label, transform)
9696
self.class_num = class_num
9797
print("class number for ClassificationDataset", self.class_num)
98+
print("self.transform", self.transform)
9899

99100
def __getlabel__(self, idx):
100101
csv_keys = list(self.csv_items.keys())
@@ -126,7 +127,7 @@ def __getitem__(self, idx):
126127
sample['label'] = self.__getlabel__(idx)
127128
if (self.image_weight_idx is not None):
128129
sample['image_weight'] = self.__getweight__(idx)
130+
print("***transform", self.transform)
129131
if self.transform:
130132
sample = self.transform(sample)
131133
return sample
132-

pymic/loss/loss_dict_seg.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@
33
import torch.nn as nn
44
from pymic.loss.seg.ce import CrossEntropyLoss, GeneralizedCrossEntropyLoss
55
from pymic.loss.seg.dice import DiceLoss, FocalDiceLoss, NoiseRobustDiceLoss
6+
from pymic.loss.seg.slsr import SLSRLoss
67
from pymic.loss.seg.exp_log import ExpLogLoss
78
from pymic.loss.seg.mse import MSELoss, MAELoss
89

910
SegLossDict = {'CrossEntropyLoss': CrossEntropyLoss,
1011
'GeneralizedCrossEntropyLoss': GeneralizedCrossEntropyLoss,
12+
'SLSRLoss': SLSRLoss,
1113
'DiceLoss': DiceLoss,
1214
'FocalDiceLoss': FocalDiceLoss,
1315
'NoiseRobustDiceLoss': NoiseRobustDiceLoss,

pymic/loss/seg/gatedcrf_backup.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# -*- coding: utf-8 -*-
2+
from __future__ import print_function, division
3+
4+
import torch
5+
import torch.nn as nn
6+
import numpy as np
7+
from pymic.loss.seg.ce import CrossEntropyLoss
8+
from pymic.loss.seg.gatedcrf_util import ModelLossSemsegGatedCRF
9+
10+
class GatedCRFLoss(nn.Module):
11+
def __init__(self, params):
12+
super(GatedCRFLoss, self).__init__()
13+
self.gcrf_loss = ModelLossSemsegGatedCRF()
14+
self.softmax = params.get('loss_softmax', True)
15+
w0 = params['GatedCRFLoss_W0'.lower()]
16+
xy0= params['GatedCRFLoss_XY0'.lower()]
17+
rgb= params['GatedCRFLoss_rgb'.lower()]
18+
w1 = params['GatedCRFLoss_W1'.lower()]
19+
xy1= params['GatedCRFLoss_XY1'.lower()]
20+
kernel0 = {'weight': w0, 'xy': xy0, 'rgb': rgb}
21+
kernel1 = {'weight': w1, 'xy': xy1}
22+
self.kernels = [kernel0, kernel1]
23+
self.radius = params['GatedCRFLoss_Radius'.lower()]
24+
25+
def forward(self, loss_input_dict):
26+
predict = loss_input_dict['prediction']
27+
image = loss_input_dict['image'] # should be normalized by mean, std
28+
scribble= loss_input_dict['scribbles']
29+
validity_mask = loss_input_dict['validity_mask']
30+
31+
if(self.softmax):
32+
predict = nn.Softmax(dim = 1)(predict)
33+
34+
batch_dict = {'rgb': image,
35+
'semseg_scribbles': scribble}
36+
x_shape = list(predict.shape)
37+
l_crf = {'loss': 0}
38+
if(self.gcrf_w > 0):
39+
l_crf = self.gcrf_loss(predict,
40+
self.kernels,
41+
self.radius,
42+
batch_dict,
43+
x_shape[-2],
44+
x_shape[-1],
45+
mask_src=validity_mask,
46+
out_kernels_vis=False,
47+
)
48+
return l_crf['loss']

pymic/loss/seg/slsr.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
Spatial Label Smoothing Regularization (SLSR) loss for learning from
4+
noisy annotatins according to the following paper:
5+
Minqing Zhang, Jiantao Gao et al., Characterizing Label Errors:
6+
Confident Learning for Noisy-Labeled Image Segmentation, MICCAI 2020.
7+
"""
8+
from __future__ import print_function, division
9+
10+
import torch
11+
import torch.nn as nn
12+
from pymic.loss.seg.util import reshape_tensor_to_2D
13+
14+
class SLSRLoss(nn.Module):
15+
def __init__(self, params):
16+
super(SLSRLoss, self).__init__()
17+
if(params is None):
18+
params = {}
19+
self.softmax = params.get('loss_softmax', True)
20+
self.epsilon = params.get('slsrloss_softmax', 0.25)
21+
22+
def forward(self, loss_input_dict):
23+
predict = loss_input_dict['prediction']
24+
soft_y = loss_input_dict['ground_truth']
25+
pix_w = loss_input_dict.get('pixel_weight', None)
26+
# the pixel wight here is actually the confidence mask
27+
# i.e., if the value is one, it means the label of corresponding
28+
# pixel is noisy and should be replaced with smoothed label.
29+
30+
if(isinstance(predict, (list, tuple))):
31+
predict = predict[0]
32+
if(self.softmax):
33+
predict = nn.Softmax(dim = 1)(predict)
34+
predict = reshape_tensor_to_2D(predict)
35+
soft_y = reshape_tensor_to_2D(soft_y)
36+
if(pix_w is not None):
37+
pix_w = reshape_tensor_to_2D(pix_w > 0).float()
38+
39+
# smooth labels for pixels in the unconfident mask
40+
smooth_y = (soft_y - 0.5) * (0.5 - self.epsilon) / 0.5 + 0.5
41+
smooth_y = pix_w * smooth_y + (1 - pix_w) * soft_y
42+
else:
43+
smooth_y = soft_y
44+
45+
# for numeric stability
46+
predict = predict * 0.999 + 5e-4
47+
ce = - smooth_y* torch.log(predict)
48+
ce = torch.sum(ce, dim = 1) # shape is [N]
49+
ce = torch.mean(ce)
50+
return ce

pymic/net/cls/torch_pretrained_net.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,9 @@ class ResNet18(nn.Module):
2424
def __init__(self, params):
2525
super(ResNet18, self).__init__()
2626
self.params = params
27-
net_name = params['net_type']
2827
cls_num = params['class_num']
29-
in_chns = params['input_chns']
30-
self.pretrain = params['pretrain']
28+
in_chns = params.get('input_chns', 3)
29+
self.pretrain = params.get('pretrain', True)
3130
self.update_layers = params.get('update_layers', 0)
3231
self.net = models.resnet18(pretrained = self.pretrain)
3332

@@ -51,10 +50,9 @@ class VGG16(nn.Module):
5150
def __init__(self, params):
5251
super(VGG16, self).__init__()
5352
self.params = params
54-
net_name = params['net_type']
5553
cls_num = params['class_num']
56-
in_chns = params['input_chns']
57-
self.pretrain = params['pretrain']
54+
in_chns = params.get('input_chns', 3)
55+
self.pretrain = params.get('pretrain', True)
5856
self.update_layers = params.get('update_layers', 0)
5957
self.net = models.vgg16(pretrained = self.pretrain)
6058

@@ -78,10 +76,9 @@ class MobileNetV2(nn.Module):
7876
def __init__(self, params):
7977
super(MobileNetV2, self).__init__()
8078
self.params = params
81-
net_name = params['net_type']
8279
cls_num = params['class_num']
83-
in_chns = params['input_chns']
84-
self.pretrain = params['pretrain']
80+
in_chns = params.get('input_chns', 3)
81+
self.pretrain = params.get('pretrain', True)
8582
self.update_layers = params.get('update_layers', 0)
8683
self.net = models.mobilenet_v2(pretrained = self.pretrain)
8784

pymic/net/net3d/unet3d.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ def __init__(self, params):
9696
self.n_class = self.params['class_num']
9797
self.trilinear = self.params['trilinear']
9898
self.deep_sup = self.params['deep_supervise']
99+
self.stage = self.params['stage']
99100
assert(len(self.ft_chns) == 5 or len(self.ft_chns) == 4)
100101

101102
self.in_conv= ConvBlock(self.in_chns, self.ft_chns[0], self.dropout[0])
@@ -133,7 +134,7 @@ def forward(self, x):
133134
x_d1 = self.up3(x_d2, x1)
134135
x_d0 = self.up4(x_d1, x0)
135136
output = self.out_conv(x_d0)
136-
if(self.deep_sup):
137+
if(self.deep_sup and self.stage == "train"):
137138
out_shape = list(output.shape)[2:]
138139
output1 = self.out_conv1(x_d1)
139140
output1 = interpolate(output1, out_shape, mode = 'trilinear')

pymic/net_run/agent_abstract.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def worker_init_fn(worker_id):
144144
bn_test = self.config['dataset'].get('test_batch_size', 1)
145145
if(self.test_set is None):
146146
self.test_set = self.get_stage_dataset_from_config('test')
147-
self.test_loder = torch.utils.data.DataLoader(self.test_set,
147+
self.test_loader = torch.utils.data.DataLoader(self.test_set,
148148
batch_size = bn_test, shuffle=False, num_workers= bn_test)
149149

150150
def create_optimizer(self, params):

pymic/net_run/agent_cls.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ def train_valid(self):
207207
self.net.to(self.device)
208208

209209
ckpt_dir = self.config['training']['ckpt_save_dir']
210-
ckpt_prefx = self.config['training']['ckpt_save_prefix']
210+
ckpt_prefx = ckpt_dir.split('/')[-1]
211211
iter_start = self.config['training']['iter_start']
212212
iter_max = self.config['training']['iter_max']
213213
iter_valid = self.config['training']['iter_valid']
@@ -290,9 +290,12 @@ def infer(self):
290290
out_prob_list = []
291291
out_lab_list = []
292292
with torch.no_grad():
293-
for data in self.test_loder:
293+
for data in self.test_loader:
294294
names = data['names']
295+
if(names[0] != "data3_process/20190711_1005487059.png"):
296+
continue
295297
inputs = self.convert_tensor_type(data['image'])
298+
print("intensity mean and std", inputs.mean(), inputs.std())
296299
inputs = inputs.to(device)
297300

298301
start_time = time.time()

pymic/net_run/agent_seg.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def __init__(self, config, stage = 'train'):
4040
def get_stage_dataset_from_config(self, stage):
4141
assert(stage in ['train', 'valid', 'test'])
4242
root_dir = self.config['dataset']['root_dir']
43-
modal_num = self.config['dataset']['modal_num']
43+
modal_num = self.config['dataset'].get('modal_num', 1)
4444

4545
transform_key = stage + '_transform'
4646
if(stage == "valid" and transform_key not in self.config['dataset']):
@@ -61,7 +61,7 @@ def get_stage_dataset_from_config(self, stage):
6161
data_transform = transforms.Compose(self.transform_list)
6262

6363
csv_file = self.config['dataset'].get(stage + '_csv', None)
64-
dataset = NiftyDataset(root_dir=root_dir,
64+
dataset = NiftyDataset(root_dir = root_dir,
6565
csv_file = csv_file,
6666
modal_num = modal_num,
6767
with_label= not (stage == 'test'),
@@ -286,7 +286,7 @@ def train_valid(self):
286286
self.device = torch.device("cuda:{0:}".format(device_ids[0]))
287287
self.net.to(self.device)
288288
ckpt_dir = self.config['training']['ckpt_save_dir']
289-
ckpt_prefx = ckpt_dir.split('/')[-1]
289+
ckpt_prefx = ckpt_dir.split('/')[-1]
290290
iter_start = self.config['training']['iter_start']
291291
iter_max = self.config['training']['iter_max']
292292
iter_valid = self.config['training']['iter_valid']
@@ -397,7 +397,7 @@ def test_time_dropout(m):
397397
infer_obj = Inferer(self.net, infer_cfg)
398398
infer_time_list = []
399399
with torch.no_grad():
400-
for data in self.test_loder:
400+
for data in self.test_loader:
401401
images = self.convert_tensor_type(data['image'])
402402
images = images.to(device)
403403

@@ -444,7 +444,7 @@ def infer_with_multiple_checkpoints(self):
444444
infer_obj = Inferer(self.net, infer_cfg)
445445
infer_time_list = []
446446
with torch.no_grad():
447-
for data in self.test_loder:
447+
for data in self.test_loader:
448448
images = self.convert_tensor_type(data['image'])
449449
images = images.to(device)
450450

pymic/net_run/infer_func.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -137,22 +137,22 @@ def run(self, image):
137137
outputs = self.__infer(image)
138138
elif(tta_mode == 1): # test time augmentation with flip in 2D
139139
outputs1 = self.__infer(image)
140-
outputs2 = self.__infer(torch.flip(image, [-1]))
141-
outputs3 = self.__infer(torch.flip(image, [-2]))
142-
outputs4 = self.__infer(torch.flip(image, [-2, -1]))
140+
outputs2 = self.__infer(torch.flip(image, [-2]))
141+
outputs3 = self.__infer(torch.flip(image, [-3]))
142+
outputs4 = self.__infer(torch.flip(image, [-2, -3]))
143143
if(isinstance(outputs1, (tuple, list))):
144144
outputs = []
145145
for i in range(len(outputs)):
146146
temp_out1 = outputs1[i]
147-
temp_out2 = torch.flip(outputs2[i], [-1])
148-
temp_out3 = torch.flip(outputs3[i], [-2])
149-
temp_out4 = torch.flip(outputs4[i], [-2, -1])
147+
temp_out2 = torch.flip(outputs2[i], [-2])
148+
temp_out3 = torch.flip(outputs3[i], [-3])
149+
temp_out4 = torch.flip(outputs4[i], [-2, -3])
150150
temp_mean = (temp_out1 + temp_out2 + temp_out3 + temp_out4) / 4
151151
outputs.append(temp_mean)
152152
else:
153-
outputs2 = torch.flip(outputs2, [-1])
154-
outputs3 = torch.flip(outputs3, [-2])
155-
outputs4 = torch.flip(outputs4, [-2, -1])
153+
outputs2 = torch.flip(outputs2, [-2])
154+
outputs3 = torch.flip(outputs3, [-3])
155+
outputs4 = torch.flip(outputs4, [-2, -3])
156156
outputs = (outputs1 + outputs2 + outputs3 + outputs4) / 4
157157
else:
158158
raise ValueError("Undefined tta_mode {0:}".format(tta_mode))

0 commit comments

Comments
 (0)