Skip to content

Commit 87f2935

Browse files
committed
update net run agent and mse loss
1 parent 0d03c25 commit 87f2935

File tree

6 files changed

+191
-21
lines changed

6 files changed

+191
-21
lines changed

pymic/io/nifty_dataset.py

Lines changed: 43 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,25 +16,46 @@ class NiftyDataset(Dataset):
1616
with dimention order [C, H, W] for 2D images"""
1717

1818
def __init__(self, root_dir, csv_file, modal_num = 1,
19-
with_label = False, transform=None):
19+
with_label = False, with_weight = None, transform=None):
2020
"""
2121
Args:
2222
root_dir (string): Directory with all the images.
2323
csv_file (string): Path to the csv file with image names.
2424
modal_num (int): Number of modalities.
2525
with_label (bool): Load the data with segmentation ground truth.
26+
with_weight(bool): Load pixel-wise weight map.
2627
transform (callable, optional): Optional transform to be applied
2728
on a sample.
2829
"""
2930
self.root_dir = root_dir
3031
self.csv_items = pd.read_csv(csv_file)
3132
self.modal_num = modal_num
3233
self.with_label = with_label
34+
self.with_weight= with_weight
3335
self.transform = transform
3436

37+
if(self.with_label):
38+
self.label_idx = list(self.csv_items.keys()).index('label')
39+
if(self.with_weight):
40+
self.weight_idx = list(self.csv_items.keys()).index('weight')
41+
3542
def __len__(self):
3643
return len(self.csv_items)
3744

45+
def __getlabel__(self, idx):
46+
label_name = "{0:}/{1:}".format(self.root_dir,
47+
self.csv_items.iloc[idx, self.label_idx])
48+
label = load_image_as_nd_array(label_name)['data_array']
49+
label = np.asarray(label, np.int32)
50+
return label
51+
52+
def __getweight__(self, idx):
53+
weight_name = "{0:}/{1:}".format(self.root_dir,
54+
self.csv_items.iloc[idx, self.weight_idx])
55+
weight = load_image_as_nd_array(weight_name)['data_array']
56+
weight = np.asarray(weight, np.float32)
57+
return weight
58+
3859
def __getitem__(self, idx):
3960
names_list, image_list = [], []
4061
for i in range (self.modal_num):
@@ -46,17 +67,32 @@ def __getitem__(self, idx):
4667
image_list.append(image_data)
4768
image = np.concatenate(image_list, axis = 0)
4869
image = np.asarray(image, np.float32)
49-
5070
sample = {'image': image, 'names' : names_list[0],
5171
'origin':image_dict['origin'],
5272
'spacing': image_dict['spacing'],
5373
'direction':image_dict['direction']}
54-
if (self.with_label):
55-
label_name = "{0:}/{1:}".format(self.root_dir, self.csv_items.iloc[idx, -1])
56-
label = load_image_as_nd_array(label_name)['data_array']
57-
label = np.asarray(label, np.int32)
58-
sample['label'] = label
74+
if (self.with_label):
75+
sample['label'] = self.__getlabel__(idx)
76+
assert(image.shape[1:] == sample['label'].shape[1:])
77+
if (self.with_weight):
78+
sample['weight'] = self.__getweight__(idx)
79+
assert(image.shape[1:] == sample['weight'].shape[1:])
5980
if self.transform:
6081
sample = self.transform(sample)
6182

6283
return sample
84+
85+
86+
class ClassificationDataset(NiftyDataset):
87+
def __init__(self, root_dir, csv_file, modal_num = 1, class_num = 2,
88+
with_label = False, transform=None):
89+
super(ClassificationDataset, self).__init__(root_dir,
90+
csv_file, modal_num, with_label, transform)
91+
self.class_num = class_num
92+
print("class number for ClassificationDataset", self.class_num)
93+
94+
def __getlabel__(self, idx):
95+
label_idx = self.csv_items.iloc[idx, -1]
96+
label = np.zeros((self.class_num, ))
97+
label[label_idx] = 1
98+
return label

pymic/loss/loss_dict.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,14 @@
44
from pymic.loss.dice import DiceLoss, MultiScaleDiceLoss
55
from pymic.loss.dice import DiceWithCrossEntropyLoss, NoiseRobustDiceLoss
66
from pymic.loss.exp_log import ExpLogLoss
7+
from pymic.loss.mse import MSELoss, MAELoss
78

89
LossDict = {'CrossEntropyLoss': CrossEntropyLoss,
910
'GeneralizedCrossEntropyLoss': GeneralizedCrossEntropyLoss,
1011
'DiceLoss': DiceLoss,
1112
'MultiScaleDiceLoss': MultiScaleDiceLoss,
1213
'DiceWithCrossEntropyLoss': DiceWithCrossEntropyLoss,
1314
'NoiseRobustDiceLoss': NoiseRobustDiceLoss,
14-
'ExpLogLoss': ExpLogLoss}
15+
'ExpLogLoss': ExpLogLoss,
16+
'MSELoss': MSELoss,
17+
'MAELoss': MAELoss}

pymic/loss/mse.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import torch
2+
import torch.nn as nn
3+
from pymic.loss.util import reshape_tensor_to_2D
4+
5+
class MSELoss(nn.Module):
6+
def __init__(self, params):
7+
super(MSELoss, self).__init__()
8+
self.enable_pix_weight = params['MSELoss_Enable_Pixel_Weight'.lower()]
9+
self.enable_cls_weight = params['MSELoss_Enable_Class_Weight'.lower()]
10+
11+
def forward(self, loss_input_dict):
12+
predict = loss_input_dict['prediction']
13+
soft_y = loss_input_dict['ground_truth']
14+
pix_w = loss_input_dict['pixel_weight']
15+
cls_w = loss_input_dict['class_weight']
16+
softmax = loss_input_dict['softmax']
17+
18+
if(softmax):
19+
predict = nn.Softmax(dim = 1)(predict)
20+
predict = reshape_tensor_to_2D(predict)
21+
soft_y = reshape_tensor_to_2D(soft_y)
22+
se = self.get_prediction_error(predict, soft_y)
23+
if(self.enable_cls_weight):
24+
if(cls_w is None):
25+
raise ValueError("Class weight is enabled but not defined")
26+
mse = torch.sum(se * cls_w, dim = 1) / torch.sum(cls_w)
27+
else:
28+
mse = torch.mean(se, dim = 1)
29+
if(self.enable_pix_weight):
30+
if(pix_w is None):
31+
raise ValueError("Pixel weight is enabled but not defined")
32+
pix_w = reshape_tensor_to_2D(pix_w)
33+
mse = torch.sum(mse * pix_w) / torch.sum(pix_w)
34+
else:
35+
mse = torch.mean(mse)
36+
return mse
37+
38+
def get_prediction_error(self, predict, soft_y):
39+
diff = predict - soft_y
40+
error = diff * diff
41+
return error
42+
43+
class MAELoss(nn.MSELoss):
44+
def __init__(self, params):
45+
super(MAELoss, self).__init__(params)
46+
47+
def get_prediction_error(self, predict, soft_y):
48+
diff = predict - soft_y
49+
error = torch.abs(diff)
50+
return error

pymic/net_run/net_run_agent.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -66,26 +66,30 @@ def get_stage_dataset_from_config(self, stage):
6666

6767
if(stage == "train" or stage == "valid"):
6868
transform_names = self.config['dataset']['train_transform']
69-
with_weight = self.config['dataset']['load_pixelwise_weight']
69+
with_weight = self.config['dataset'].get('load_pixelwise_weight', False)
7070
elif(stage == "test"):
7171
transform_names = self.config['dataset']['test_transform']
7272
with_weight = False
7373
else:
7474
raise ValueError("Incorrect value for stage: {0:}".format(stage))
7575
self.transform_list = []
76-
for name in transform_names:
77-
if(name not in self.transform_dict):
78-
raise(ValueError("Undefined transform {0:}".format(name)))
79-
one_transform = self.transform_dict[name](self.config['dataset'])
80-
self.transform_list.append(one_transform)
76+
if(transform_names is None or len(transform_names) == 0):
77+
data_transform = None
78+
else:
79+
for name in transform_names:
80+
if(name not in self.transform_dict):
81+
raise(ValueError("Undefined transform {0:}".format(name)))
82+
one_transform = self.transform_dict[name](self.config['dataset'])
83+
self.transform_list.append(one_transform)
84+
data_transform = transforms.Compose(self.transform_list)
8185

8286
csv_file = self.config['dataset'].get(stage + '_csv', None)
8387
dataset = NiftyDataset(root_dir=root_dir,
8488
csv_file = csv_file,
8589
modal_num = modal_num,
8690
with_label= not (stage == 'test'),
8791
with_weight = with_weight,
88-
transform = transforms.Compose(self.transform_list))
92+
transform = data_transform )
8993
return dataset
9094

9195
def create_dataset(self):
@@ -141,7 +145,6 @@ def convert_tensor_type(self, input_tensor):
141145
def train(self):
142146
device = torch.device(self.config['training']['device_name'])
143147
self.net.to(device)
144-
145148
class_num = self.config['network']['class_num']
146149
summ_writer = SummaryWriter(self.config['training']['summary_dir'])
147150
chpt_prefx = self.config['training']['checkpoint_prefix']
@@ -171,7 +174,7 @@ def train(self):
171174
if(loss_name not in self.loss_dict):
172175
raise ValueError("Undefined loss function {0:}".format(loss_name))
173176
self.loss_calculater = self.loss_dict[loss_name](self.config['training'])
174-
177+
175178
trainIter = iter(self.train_loader)
176179
train_loss = 0
177180
train_dice_list = []

pymic/util/evaluation.py

Lines changed: 63 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import os
55
import sys
66
import math
7+
import pandas as pd
78
import random
89
import GeodisTK
910
import configparser
@@ -180,7 +181,7 @@ def get_evaluation_score(s_volume, g_volume, spacing, metric):
180181

181182
return score
182183

183-
def evaluation(config_file):
184+
def evaluation_backup(config_file):
184185
config = parse_config(config_file)['evaluation']
185186
metric = config['metric']
186187
labels = config['label_list']
@@ -218,7 +219,9 @@ def evaluation(config_file):
218219

219220
for g_folder in g_folder_list:
220221
g_name = os.path.join(g_folder, patient_names[i] + g_postfix_long)
221-
if(os.path.isfile(g_name)):
222+
if(not os.path.isfile(g_name)):
223+
g_name = g_name.replace(patient_names[i], patient_names[i] + '/' + patient_names[i])
224+
if(not os.path.isfile(g_name)):
222225
break
223226
s_dict = load_image_as_nd_array(s_name)
224227
g_dict = load_image_as_nd_array(g_name)
@@ -257,6 +260,64 @@ def evaluation(config_file):
257260
print("{0:} mean ".format(metric), score_mean)
258261
print("{0:} std ".format(metric), score_std)
259262

263+
def evaluation(config_file):
264+
config = parse_config(config_file)['evaluation']
265+
metric = config['metric']
266+
labels = config['label_list']
267+
organ_name = config['organ_name']
268+
gt_root = config['ground_truth_folder_root']
269+
seg_root = config['segmentation_folder_root']
270+
image_pair_csv = config['evaluation_image_pair']
271+
ground_truth_label_convert_source = config.get('ground_truth_label_convert_source', None)
272+
ground_truth_label_convert_target = config.get('ground_truth_label_convert_target', None)
273+
segmentation_label_convert_source = config.get('segmentation_label_convert_source', None)
274+
segmentation_label_convert_target = config.get('segmentation_label_convert_target', None)
275+
276+
image_items = pd.read_csv(image_pair_csv)
277+
item_num = len(image_items)
278+
score_all_data = []
279+
for i in range(item_num):
280+
gt_name = image_items.iloc[i, 0]
281+
seg_name = image_items.iloc[i, 1]
282+
gt_full_name = gt_root + '/' + gt_name
283+
seg_full_name = seg_root + '/' + seg_name
284+
285+
s_dict = load_image_as_nd_array(seg_full_name)
286+
g_dict = load_image_as_nd_array(gt_full_name)
287+
s_volume = s_dict["data_array"]; s_spacing = s_dict["spacing"]
288+
g_volume = g_dict["data_array"]; g_spacing = g_dict["spacing"]
289+
# for dim in range(len(s_spacing)):
290+
# assert(s_spacing[dim] == g_spacing[dim])
291+
if((ground_truth_label_convert_source is not None) and \
292+
ground_truth_label_convert_target is not None):
293+
g_volume = convert_label(g_volume, ground_truth_label_convert_source, \
294+
ground_truth_label_convert_target)
295+
296+
if((segmentation_label_convert_source is not None) and \
297+
segmentation_label_convert_target is not None):
298+
s_volume = convert_label(s_volume, segmentation_label_convert_source, \
299+
segmentation_label_convert_target)
300+
301+
# fuse multiple labels
302+
s_volume_sub = np.zeros_like(s_volume)
303+
g_volume_sub = np.zeros_like(g_volume)
304+
for lab in labels:
305+
s_volume_sub = s_volume_sub + np.asarray(s_volume == lab, np.uint8)
306+
g_volume_sub = g_volume_sub + np.asarray(g_volume == lab, np.uint8)
307+
308+
# get evaluation score
309+
temp_score = get_evaluation_score(s_volume_sub > 0, g_volume_sub > 0,
310+
s_spacing, metric)
311+
score_all_data.append(temp_score)
312+
print(seg_name, temp_score)
313+
score_all_data = np.asarray(score_all_data)
314+
score_mean = [score_all_data.mean(axis = 0)]
315+
score_std = [score_all_data.std(axis = 0)]
316+
np.savetxt("{0:}/{1:}_{2:}_all.txt".format(seg_root, organ_name, metric), score_all_data)
317+
np.savetxt("{0:}/{1:}_{2:}_mean.txt".format(seg_root, organ_name, metric), score_mean)
318+
np.savetxt("{0:}/{1:}_{2:}_std.txt".format(seg_root, organ_name, metric), score_std)
319+
print("{0:} mean ".format(metric), score_mean)
320+
print("{0:} std ".format(metric), score_std)
260321

261322
def main():
262323
if(len(sys.argv) < 2):

pymic/util/image_process.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from __future__ import absolute_import, print_function
33

44
import numpy as np
5+
import SimpleITK as sitk
56
from scipy import ndimage
67

78
def get_ND_bounding_box(volume, margin = None):
@@ -16,8 +17,8 @@ def get_ND_bounding_box(volume, margin = None):
1617
idx_min = []
1718
idx_max = []
1819
for i in range(len(input_shape)):
19-
idx_min.append(indxes[i].min())
20-
idx_max.append(indxes[i].max() + 1)
20+
idx_min.append(int(indxes[i].min()))
21+
idx_max.append(int(indxes[i].max()) + 1)
2122

2223
for i in range(len(input_shape)):
2324
idx_min[i] = max(idx_min[i] - margin[i], 0)
@@ -162,3 +163,19 @@ def convert_label(label, source_list, target_list):
162163
label_temp = label_temp * target_list[i]
163164
label_converted = label_converted + label_temp
164165
return label_converted
166+
167+
def resample_sitk_image_to_given_spacing(image, spacing, order):
168+
"""
169+
image: an sitk image object
170+
spacing: 3D tuple / list for spacing along x, y, z direction
171+
order: order for interpolation
172+
"""
173+
spacing0 = image.GetSpacing()
174+
data = sitk.GetArrayFromImage(image)
175+
zoom = [spacing0[i] / spacing[i] for i in range(3)]
176+
zoom = [zoom[2], zoom[0], zoom[1]]
177+
data = ndimage.interpolation.zoom(data, zoom, order = order)
178+
out_img = sitk.GetImageFromArray(data)
179+
out_img.SetSpacing(spacing)
180+
out_img.SetDirection(image.GetDirection())
181+
return out_img

0 commit comments

Comments
 (0)