Skip to content

Commit ea24a5d

Browse files
committed
update transform
1 parent 1f2d34d commit ea24a5d

File tree

12 files changed

+807
-731
lines changed

12 files changed

+807
-731
lines changed
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
# -*- coding: utf-8 -*-
2+
from __future__ import print_function, division
3+
4+
class AbstractTransform(object):
5+
def __init__(self, params):
6+
pass
7+
8+
def __call__(self, sample):
9+
return sample
10+
11+
def inverse_transform_for_prediction(self, sample):
12+
raise(ValueError("not implemented"))

pymic/transform/crop.py

Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
# -*- coding: utf-8 -*-
2+
from __future__ import print_function, division
3+
4+
import torch
5+
import json
6+
import math
7+
import random
8+
import numpy as np
9+
from scipy import ndimage
10+
from pymic.transform.abstract_transform import AbstractTransform
11+
from pymic.util.image_process import *
12+
13+
14+
class CropWithBoundingBox(AbstractTransform):
15+
"""Crop the image (shape [C, D, H, W] or [C, H, W]) based on bounding box
16+
"""
17+
def __init__(self, params):
18+
"""
19+
start (None or tuple/list): The start index along each spatial axis.
20+
if None, calculate the start index automatically so that
21+
the cropped region is centered at the non-zero region.
22+
output_size (None or tuple/list): Desired spatial output size.
23+
if None, set it as the size of bounding box of non-zero region
24+
"""
25+
self.start = params['CropWithBoundingBox_start'.lower()]
26+
self.output_size = params['CropWithBoundingBox_output_size'.lower()]
27+
self.inverse = params['CropWithBoundingBox_inverse'.lower()]
28+
29+
def __call__(self, sample):
30+
image = sample['image']
31+
input_shape = image.shape
32+
input_dim = len(input_shape) - 1
33+
bb_min, bb_max = get_ND_bounding_box(image)
34+
bb_min, bb_max = bb_min[1:], bb_max[1:]
35+
if(self.start is None):
36+
if(self.output_size is None):
37+
crop_min, crop_max = bb_min, bb_max
38+
else:
39+
assert(len(self.output_size) == input_dim)
40+
crop_min = [int((bb_min[i] + bb_max[i] + 1)/2) - int(self.output_size[i]/2) \
41+
for i in range(input_dim)]
42+
crop_min = [max(0, crop_min[i]) for i in range(input_dim)]
43+
crop_max = [crop_min[i] + self.output_size[i] for i in range(input_dim)]
44+
else:
45+
assert(len(self.start) == input_dim)
46+
crop_min = self.start
47+
if(self.output_size is None):
48+
assert(len(self.output_size) == input_dim)
49+
crop_max = [crop_min[i] + bb_max[i] - bb_min[i] \
50+
for i in range(input_dim)]
51+
else:
52+
crop_max = [crop_min[i] + self.output_size[i] for i in range(input_dim)]
53+
crop_min = [0] + crop_min
54+
crop_max = list(input_shape[0:1]) + crop_max
55+
sample['CropWithBoundingBox_Param'] = json.dumps((input_shape, crop_min, crop_max))
56+
57+
image_t = crop_ND_volume_with_bounding_box(image, crop_min, crop_max)
58+
sample['image'] = image_t
59+
60+
if('label' in sample and sample['label'].shape[1:] == image.shape[1:]):
61+
label = sample['label']
62+
crop_max[0] = label.shape[0]
63+
label = crop_ND_volume_with_bounding_box(label, crop_min, crop_max)
64+
sample['label'] = label
65+
if('weight' in sample and sample['weight'].shape[1:] == image.shape[1:]):
66+
weight = sample['weight']
67+
crop_max[0] = weight.shape[0]
68+
weight = crop_ND_volume_with_bounding_box(weight, crop_min, crop_max)
69+
sample['weight'] = weight
70+
return sample
71+
72+
def inverse_transform_for_prediction(self, sample):
73+
''' rescale sample['predict'] (5D or 4D) to the original spatial shape.
74+
assume batch size is 1, otherwise scale may be different for
75+
different elemenets in the batch.
76+
77+
origin_shape is a 4D or 3D vector as saved in __call__().'''
78+
if(isinstance(sample['CropWithBoundingBox_Param'], list) or \
79+
isinstance(sample['CropWithBoundingBox_Param'], tuple)):
80+
params = json.loads(sample['CropWithBoundingBox_Param'][0])
81+
else:
82+
params = json.loads(sample['CropWithBoundingBox_Param'])
83+
origin_shape = params[0]
84+
crop_min = params[1]
85+
crop_max = params[2]
86+
predict = sample['predict']
87+
if(isinstance(predict, tuple) or isinstance(predict, list)):
88+
output_predict = []
89+
for predict_i in predict:
90+
origin_shape = list(predict_i.shape[:2]) + origin_shape[1:]
91+
output_predict_i = np.zeros(origin_shape, predict_i.dtype)
92+
crop_min = [0, 0] + crop_min[1:]
93+
crop_max = list(predict_i.shape[:2]) + crop_max[1:]
94+
output_predict_i = set_ND_volume_roi_with_bounding_box_range(output_predict_i,
95+
crop_min, crop_max, predict_i)
96+
output_predict.append(output_predict_i)
97+
else:
98+
origin_shape = list(predict.shape[:2]) + origin_shape[1:]
99+
output_predict = np.zeros(origin_shape, predict.dtype)
100+
crop_min = [0, 0] + crop_min[1:]
101+
crop_max = list(predict.shape[:2]) + crop_max[1:]
102+
output_predict = set_ND_volume_roi_with_bounding_box_range(output_predict,
103+
crop_min, crop_max, predict)
104+
105+
sample['predict'] = output_predict
106+
return sample
107+
108+
class RandomCrop(object):
109+
"""Randomly crop the input image (shape [C, D, H, W] or [C, H, W])
110+
"""
111+
def __init__(self, params):
112+
"""
113+
output_size (tuple or list): Desired output size [D, H, W] or [H, W].
114+
the output channel is the same as the input channel.
115+
foreground_focus (bool): If true, allow crop around the foreground.
116+
foreground_ratio (float): Specifying the probability of foreground
117+
focus cropping when foreground_focus is true.
118+
mask_label (None, or tuple / list): Specifying the foreground labels for foreground
119+
focus cropping
120+
"""
121+
self.output_size = params['RandomCrop_output_size'.lower()]
122+
self.fg_focus = params['RandomCrop_foreground_focus'.lower()]
123+
self.fg_ratio = params['RandomCrop_foreground_ratio'.lower()]
124+
self.mask_label = params['RandomCrop_mask_label'.lower()]
125+
self.inverse = params['RandomCrop_inverse'.lower()]
126+
assert isinstance(self.output_size, (list, tuple))
127+
if(self.mask_label is not None):
128+
assert isinstance(self.mask_label, (list, tuple))
129+
130+
def __call__(self, sample):
131+
image = sample['image']
132+
input_shape = image.shape
133+
input_dim = len(input_shape) - 1
134+
135+
assert(input_dim == len(self.output_size))
136+
crop_margin = [input_shape[i + 1] - self.output_size[i]\
137+
for i in range(input_dim)]
138+
crop_min = [random.randint(0, item) for item in crop_margin]
139+
if(self.fg_focus and random.random() < self.fg_ratio):
140+
label = sample['label']
141+
mask = np.zeros_like(label)
142+
for temp_lab in self.mask_label:
143+
mask = np.maximum(mask, label == temp_lab)
144+
if(mask.sum() == 0):
145+
bb_min = [0] * (input_dim + 1)
146+
bb_max = mask.shape
147+
else:
148+
bb_min, bb_max = get_ND_bounding_box(mask)
149+
bb_min, bb_max = bb_min[1:], bb_max[1:]
150+
crop_min = [random.randint(bb_min[i], bb_max[i]) - int(self.output_size[i]/2) \
151+
for i in range(input_dim)]
152+
crop_min = [max(0, item) for item in crop_min]
153+
crop_min = [min(crop_min[i], input_shape[i+1] - self.output_size[i]) \
154+
for i in range(input_dim)]
155+
156+
crop_max = [crop_min[i] + self.output_size[i] \
157+
for i in range(input_dim)]
158+
crop_min = [0] + crop_min
159+
crop_max = list(input_shape[0:1]) + crop_max
160+
sample['RandomCrop_Param'] = json.dumps((input_shape, crop_min, crop_max))
161+
162+
image_t = crop_ND_volume_with_bounding_box(image, crop_min, crop_max)
163+
sample['image'] = image_t
164+
165+
if('label' in sample and sample['label'].shape[1:] == image.shape[1:]):
166+
label = sample['label']
167+
crop_max[0] = label.shape[0]
168+
label = crop_ND_volume_with_bounding_box(label, crop_min, crop_max)
169+
sample['label'] = label
170+
if('weight' in sample and sample['weight'].shape[1:] == image.shape[1:]):
171+
weight = sample['weight']
172+
crop_max[0] = weight.shape[0]
173+
weight = crop_ND_volume_with_bounding_box(weight, crop_min, crop_max)
174+
sample['weight'] = weight
175+
return sample
176+
177+
def inverse_transform_for_prediction(self, sample):
178+
''' rescale sample['predict'] (5D or 4D) to the original spatial shape.
179+
assume batch size is 1, otherwise scale may be different for
180+
different elemenets in the batch.
181+
182+
origin_shape is a 4D or 3D vector as saved in __call__().'''
183+
if(isinstance(sample['RandomCrop_Param'], list) or \
184+
isinstance(sample['RandomCrop_Param'], tuple)):
185+
params = json.loads(sample['RandomCrop_Param'][0])
186+
else:
187+
params = json.loads(sample['RandomCrop_Param'])
188+
origin_shape = params[0]
189+
crop_min = params[1]
190+
crop_max = params[2]
191+
predict = sample['predict']
192+
origin_shape = list(predict.shape[:2]) + origin_shape[1:]
193+
output_predict = np.zeros(origin_shape, predict.dtype)
194+
crop_min = [0, 0] + crop_min[1:]
195+
crop_max = list(predict.shape[:2]) + crop_max[1:]
196+
output_predict = set_ND_volume_roi_with_bounding_box_range(output_predict,
197+
crop_min, crop_max, predict)
198+
199+
sample['predict'] = output_predict
200+
return sample

pymic/transform/flip.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
# -*- coding: utf-8 -*-
2+
from __future__ import print_function, division
3+
4+
import torch
5+
import json
6+
import math
7+
import random
8+
import numpy as np
9+
from scipy import ndimage
10+
from pymic.transform.abstract_transform import AbstractTransform
11+
from pymic.util.image_process import *
12+
13+
14+
class RandomFlip(AbstractTransform):
15+
""" random flip the image (shape [C, D, H, W] or [C, H, W]) """
16+
def __init__(self, params):
17+
"""
18+
flip_depth (bool) : random flip along depth axis or not, only used for 3D images
19+
flip_height (bool): random flip along height axis or not
20+
flip_width (bool) : random flip along width axis or not
21+
"""
22+
self.flip_depth = params['RandomFlip_flip_depth'.lower()]
23+
self.flip_height = params['RandomFlip_flip_height'.lower()]
24+
self.flip_width = params['RandomFlip_flip_width'.lower()]
25+
self.inverse = params['RandomFlip_inverse'.lower()]
26+
27+
def __call__(self, sample):
28+
image = sample['image']
29+
input_shape = image.shape
30+
input_dim = len(input_shape) - 1
31+
flip_axis = []
32+
if(self.flip_width):
33+
if(random.random() > 0.5):
34+
flip_axis.append(-1)
35+
if(self.flip_height):
36+
if(random.random() > 0.5):
37+
flip_axis.append(-2)
38+
if(input_dim == 3 and self.flip_depth):
39+
if(random.random() > 0.5):
40+
flip_axis.append(-3)
41+
42+
sample['RandomFlip_Param'] = json.dumps(flip_axis)
43+
if(len(flip_axis) > 0):
44+
# use .copy() to avoid negative strides of numpy array
45+
# current pytorch does not support negative strides
46+
image_t = np.flip(image, flip_axis).copy()
47+
sample['image'] = image_t
48+
if('label' in sample and sample['label'].shape[1:] == image.shape[1:]):
49+
sample['label'] = np.flip(sample['label'] , flip_axis).copy()
50+
if('weight' in sample and sample['weight'].shape[1:] == image.shape[1:]):
51+
sample['weight'] = np.flip(sample['weight'] , flip_axis).copy()
52+
53+
return sample
54+
55+
def inverse_transform_for_prediction(self, sample):
56+
''' flip sample['predict'] (5D or 4D) to the original direction.
57+
assume batch size is 1, otherwise flip parameter may be different for
58+
different elemenets in the batch.
59+
60+
flip_axis is a list as saved in __call__().'''
61+
if(isinstance(sample['RandomFlip_Param'], list) or \
62+
isinstance(sample['RandomFlip_Param'], tuple)):
63+
flip_axis = json.loads(sample['RandomFlip_Param'][0])
64+
else:
65+
flip_axis = json.loads(sample['RandomFlip_Param'])
66+
if(len(flip_axis) > 0):
67+
sample['predict'] = np.flip(sample['predict'] , flip_axis).copy()
68+
return sample
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# -*- coding: utf-8 -*-
2+
from __future__ import print_function, division
3+
4+
import torch
5+
import json
6+
import math
7+
import random
8+
import numpy as np
9+
from scipy import ndimage
10+
from pymic.transform.abstract_transform import AbstractTransform
11+
from pymic.util.image_process import *
12+
13+
14+
class ChannelWiseGammaCorrection(AbstractTransform):
15+
"""
16+
apply random gamma correction to each channel
17+
"""
18+
def __init__(self, params):
19+
"""
20+
(gamma_min, gamma_max) specify the range of gamma
21+
"""
22+
self.gamma_min = params['ChannelWiseGammaCorrection_gamma_min'.lower()]
23+
self.gamma_max = params['ChannelWiseGammaCorrection_gamma_max'.lower()]
24+
self.inverse = params['ChannelWiseGammaCorrection_inverse'.lower()]
25+
26+
def __call__(self, sample):
27+
image= sample['image']
28+
for chn in range(image.shape[0]):
29+
gamma_c = random.random() * (self.gamma_max - self.gamma_min) + self.gamma_min
30+
img_c = image[chn]
31+
v_min = img_c.min()
32+
v_max = img_c.max()
33+
img_c = (img_c - v_min)/(v_max - v_min)
34+
img_c = np.power(img_c, gamma_c)*(v_max - v_min) + v_min
35+
image[chn] = img_c
36+
37+
sample['image'] = image
38+
return sample
39+

0 commit comments

Comments
 (0)