Skip to content

Commit 02ed81f

Browse files
committed
update weakly supervised leraning
add Mumford Shah loss
1 parent b67920f commit 02ed81f

File tree

4 files changed

+189
-0
lines changed

4 files changed

+189
-0
lines changed

pymic/loss/seg/mumford_shah.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
# -*- coding: utf-8 -*-
2+
from __future__ import print_function, division
3+
4+
import torch
5+
import torch.nn as nn
6+
7+
class DiceLoss(nn.Module):
8+
def __init__(self, params = None):
9+
super(DiceLoss, self).__init__()
10+
if(params is None):
11+
self.softmax = True
12+
else:
13+
self.softmax = params.get('loss_softmax', True)
14+
15+
def forward(self, loss_input_dict):
16+
predict = loss_input_dict['prediction']
17+
soft_y = loss_input_dict['ground_truth']
18+
19+
if(isinstance(predict, (list, tuple))):
20+
predict = predict[0]
21+
if(self.softmax):
22+
predict = nn.Softmax(dim = 1)(predict)
23+
predict = reshape_tensor_to_2D(predict)
24+
soft_y = reshape_tensor_to_2D(soft_y)
25+
dice_score = get_classwise_dice(predict, soft_y)
26+
dice_loss = 1.0 - dice_score.mean()
27+
return dice_loss
28+
29+
class MumfordShahLoss(nn.Module):
30+
"""
31+
Implementation of Mumford Shah Loss in this paper:
32+
Boah Kim and Jong Chul Ye, Mumford–Shah Loss Functional
33+
for Image Segmentation With Deep Learning. IEEE TIP, 2019.
34+
The oringial implementation is availabel at:
35+
https://github.com/jongcye/CNN_MumfordShah_Loss
36+
37+
currently only 2D version is supported.
38+
"""
39+
def __init__(self, params = None):
40+
super(MumfordShahLoss, self).__init__()
41+
if(params is None):
42+
params = {}
43+
self.softmax = params.get('loss_softmax', True)
44+
self.penalty = params.get('MumfordShahLoss_penalty', "l1")
45+
self.grad_w = params.get('MumfordShahLoss_lambda', 1.0)
46+
47+
def get_levelset_loss(self, output, target):
48+
"""
49+
output: softmax output of a network
50+
target: the input image
51+
"""
52+
outshape = output.shape
53+
tarshape = target.shape
54+
loss = 0.0
55+
for ich in range(tarshape[1]):
56+
target_ = torch.unsqueeze(target[:,ich], 1)
57+
target_ = target_.expand(tarshape[0], outshape[1], tarshape[2], tarshape[3])
58+
pcentroid = torch.sum(target_ * output, (2,3))/torch.sum(output, (2,3))
59+
pcentroid = pcentroid.view(tarshape[0], outshape[1], 1, 1)
60+
plevel = target_ - pcentroid.expand(tarshape[0], outshape[1], tarshape[2], tarshape[3])
61+
pLoss = plevel * plevel * output
62+
loss += torch.sum(pLoss)
63+
return loss
64+
65+
def get_gradient_loss(self, pred, penalty = "l2"):
66+
dH = torch.abs(pred[:, :, 1:, :] - pred[:, :, :-1, :])
67+
dW = torch.abs(pred[:, :, :, 1:] - pred[:, :, :, :-1])
68+
if penalty == "l2":
69+
dH = dH * dH
70+
dW = dW * dW
71+
loss = torch.sum(dH) + torch.sum(dW)
72+
return loss
73+
74+
def forward(self, loss_input_dict):
75+
predict = loss_input_dict['prediction']
76+
image = loss_input_dict['image']
77+
if(isinstance(predict, (list, tuple))):
78+
predict = predict[0]
79+
if(self.softmax):
80+
predict = nn.Softmax(dim = 1)(predict)
81+
82+
pred_shape = list(predict.shape)
83+
if(len(pred_shape) == 5):
84+
[N, C, D, H, W] = pred_shape
85+
new_shape = [N*D, C, H, W]
86+
predict = torch.transpose(predict, 1, 2)
87+
predict = torch.reshape(predict, new_shape)
88+
[N, C, D, H, W] = list(image.shape)
89+
new_shape = [N*D, C, H, W]
90+
image = torch.transpose(image, 1, 2)
91+
image = torch.reshape(image, new_shape)
92+
loss0 = self.get_levelset_loss(predict, image)
93+
loss1 = self.get_gradient_loss(predict, self.penalty)
94+
loss = loss0 + self.grad_w * loss1
95+
return loss/torch.numel(predict)

pymic/net_run_wsl/wsl_main.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,14 @@
77
from pymic.util.parse_config import *
88
from pymic.net_run_wsl.wsl_em import WSL_EntropyMinimization
99
from pymic.net_run_wsl.wsl_gatedcrf import WSL_GatedCRF
10+
from pymic.net_run_wsl.wsl_mumford_shah import WSL_MumfordShah
1011
from pymic.net_run_wsl.wsl_tv import WSL_TotalVariation
1112
from pymic.net_run_wsl.wsl_ustm import WSL_USTM
1213
from pymic.net_run_wsl.wsl_dmpls import WSL_DMPLS
1314

1415
WSLMethodDict = {'EntropyMinimization': WSL_EntropyMinimization,
1516
'GatedCRF': WSL_GatedCRF,
17+
'MumfordShah': WSL_MumfordShah,
1618
'TotalVariation': WSL_TotalVariation,
1719
'USTM': WSL_USTM,
1820
'DMPLS': WSL_DMPLS}
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
# -*- coding: utf-8 -*-
2+
from __future__ import print_function, division
3+
import logging
4+
import numpy as np
5+
import random
6+
import torch
7+
import torchvision.transforms as transforms
8+
from pymic.io.nifty_dataset import NiftyDataset
9+
from pymic.loss.seg.util import get_soft_label
10+
from pymic.loss.seg.util import reshape_prediction_and_ground_truth
11+
from pymic.loss.seg.util import get_classwise_dice
12+
from pymic.loss.seg.mumford_shah import MumfordShahLoss
13+
from pymic.net_run.agent_seg import SegmentationAgent
14+
from pymic.net_run_wsl.wsl_em import WSL_EntropyMinimization
15+
from pymic.util.ramps import sigmoid_rampup
16+
17+
class WSL_MumfordShah(WSL_EntropyMinimization):
18+
"""
19+
Training and testing agent for semi-supervised segmentation
20+
"""
21+
def __init__(self, config, stage = 'train'):
22+
super(WSL_MumfordShah, self).__init__(config, stage)
23+
24+
def training(self):
25+
class_num = self.config['network']['class_num']
26+
iter_valid = self.config['training']['iter_valid']
27+
wsl_cfg = self.config['weakly_supervised_learning']
28+
train_loss = 0
29+
train_loss_sup = 0
30+
train_loss_reg = 0
31+
train_dice_list = []
32+
33+
reg_loss_calculator = MumfordShahLoss(wsl_cfg)
34+
self.net.train()
35+
for it in range(iter_valid):
36+
try:
37+
data = next(self.trainIter)
38+
except StopIteration:
39+
self.trainIter = iter(self.train_loader)
40+
data = next(self.trainIter)
41+
42+
# get the inputs
43+
inputs = self.convert_tensor_type(data['image'])
44+
y = self.convert_tensor_type(data['label_prob'])
45+
46+
inputs, y = inputs.to(self.device), y.to(self.device)
47+
48+
# zero the parameter gradients
49+
self.optimizer.zero_grad()
50+
51+
# forward + backward + optimize
52+
outputs = self.net(inputs)
53+
loss_sup = self.get_loss_value(data, outputs, y)
54+
loss_dict = {"prediction":outputs, 'image':inputs}
55+
loss_reg = reg_loss_calculator(loss_dict)
56+
57+
iter_max = self.config['training']['iter_max']
58+
ramp_up_length = wsl_cfg.get('ramp_up_length', iter_max)
59+
regular_w = 0.0
60+
if(self.glob_it > wsl_cfg.get('iter_sup', 0)):
61+
regular_w = wsl_cfg.get('regularize_w', 0.1)
62+
if(ramp_up_length is not None and self.glob_it < ramp_up_length):
63+
regular_w = regular_w * sigmoid_rampup(self.glob_it, ramp_up_length)
64+
loss = loss_sup + regular_w*loss_reg
65+
# if (self.config['training']['use'])
66+
loss.backward()
67+
self.optimizer.step()
68+
self.scheduler.step()
69+
70+
train_loss = train_loss + loss.item()
71+
train_loss_sup = train_loss_sup + loss_sup.item()
72+
train_loss_reg = train_loss_reg + loss_reg.item()
73+
# get dice evaluation for each class in annotated images
74+
if(isinstance(outputs, tuple) or isinstance(outputs, list)):
75+
outputs = outputs[0]
76+
p_argmax = torch.argmax(outputs, dim = 1, keepdim = True)
77+
p_soft = get_soft_label(p_argmax, class_num, self.tensor_type)
78+
p_soft, y = reshape_prediction_and_ground_truth(p_soft, y)
79+
dice_list = get_classwise_dice(p_soft, y)
80+
train_dice_list.append(dice_list.cpu().numpy())
81+
train_avg_loss = train_loss / iter_valid
82+
train_avg_loss_sup = train_loss_sup / iter_valid
83+
train_avg_loss_reg = train_loss_reg / iter_valid
84+
train_cls_dice = np.asarray(train_dice_list).mean(axis = 0)
85+
train_avg_dice = train_cls_dice.mean()
86+
87+
train_scalers = {'loss': train_avg_loss, 'loss_sup':train_avg_loss_sup,
88+
'loss_reg':train_avg_loss_reg, 'regular_w':regular_w,
89+
'avg_dice':train_avg_dice, 'class_dice': train_cls_dice}
90+
return train_scalers
91+

pymic/util/evaluation_seg.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,7 @@ def evaluation(config_file):
222222
for i in range(item_num):
223223
gt_name = image_items.iloc[i, 0]
224224
seg_name = image_items.iloc[i, 1]
225+
# seg_name = seg_name.replace(".nii.gz", "_pred.nii.gz")
225226
gt_full_name = gt_root + '/' + gt_name
226227
seg_full_name = seg_root_n + '/' + seg_name
227228

0 commit comments

Comments
 (0)