Skip to content

Commit e594ef8

Browse files
committed
Merge branch 'dev'
2 parents 34865c5 + 991995b commit e594ef8

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

51 files changed

+3086
-345
lines changed

.gitignore

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,6 @@
11
*.pyc
2+
build/*
3+
dist/*
4+
*egg*/*
5+
*stop*
6+
files.txt

pymic/io/h5_dataset.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
# -*- coding: utf-8 -*-
2+
3+
from logging import root
4+
import os
5+
from re import S
6+
import torch
7+
import random
8+
import h5py
9+
import pandas as pd
10+
from scipy import ndimage
11+
from torch.utils.data import Dataset
12+
from torch.utils.data.sampler import Sampler
13+
14+
class H5DataSets(Dataset):
15+
"""
16+
Dataset for loading images stored in h5 format. It generates
17+
4D tensors with dimention order [C, D, H, W] for 3D images, and
18+
3D tensors with dimention order [C, H, W] for 2D images
19+
"""
20+
def __init__(self, root_dir, sample_list_name, transform = None):
21+
self.root_dir = root_dir
22+
self.transform = transform
23+
with open(sample_list_name, 'r') as f:
24+
lines = f.readlines()
25+
self.sample_list = [item.replace('\n', '') for item in lines]
26+
27+
def __len__(self):
28+
return len(self.sample_list)
29+
30+
def __getitem__(self, idx):
31+
sample_name = self.sample_list[idx]
32+
h5f = h5py.File(self.root_dir + '/' + sample_name, 'r')
33+
image = h5f['image'][:]
34+
label = h5f['label'][:]
35+
sample = {'image': image, 'label': label}
36+
if self.transform:
37+
sample = self.transform(sample)
38+
# sample["idx"] = idx
39+
return sample
40+
41+
class TwoStreamBatchSampler(Sampler):
42+
"""Iterate two sets of indices
43+
44+
An 'epoch' is one iteration through the primary indices.
45+
During the epoch, the secondary indices are iterated through
46+
as many times as needed.
47+
"""
48+
49+
def __init__(self, primary_indices, secondary_indices, batch_size, secondary_batch_size):
50+
self.primary_indices = primary_indices
51+
self.secondary_indices = secondary_indices
52+
self.secondary_batch_size = secondary_batch_size
53+
self.primary_batch_size = batch_size - secondary_batch_size
54+
55+
assert len(self.primary_indices) >= self.primary_batch_size > 0
56+
assert len(self.secondary_indices) >= self.secondary_batch_size > 0
57+
58+
def __iter__(self):
59+
primary_iter = iterate_once(self.primary_indices)
60+
secondary_iter = iterate_eternally(self.secondary_indices)
61+
return (
62+
primary_batch + secondary_batch
63+
for (primary_batch, secondary_batch)
64+
in zip(grouper(primary_iter, self.primary_batch_size),
65+
grouper(secondary_iter, self.secondary_batch_size))
66+
)
67+
68+
def __len__(self):
69+
return len(self.primary_indices) // self.primary_batch_size
70+
71+
72+
def iterate_once(iterable):
73+
return np.random.permutation(iterable)
74+
75+
76+
def iterate_eternally(indices):
77+
def infinite_shuffles():
78+
while True:
79+
yield np.random.permutation(indices)
80+
return itertools.chain.from_iterable(infinite_shuffles())
81+
82+
83+
def grouper(iterable, n):
84+
"Collect data into fixed-length chunks or blocks"
85+
# grouper('ABCDEFG', 3) --> ABC DEF"
86+
args = [iter(iterable)] * n
87+
return zip(*args)
88+
89+
90+
if __name__ == "__main__":
91+
root_dir = "/home/guotai/disk2t/projects/semi_supervise/SSL4MIS/data/ACDC/data/slices"
92+
file_name = "/home/guotai/disk2t/projects/semi_supervise/slices.txt"
93+
dataset = H5DataSets(root_dir, file_name)
94+
train_loader = torch.utils.data.DataLoader(dataset,
95+
batch_size = 4, shuffle=True, num_workers= 1)
96+
for sample in train_loader:
97+
image = sample['image']
98+
label = sample['label']
99+
print(image.shape, label.shape)
100+
print(image.min(), image.max(), label.max())

pymic/io/image_read_write.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,8 @@
44
import os
55
import numpy as np
66
import SimpleITK as sitk
7-
87
from PIL import Image
98

10-
119
def load_nifty_volume_as_4d_array(filename):
1210
"""Read a nifty image and return a dictionay storing data array, spacing and direction
1311
output['data_array'] 4d array with shape [C, D, H, W]

pymic/io/nifty_dataset.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import torch
66
import pandas as pd
77
import numpy as np
8-
98
from torch.utils.data import Dataset, DataLoader
109
from torchvision import transforms, utils
1110
from pymic.io.image_read_write import load_image_as_nd_array
@@ -95,7 +94,6 @@ def __init__(self, root_dir, csv_file, modal_num = 1, class_num = 2,
9594
super(ClassificationDataset, self).__init__(root_dir,
9695
csv_file, modal_num, with_label, transform)
9796
self.class_num = class_num
98-
print("class number for ClassificationDataset", self.class_num)
9997

10098
def __getlabel__(self, idx):
10199
csv_keys = list(self.csv_items.keys())
@@ -129,6 +127,4 @@ def __getitem__(self, idx):
129127
sample['image_weight'] = self.__getweight__(idx)
130128
if self.transform:
131129
sample = self.transform(sample)
132-
133130
return sample
134-

pymic/loss/cls/ce.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,15 +29,13 @@ class SigmoidCELoss(nn.Module):
2929
Args:
3030
predict has a shape of [N, C] where C is the class number
3131
labels has a shape of [N, C] with binary values
32-
3332
note that predict is the digit output of a network, before using sigmoid."""
3433
def __init__(self, params):
3534
super(SigmoidCELoss, self).__init__()
3635

3736
def forward(self, loss_input_dict):
3837
predict = loss_input_dict['prediction']
3938
labels = loss_input_dict['ground_truth']
40-
# for numeric stability
4139
predict = nn.Sigmoid()(predict) * 0.999 + 5e-4
4240
loss = - labels * torch.log(predict) - (1 - labels) * torch.log( 1 - predict)
4341
loss = loss.mean()

pymic/loss/cls/l1.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66
from pymic.loss.cls.util import get_soft_label
77

88
class L1Loss(nn.Module):
9+
"""
10+
L1 (MAE) loss for classification
11+
"""
912
def __init__(self, params):
1013
super(L1Loss, self).__init__()
1114
self.l1_loss = nn.L1Loss()
@@ -20,3 +23,30 @@ def forward(self, loss_input_dict):
2023
soft_y = get_soft_label(labels, num_class, data_type)
2124
loss = self.l1_loss(predict, soft_y)
2225
return loss
26+
27+
class RectifiedLoss(nn.Module):
28+
def __init__(self, params):
29+
super(RectifiedLoss, self).__init__()
30+
# self.l1_loss = nn.L1Loss()
31+
32+
def forward(self, loss_input_dict):
33+
predict = loss_input_dict['prediction']
34+
labels = loss_input_dict['ground_truth'][:, None] # reshape to N, 1
35+
36+
# softmax = nn.Softmax(dim = 1)
37+
# predict = softmax(predict)
38+
num_class = list(predict.size())[1]
39+
data_type = 'float' if(predict.dtype is torch.float32) else 'double'
40+
soft_y = get_soft_label(labels, num_class, data_type)
41+
g = 2* soft_y - 1
42+
loss = torch.exp((g*1.5- predict) * g)
43+
mask = predict < g
44+
if (data_type == 'float'):
45+
mask = mask.float()
46+
else:
47+
mask = mask.double()
48+
w = (mask - 0.5) * g + 0.5
49+
loss = w * loss + 0.1*(g - predict) * (g - predict)
50+
loss = loss.mean()
51+
# loss = self.l1_loss(predict, soft_y)
52+
return loss

pymic/loss/loss_dict_seg.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,15 @@
22
from __future__ import print_function, division
33
import torch.nn as nn
44
from pymic.loss.seg.ce import CrossEntropyLoss, GeneralizedCrossEntropyLoss
5-
from pymic.loss.seg.dice import DiceLoss, MultiScaleDiceLoss
6-
from pymic.loss.seg.dice import FocalDiceLoss, NoiseRobustDiceLoss
5+
from pymic.loss.seg.dice import DiceLoss, FocalDiceLoss, NoiseRobustDiceLoss
6+
from pymic.loss.seg.slsr import SLSRLoss
77
from pymic.loss.seg.exp_log import ExpLogLoss
88
from pymic.loss.seg.mse import MSELoss, MAELoss
99

1010
SegLossDict = {'CrossEntropyLoss': CrossEntropyLoss,
1111
'GeneralizedCrossEntropyLoss': GeneralizedCrossEntropyLoss,
12+
'SLSRLoss': SLSRLoss,
1213
'DiceLoss': DiceLoss,
13-
'MultiScaleDiceLoss': MultiScaleDiceLoss,
1414
'FocalDiceLoss': FocalDiceLoss,
1515
'NoiseRobustDiceLoss': NoiseRobustDiceLoss,
1616
'ExpLogLoss': ExpLogLoss,

pymic/loss/seg/ce.py

Lines changed: 35 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -8,40 +8,55 @@
88
class CrossEntropyLoss(nn.Module):
99
def __init__(self, params):
1010
super(CrossEntropyLoss, self).__init__()
11-
self.enable_pix_weight = params.get('CrossEntropyLoss_Enable_Pixel_Weight'.lower(), False)
12-
self.enable_cls_weight = params.get('CrossEntropyLoss_Enable_Class_Weight'.lower(), False)
11+
if(params is None):
12+
self.softmax = True
13+
else:
14+
self.softmax = params.get('loss_softmax', True)
1315

1416
def forward(self, loss_input_dict):
1517
predict = loss_input_dict['prediction']
1618
soft_y = loss_input_dict['ground_truth']
17-
pix_w = loss_input_dict['pixel_weight']
18-
cls_w = loss_input_dict['class_weight']
19-
softmax = loss_input_dict['softmax']
19+
pix_w = loss_input_dict.get('pixel_weight', None)
2020

2121
if(isinstance(predict, (list, tuple))):
2222
predict = predict[0]
23-
if(softmax):
23+
if(self.softmax):
2424
predict = nn.Softmax(dim = 1)(predict)
2525
predict = reshape_tensor_to_2D(predict)
2626
soft_y = reshape_tensor_to_2D(soft_y)
2727

2828
# for numeric stability
2929
predict = predict * 0.999 + 5e-4
3030
ce = - soft_y* torch.log(predict)
31-
if(self.enable_cls_weight):
32-
if(cls_w is None):
33-
raise ValueError("Class weight is enabled but not defined")
34-
ce = torch.sum(ce * cls_w, dim = 1)
35-
else:
36-
ce = torch.sum(ce, dim = 1) # shape is [N]
37-
if(self.enable_pix_weight):
38-
if(pix_w is None):
39-
raise ValueError("Pixel weight is enabled but not defined")
40-
pix_w = reshape_tensor_to_2D(pix_w) # shape is [N, 1]
41-
pix_w = torch.squeeze(pix_w) # squeeze to [N]
42-
ce = torch.sum(ce * pix_w) / torch.sum(pix_w)
43-
else:
31+
ce = torch.sum(ce, dim = 1) # shape is [N]
32+
if(pix_w is None):
4433
ce = torch.mean(ce)
34+
else:
35+
pix_w = torch.squeeze(reshape_tensor_to_2D(pix_w))
36+
ce = torch.sum(pix_w * ce) / (pix_w.sum() + 1e-5)
37+
return ce
38+
39+
class PartialCrossEntropyLoss(nn.Module):
40+
def __init__(self, params):
41+
super(CrossEntropyLoss, self).__init__()
42+
self.softmax = params.get('loss_softmax', True)
43+
44+
def forward(self, loss_input_dict):
45+
predict = loss_input_dict['prediction']
46+
soft_y = loss_input_dict['ground_truth']
47+
48+
if(isinstance(predict, (list, tuple))):
49+
predict = predict[0]
50+
if(self.softmax):
51+
predict = nn.Softmax(dim = 1)(predict)
52+
predict = reshape_tensor_to_2D(predict)
53+
soft_y = reshape_tensor_to_2D(soft_y)
54+
55+
# for numeric stability
56+
predict = predict * 0.999 + 5e-4
57+
ce = - soft_y* torch.log(predict)
58+
ce = torch.sum(ce, dim = 1) # shape is [N]
59+
ce = torch.mean(ce)
4560
return ce
4661

4762
class GeneralizedCrossEntropyLoss(nn.Module):
@@ -85,4 +100,4 @@ def forward(self, loss_input_dict):
85100
gce = torch.sum(gce * pix_w) / torch.sum(pix_w)
86101
else:
87102
gce = torch.mean(gce)
88-
return gce
103+
return gce

pymic/loss/seg/deep_sup.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# -*- coding: utf-8 -*-
2+
from __future__ import print_function, division
3+
4+
import torch.nn as nn
5+
6+
class DeepSuperviseLoss(nn.Module):
7+
def __init__(self, params):
8+
super(DeepSuperviseLoss, self).__init__()
9+
self.deep_sup_weight = params.get('deep_suervise_weight', None)
10+
self.base_loss = params['base_loss']
11+
12+
def forward(self, loss_input_dict):
13+
predict = loss_input_dict['prediction']
14+
if(not isinstance(predict, (list,tuple))):
15+
raise ValueError("""For deep supervision, the prediction should
16+
be a list or a tuple""")
17+
predict_num = len(predict)
18+
if(self.deep_sup_weight is None):
19+
self.deep_sup_weight = [1.0] * predict_num
20+
else:
21+
assert(predict_num == len(self.deep_sup_weight))
22+
loss_sum, weight_sum = 0.0, 0.0
23+
for i in range(predict_num):
24+
loss_input_dict['prediction'] = predict[i]
25+
temp_loss = self.base_loss(loss_input_dict)
26+
loss_sum += temp_loss * self.deep_sup_weight[i]
27+
weight_sum += self.deep_sup_weight[i]
28+
loss = loss_sum/weight_sum
29+
return loss

0 commit comments

Comments
 (0)