-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathdataset.py
More file actions
340 lines (282 loc) · 13.2 KB
/
dataset.py
File metadata and controls
340 lines (282 loc) · 13.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
import logging
import time
import h5py
import zarr
import tracemalloc
from tqdm import tqdm
import numpy as np
from collections import deque
import zarr
from torch.utils.data import IterableDataset
def get_labeled_position(label, class_value, label_any=None):
"""Sample valid idx position inside the specified class.
Sample a position inside the specified class.
Using pre-computed np.any(label == class_value, axis=2)
along third axis makes sampling more efficient. If there
is no valid position, None is returned.
Args:
label (np.array): array with label information H,W,D
class_value (int): value of specified class
label_any (list): pre-computed np.any(label == class_value, axis=2)
Returns:
list: indices of a random valid position inside the given label
"""
if label_any is None:
label_any = np.any(label == class_value, axis=2)
# Are there any positions with label == class_value?
valid_idx = np.argwhere(label_any==True)
if valid_idx.size:
# choose random valid position (2d)
rnd = np.random.randint(0, valid_idx.shape[0])
idx = valid_idx[rnd]
# Sample additional index along the third axis(=2).
# Voxel value should be equal to the class value.
valid_idx = label[idx[0], idx[1], :]
valid_idx = np.argwhere(valid_idx == class_value)[0]
rnd = np.random.choice(valid_idx)
idx = [idx[0], idx[1], rnd]
else:
idx = None
return idx
def get_random_patch_indices(patch_size, img_shape, pos=None):
""" Create random patch indices.
Creates (valid) max./min. corner indices of a patch.
If a specific position is given, the patch must contain
this index position. If position is None, a random
patch will be produced.
Args:
patch_size (np.array): patch dimensions (H,W,D)
img_shape (np.array): shape of the image (H,W,D)
pos (np.array, optional): specify position (H,W,D), wich should be
included in the sampled patch. Defaults to None.
Returns:
(np.array, np.array): patch corner indices (e.g. first axis
index_ini[0]:index_fin[0])
"""
# 3d - image array should have shape H,W,D
# if idx is given, the patch has to surround this position
if pos:
pos = np.array(pos, dtype=np.int)
min_index = np.maximum(pos-patch_size+1, 0)
max_index = np.minimum(img_shape-patch_size+1, pos+1)
else:
min_index = np.array([0, 0, 0])
max_index = img_shape-patch_size+1
# create valid patch boundaries
index_ini = np.random.randint(low=min_index, high=max_index)
index_fin = index_ini + patch_size
return index_ini, index_fin
def one_hot_to_label(data,
add_background=True):
"""Convert one hot encoded array to 1d-class values array.
Args:
data (np.array): One-hot encoded array C,H,W,D.
add_background (bool, optional): Add additional background channel (0). Defaults to True.
Returns:
np.array: C[0],H,W,D (1-dim) class value array
"""
if add_background:
background = np.invert(np.any(data, axis=0, keepdims=True))
data = np.concatenate([background, data], axis=0)
data = np.argmax(data, axis=0)
data = np.expand_dims(data, axis=0)
return data
class DataReader:
def read(self, group_key, subj_keys, dtype=True, preload=True):
pass
def read_data_to_memory(self, subject_keys, group, dtype=np.float16, preload=True):
"""Reads data from source to memory.
The dataset should be stored using the following structure:
<data_path>/<group>/<key>...
A generator function (data_generator) can be defined to read data respecting this
structure (implementations for hdf5/zarr/nifti directory are available).
Args:
subject_keys (list): identifying keys
group (str): data group name
dtype (type, optional): store dtype (default np.float16/np.uint8). Defaults to np.float16.
preload (bool, optional): if False, data will be loaded on the fly. Defaults to True.
Returns
object: collections.deque list containing the dataset
"""
logger = logging.getLogger(__name__)
logger.info(f'loading group [{group}]...')
# check timing and memory allocation
t = time.perf_counter()
tracemalloc.start()
data = deque(self.read(subject_keys, group, dtype, preload))
current, peak = tracemalloc.get_traced_memory()
logger.debug(f'finished: {time.perf_counter() - t :.3f} s, current memory usage {current / 10**9: .2f}GB, peak memory usage {peak / 10**9:.2f}GB')
return data
def get_data_shape(self, subject_keys, group):
pass
def get_data_attribute(self, subject_keys, group, attribute):
pass
def close(self):
pass
class DataReaderHDF5(DataReader):
def __init__(self, path_data):
self.path_data = path_data
self.hf = h5py.File(str(path_data), 'r')
self.logger = logging.getLogger(__name__)
def read(self, subject_keys, group, dtype=np.float16, preload=True):
for k in tqdm(subject_keys):
data = self.hf[f'{group}/{k}']
if preload:
data = data[:].astype(dtype)
yield data
def get_data_shape(self, subject_keys, group):
shapes = {}
for k in subject_keys:
shapes[k] = self.hf[f'{group}/{k}'].shape
return shapes
def get_data_attribute(self, subject_keys, group, attribute):
attr = {}
for k in subject_keys:
attr[k] = self.hf[f'{group}/{k}'].attrs[attribute]
return attr
def close(self):
self.hf.close()
def grid_patch_generator(img, patch_size, patch_overlap, **kwargs):
"""Generates grid of overlapping patches.
All patches are overlapping (2*patch_overlap per axis).
Cropping the original image by patch_overlap.
The resulting patches can be re-assembled to the
original image shape.
Additional np.pad argument can be passed via **kwargs.
Args:
img (np.array): CxHxWxD
patch_size (list/np.array): patch shape [H,W,D]
patch_overlap (list/np.array): overlap (per axis) [H,W,D]
Yields:
np.array, np.array, int: patch data CxHxWxD,
patch position [H,W,D],
patch number
"""
dim = 3
patch_size = np.array(patch_size)
img_size = np.array(img.shape[1:])
patch_overlap = np.array(patch_overlap)
cropped_patch_size = patch_size - 2*patch_overlap
n_patches = np.ceil(img_size/cropped_patch_size).astype(int)
overhead = cropped_patch_size - img_size % cropped_patch_size
padded_img = np.pad(img, [[0,0],
[patch_overlap[0], patch_overlap[0] + overhead[0]],
[patch_overlap[1], patch_overlap[1] + overhead[1]],
[patch_overlap[2], patch_overlap[2] + overhead[2]]], **kwargs)
pos = [np.arange(0, n_patches[k])*cropped_patch_size[k] for k in range(dim)]
count = -1
for p0 in pos[0]:
for p1 in pos[1]:
for p2 in pos[2]:
idx = np.array([p0, p1, p2])
idx_end = idx + patch_size
count += 1
patch = padded_img[:, idx[0]:idx_end[0], idx[1]:idx_end[1], idx[2]:idx_end[2]]
yield patch, idx, count
class GridPatchSampler(IterableDataset):
def __init__(self,
data_path,
subject_keys,
patch_size, patch_overlap,
out_channels=1,
out_dtype=np.uint8,
channel_selection=None,
image_group='images',
ReaderClass=DataReaderHDF5,
pad_args={'mode': 'symmetric'}):
"""GridPatchSampler for patch based inference.
Creates IterableDataset of overlapping patches (overlap between neighboring
patches: 2*patch_overlapping).
To assemble the original image shape use add_processed_batch(). The
number of channels for the assembled images (corresponding to the
channels of the processed patches) has to be defined by num_channels:
<num_channels>xHxWxD.
Args:
data_path (Path/str): data path (e.g. zarr/hdf5 file)
subject_keys (list): subject keys
patch_size (list/np.array): [H,W,D] patch shape
patch_overlap (list/np.array): [H,W,D] patch boundary
out_channels (int, optional): number of channels for the processed patches. Defaults to 1.
out_dtype (dtype, optional): data type of processed patches. Defaults to np.uint8.
channel_selection (dtype, optional): use only specified channels. Defaults to None.
image_group (str, optional): image group tag . Defaults to 'images'.
ReaderClass (function, optional): data reader class. Defaults to DataReaderHDF5.
pad_args (dict, optional): additional np.pad parameters. Defaults to {'mode': 'symmetric'}.
"""
self.data_path = str(data_path)
self.subject_keys = subject_keys
self.patch_size = np.array(patch_size)
self.patch_overlap = patch_overlap
self.image_group = image_group
self.ReaderClass = ReaderClass
self.out_channels = out_channels
self.channel_selection = channel_selection
self.out_dtype = out_dtype
self.results = zarr.group()
self.originals = {}
self.pad_args = pad_args
# read image data for each subject in subject_keys
reader = self.ReaderClass(self.data_path)
self.data_shape = reader.get_data_shape(self.subject_keys, self.image_group)
self.data_affine = reader.get_data_attribute(self.subject_keys, self.image_group, "affine")
self.data_generator = reader.read_data_to_memory(self.subject_keys, self.image_group, dtype=np.float16)
reader.close()
def add_processed_batch(self, sample):
"""Assembles the processed patches to the original array shape.
Args:
sample (dict): 'subject_key', 'pos', 'data' (C,H,W,D) for each patch
"""
for i, key in enumerate(sample['subject_key']):
# crop patch overlap
cropped_patch = np.array(sample['data'][i, :,
self.patch_overlap[0]:-self.patch_overlap[1],
self.patch_overlap[1]:-self.patch_overlap[1],
self.patch_overlap[2]:-self.patch_overlap[2]])
# start and end position
pos = np.array(sample['pos'][i])
pos_end = np.array(pos + np.array(cropped_patch.shape[1:]))
# check if end position is outside the original array (due to padding)
# -> crop again (overhead)
img_size = np.array(self.data_shape[key][1:])
crop_pos_end = np.minimum(pos_end, img_size)
overhead = np.maximum(pos_end - crop_pos_end, [0, 0, 0])
new_patch_size = np.array(cropped_patch.shape[1:]) - overhead
# add the patch to the corresponing entry in the result container
ds_shape = np.array(self.data_shape[key])
ds_shape[0] = self.out_channels
ds = self.results.require_dataset(key, shape=ds_shape, dtype=self.out_dtype, chunks=False)
ds.attrs["affine"] = np.array(self.data_affine[key]).tolist()
ds[:, pos[0]:pos_end[0],
pos[1]:pos_end[1],
pos[2]:pos_end[2]] = cropped_patch[:, :new_patch_size[0],
:new_patch_size[1],
:new_patch_size[2]].astype(self.out_dtype)
def get_assembled_data(self):
"""Gets the dictionary with assembled/processed images.
Returns:
dict: Dictionary containing the processed and assembled images (key=subject_key)
"""
return self.results
def grid_patch_sampler(self):
"""Data reading and patch generation.
Yields:
dict: patch dictionary (subject_key, position, count and data)
"""
# create a patch iterator
for subj_idx, sample in enumerate(tqdm(self.data_generator)):
subject_key = self.subject_keys[subj_idx]
# create patches
result_shape = np.array(sample.shape)
result_shape[0] = self.out_channels
patch_generator = grid_patch_generator(
sample, self.patch_size, self.patch_overlap, **self.pad_args)
for patch, idx, count in patch_generator:
patch_dict = {'data': patch[self.channel_selection, :, :, :],
'subject_key': subject_key,
'pos': idx,
'count': count}
yield patch_dict
def __iter__(self):
return iter(self.grid_patch_sampler())
def __len__(self):
return 1