-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathpreprocessing.py
More file actions
123 lines (95 loc) · 4.44 KB
/
preprocessing.py
File metadata and controls
123 lines (95 loc) · 4.44 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import os
import sys
import time
import zarr
import numpy as np
from joblib import Parallel, delayed
from pathlib import Path
import nibabel as nib
import nilearn
from nilearn.image import image
from Utils.utils import reorient_nii, zarr_store_data, zarr_convert2hdf5
def preprocess_image(img,target_spacing,
target_shape=None,
target_orientation=('L','A','S'),
interpolation='continuous',
fill_value=0,
dtype=np.float16):
"""Rescale and reorient nifti image.
Args:
img (nib.NiftiImage): Original nifti image
target_spacing (tuple): Target image spacing
target_shape (tuple, optional): Target image shape. Defaults to None.
target_orientation (tuple, optional): Target image orientation. Defaults to ('L','A','S').
interpolation (str, optional): Interpolation type (nilearn.image.resample). Defaults to 'continuous'.
fill_value (int, optional): Fill value. Defaults to 0.
dtype (type, optional): Data type for the resampled image. Defaults to np.float16.
Returns:
nib.NiftiImage: Processed nifti image
"""
# reorient to target orientation
img = reorient_nii(img, target_orientation)
# resample to target resolution
orig_spacing = np.array(img.header.get_zooms())
target_affine = np.copy(img.affine)
target_affine[:3, :3] = np.diag(target_spacing/orig_spacing) @ img.affine[:3, :3]
resampled_img = nilearn.image.resample_img(img,
target_affine=target_affine,
target_shape=target_shape,
interpolation=interpolation,
fill_value=fill_value)
resampled_img.set_data_dtype(dtype)
# normalization
img_data = resampled_img.get_fdata().astype(dtype)
p = np.percentile(img_data, [5, 95])
img_data = (img_data-p[0])/(p[1]-p[0])
img_data = np.clip(img_data, 0.0, 1.0)
resampled_img = nib.Nifti1Image(img_data.astype(dtype), resampled_img.affine)
return resampled_img
def preprocessing(config):
# ------ Preprocessing ------
print('Preprocessing...')
target_spacing = config.preprocessing.target_spacing
target_shape = config.preprocessing.target_shape
target_orientation = tuple(config.preprocessing.target_orientation)
num_workers = config.preprocessing.num_workers
path_to_raw_data = os.path.join(os.path.dirname(__file__), config.preprocessing.path_to_raw_data)
path_preprocessed_zarr = os.path.join(os.path.dirname(__file__), config.preprocessing.path_preprocessed_zarr)
path_preprocessed_hdf5 = os.path.join(os.path.dirname(__file__), config.preprocessing.path_preprocessed_hdf5)
image_group = config.preprocessing.image_group
# Create file dictionary
file_dict = {'images': dict()}
for i, f in enumerate((Path(path_to_raw_data)).iterdir()):
key = f.name[:6]
file_dict['images'][key] = str(f)+'/'+os.listdir(f)[0]
print(f'Found {len(file_dict["images"])} files to process ')
# get keys: path to images
keys = list(file_dict['images'])
# store results to zarr directory
store = zarr.storage.DirectoryStore(path_preprocessed_zarr)
root = zarr.group(store=store, overwrite=True)
image_group = root.require_group('images')
# Process images
print('Processing images...')
def proc(key):
try:
# load nifti
img_info = file_dict['images'][key]
img = nib.Nifti1Image.load(img_info)
# preprocess nifti image
preprocessed_img = preprocess_image(img,
target_spacing=target_spacing,
target_shape=target_shape,
target_orientation=target_orientation,
dtype=np.float32)
# store results
zarr_store_data(image_group, key,
preprocessed_img.get_fdata().astype(np.float16),
preprocessed_img.affine)
print(f'{key} : {time.perf_counter()-t:.3f}s')
except:
print(f'Error {key} \n {sys.exc_info()[0]}')
t = time.perf_counter()
Parallel(n_jobs=num_workers)(delayed(proc)(key) for key in keys)
# Convert zarr to hdf5
zarr_convert2hdf5(path_preprocessed_zarr, path_preprocessed_hdf5)