Skip to content

Commit e2ff655

Browse files
author
GOESTERN-0968798
committed
Initial script for affine transformation of training data
1 parent fd52df3 commit e2ff655

File tree

1 file changed

+295
-0
lines changed

1 file changed

+295
-0
lines changed
Lines changed: 295 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,295 @@
1+
#!/usr/bin/python
2+
# -- coding: utf-8 --
3+
4+
import os, sys
5+
import argparse
6+
import numpy as np
7+
import multiprocessing
8+
import logging
9+
10+
import matplotlib.path as mpltPath
11+
import scipy.ndimage
12+
from scipy.spatial.transform import Rotation as R
13+
from scipy.spatial import ConvexHull
14+
15+
from tifffile import tifffile
16+
17+
# https://github.com/mrirecon/dl-segmentation-realtime-cmr/blob/main/scripts/assess_dl_seg_utils.py
18+
def mask_from_polygon_mplt(coordinates:list, img_shape:list):
19+
"""
20+
Transfer a polygon into a binary mask by using matplotlib.Path
21+
22+
:param list coordinates: List of coordinates in format [[x1,y1], [x2,y2], ...]
23+
:param list img_shape: Shape of the output mask in format [xdim,ydim]
24+
:returns: Binary 2D mask
25+
:rtype: np.array
26+
"""
27+
path = mpltPath.Path(coordinates)
28+
points = [[x+0.5,y+0.5] for y in range(img_shape[1]) for x in range(img_shape[0])]
29+
inside2 = path.contains_points(points)
30+
new_mask = np.zeros(img_shape, dtype=np.int32)
31+
count=0
32+
for y in range(img_shape[1]):
33+
for x in range(img_shape[0]):
34+
new_mask[x,y] = int(inside2[count])
35+
count += 1
36+
return new_mask
37+
38+
def transform_polygon_mp(args:tuple):
39+
"""
40+
Multiprocessing the transformation of a polygon to a binary mask.
41+
42+
:param tuple args: tuple(coordinates, img_shape)
43+
WHERE
44+
list coordinates is list of coordinates in format [[x1,y1], [x2,y2], ...]
45+
list img_shape is list of dimensions of output mask [xdim, ydim]
46+
:returns: binary mask
47+
:rtype: np.array
48+
"""
49+
(coordinates, label_value, img_shape) = args
50+
mask = mask_from_polygon_mplt(coordinates, img_shape)
51+
mask[mask > 0] = label_value
52+
return mask
53+
54+
def combined_mask_from_list(mask_list, label_values):
55+
"""
56+
Create a combined mask for an input list of binary masks and label values.
57+
The binary masks are assigned the corresponding label value.
58+
Overlap is removed, higher label values take precedence.
59+
"""
60+
new_mask = np.zeros((mask_list[0].shape[0], mask_list[0].shape[1]), dtype=np.int32)
61+
for (mask, label) in zip(mask_list, label_values):
62+
new_mask += mask
63+
new_mask[new_mask > label] = label
64+
return new_mask
65+
66+
def make_labels_convex_mp(label_arr, mp_number = 8):
67+
"""
68+
Multi-processing for creating convex labels.
69+
70+
:param np.ndarray label_arr: Array containing input labels with format [x, y, z]
71+
:param int mp_number: Number of multi-processes
72+
:return: convex labels
73+
:rtype: np.ndarray
74+
"""
75+
label_list = [label_arr[:,:,i] for i in range(label_arr.shape[2])]
76+
pool = multiprocessing.Pool(processes=mp_number)
77+
mask_list = pool.map(make_labels_convex, label_list)
78+
mask_stack = np.stack(mask_list, axis=2)
79+
return mask_stack
80+
81+
def make_labels_convex(arr, min_pixels_register = 10, min_pixels_new = 250):
82+
"""
83+
Create convex labels.
84+
Input labels with less than "min_pixels_register" are ignored and removed.
85+
Convex labels with less than "min_pixels_new" are removed.
86+
87+
:param np.ndarray arr: 2D array in format [x, y]
88+
:param int min_pixels_register: Minimum size for input labels.
89+
:param int min_pixels_new: Minimum size for output labels.
90+
"""
91+
label_values = []
92+
mask_list = []
93+
if 0 != np.max(arr):
94+
for label_idx in range(1, np.max(arr)+1):
95+
if np.count_nonzero(arr == label_idx) > min_pixels_register:
96+
mask = arr.copy()
97+
mask[mask != label_idx] = 0
98+
coordinates = []
99+
100+
for x in range(mask.shape[0]):
101+
for y in range(mask.shape[1]):
102+
if mask[x,y] != 0:
103+
coordinates.append([x,y])
104+
105+
hull = ConvexHull(coordinates)
106+
edge_coords = []
107+
108+
for idx in hull.vertices:
109+
edge_coords.append(coordinates[idx])
110+
111+
new_mask = mask_from_polygon_mplt(edge_coords, arr.shape)
112+
new_mask[new_mask > 0] = label_idx
113+
114+
if np.count_nonzero(arr == label_idx) > min_pixels_new:
115+
mask_list.append(new_mask)
116+
label_values.append(label_idx)
117+
118+
# remove small labels
119+
else:
120+
arr[arr == label_idx] = 0
121+
122+
if len(mask_list) > 0:
123+
return combined_mask_from_list(mask_list, label_values)
124+
else:
125+
return arr
126+
127+
else:
128+
return arr
129+
130+
131+
def read_tif_stack(file):
132+
"""
133+
Read stack of TIF files.
134+
"""
135+
images = tifffile.imread(file)
136+
images = np.transpose(images, (1,2,0))
137+
return images
138+
139+
def affine_transform_euler(data, euler_angles, label_flag = False):
140+
"""
141+
Affine transform using Euler angles as input.
142+
"""
143+
# Euler angles [degree]
144+
# https://quaternions.online/ for visualization
145+
(ex, ey, ez) = euler_angles
146+
147+
rot_obj = R.from_euler('xyz', [ex, ey, ez], degrees=True)
148+
149+
# calculate offset to have center of input at center of output
150+
(xdim, ydim, zdim) = data.shape
151+
x_vec = np.array([[xdim // 2, ydim // 2, zdim // 2]])
152+
rot_matrix = rot_obj.as_matrix()
153+
y_vec = np.dot(rot_matrix, x_vec.T)
154+
t_vec = x_vec.T - y_vec
155+
offset = [t_vec[0][0], t_vec[1][0], t_vec[2][0]]
156+
157+
if label_flag:
158+
result = scipy.ndimage.affine_transform(data, rot_matrix, order=0, offset=offset, prefilter=False)
159+
else:
160+
result = scipy.ndimage.affine_transform(data, rot_matrix, offset=offset)
161+
return result
162+
163+
def pad_scaled_output(arr, target_shape, pad_type = 'zero'):
164+
"""
165+
Pad input array, either with constant value 'zero'
166+
or the mean value of corner sections of the volume with the smallest standard deviation.
167+
168+
:param np.ndarray arr: Input array in format [x, y, z]
169+
:param tuple target_shape: Shape of the padded volume in format (x, y, z)
170+
:param str pad_type: Either 'zero' or 'mean'
171+
:returns: Padded input
172+
:rtype: np.ndarray
173+
"""
174+
175+
if "mean" == pad_type:
176+
corner_arrays = [arr[0:arr.shape[0]//10, 0:arr.shape[1]//10, :], arr[0:arr.shape[0]//10, -arr.shape[1]//10:], arr[-arr.shape[0]//10:, 0:arr.shape[1]//10], arr[-arr.shape[0]//10:, -arr.shape[1]//10:]]
177+
stdv = [np.std(a) for a in corner_arrays]
178+
min_std_index = stdv.index(min(stdv))
179+
pad_value = np.mean(corner_arrays[min_std_index])
180+
elif "zero" == pad_type:
181+
pad_value = 0
182+
else:
183+
sys.exit("Choose either 'zero' or 'mean' for padding.")
184+
185+
logging.info("Using padding with pad_value " + str(pad_value))
186+
187+
pad_before_x = (target_shape[0] - arr.shape[0]) // 2
188+
pad_after_x = target_shape[0] - pad_before_x - arr.shape[0]
189+
190+
pad_before_y = (target_shape[1] - arr.shape[1]) // 2
191+
pad_after_y = target_shape[1] - pad_before_y - arr.shape[1]
192+
193+
return np.pad(arr, ((pad_before_x, pad_after_x), (pad_before_y, pad_after_y), (0,0)), constant_values = pad_value)
194+
195+
def scale_by_factor(array, scale, label_flag = False):
196+
"""
197+
Scaling an array by a given factor.
198+
"""
199+
# scaling by factor 1 / s
200+
s = 1 / scale
201+
matrix = np.asarray([
202+
[s, 0, 0, 0],
203+
[0, s, 0, 0],
204+
[0, 0, 1, 0],
205+
[0, 0, 0, 1],
206+
])
207+
output_shape = (int(array.shape[0] / s), int(array.shape[1] / s), array.shape[2])
208+
209+
if label_flag:
210+
scaled = np.ndarray(output_shape, dtype=np.int32)
211+
result = scipy.ndimage.affine_transform(array, matrix, order=0, output=scaled, output_shape=output_shape, prefilter=False)
212+
else:
213+
scaled = np.ndarray(output_shape, dtype=np.uint16)
214+
result = scipy.ndimage.affine_transform(array, matrix, output=scaled, output_shape=output_shape)
215+
216+
return result
217+
218+
def main(input_file, dir_out, scale, ex, ey, ez, make_convex):
219+
# check file format
220+
if not os.path.isfile(input_file):
221+
sys.exit("Input file does not exist.")
222+
223+
if input_file.split(".")[-1] not in ["TIFF", "TIF", "tiff", "tif"]:
224+
sys.exit("Input file must be in tif format.")
225+
226+
basename = input_file.split("/")[-1].split(".tif")[0]
227+
228+
if scale != 1 and not (ex == 0 and ey == 0 and ez == 0):
229+
sys.exit("Either scaling or rotation. A combination has not been implemented yet.")
230+
231+
# check for corresponding annotations
232+
data_dir = input_file.split(basename)[0]
233+
label_path = data_dir + basename + "_annotations.tif"
234+
if not os.path.isfile(label_path):
235+
logging.debug("No corresponding label was found.")
236+
label_path = ""
237+
238+
if "" == dir_out:
239+
logging.debug("The output is stored in the directory containing the input, since no output directory has been given.")
240+
dir_out = data_dir
241+
242+
image_file = os.path.join(data_dir, basename + ".tif")
243+
images = read_tif_stack(image_file)
244+
245+
#---Images---
246+
if scale != 1:
247+
images_aff = scale_by_factor(images, scale)
248+
images_aff = pad_scaled_output(images_aff, images.shape, pad_type = "zero")
249+
save_images = os.path.join(dir_out, basename + "_aff_scaled_" + str(scale) + ".tif")
250+
251+
else:
252+
images_aff = affine_transform_euler(images, (ex, ey, ez))
253+
save_images = os.path.join(dir_out, basename + "_affExyz" + str(int(ex)).zfill(2) + str(int(ey)).zfill(2) + str(int(ez)).zfill(2) + ".tif")
254+
255+
array_out = np.transpose(images_aff, (2,0,1))
256+
tifffile.imwrite(save_images, array_out)
257+
258+
#---Labels---
259+
if label_path != "":
260+
label_file = os.path.join(data_dir, basename + "_annotations.tif")
261+
labels = read_tif_stack(label_file)
262+
263+
if scale != 1:
264+
labels_aff = scale_by_factor(labels, scale)
265+
labels_aff = pad_scaled_output(labels_aff, labels.shape, pad_type = "zero")
266+
save_labels = os.path.join(dir_out, basename + "_aff_scaled_" + str(scale) + "_annotations.tif")
267+
268+
else:
269+
labels_aff = affine_transform_euler(labels, (ex, ey, ez), label_flag = True)
270+
if make_convex:
271+
labels_aff = make_labels_convex_mp(labels_aff, mp_number = 16)
272+
273+
save_labels = os.path.join(dir_out, basename + "_affExyz" + str(int(ex)).zfill(2) + str(int(ey)).zfill(2) + str(int(ez)).zfill(2) + "_annotations.tif")
274+
275+
array_out = np.transpose(labels_aff, (2,0,1))
276+
tifffile.imwrite(save_labels, array_out)
277+
278+
if __name__ == "__main__":
279+
280+
parser = argparse.ArgumentParser(
281+
description="Script to augment LSM data in tif format using rotation or scaling.")
282+
283+
parser.add_argument('input', type=str, help="Input image file")
284+
285+
parser.add_argument('-o', "--output", type=str, default="", help="Output directory")
286+
parser.add_argument('-c', "--convex", action='store_true', help="Flag for making affine transformed output labels convex.")
287+
288+
parser.add_argument('-s', "--scale", type=float, default=1, help="Factor to scale input with affine transformation. Only supports s<=1.")
289+
parser.add_argument('--ex', type=float, default=0, help="Euler angle x")
290+
parser.add_argument('--ey', type=float, default=0, help="Euler angle y")
291+
parser.add_argument('--ez', type=float, default=0, help="Euler angle z")
292+
293+
args = parser.parse_args()
294+
295+
main(args.input, args.output, args.scale, args.ex, args.ey, args.ez, args.convex)

0 commit comments

Comments
 (0)