Skip to content

Commit db6fdf9

Browse files
committed
update denoise, improve estimation
1 parent 916b305 commit db6fdf9

File tree

6 files changed

+57
-24
lines changed

6 files changed

+57
-24
lines changed

nipype/interfaces/dipy/preprocess.py

Lines changed: 43 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,8 @@ class DenoiseInputSpec(TraitedSpec):
9898
noise_model = traits.Enum('rician', 'gaussian', mandatory=True,
9999
usedefault=True,
100100
desc=('noise distribution model'))
101+
signal_mask = File(desc=('mask in which the mean signal '
102+
'will be computed'), exists=True)
101103
noise_mask = File(desc=('mask in which the standard deviation of noise '
102104
'will be computed'), exists=True)
103105
patch_radius = traits.Int(1, desc='patch radius')
@@ -147,16 +149,19 @@ def _run_interface(self, runtime):
147149
if isdefined(self.inputs.block_radius):
148150
settings['block_radius'] = self.inputs.block_radius
149151

152+
signal_mask = None
153+
if isdefined(self.inputs.signal_mask):
154+
signal_mask = nb.load(self.inputs.signal_mask).get_data()
150155
noise_mask = None
151-
if isdefined(self.inputs.in_mask):
156+
if isdefined(self.inputs.noise_mask):
152157
noise_mask = nb.load(self.inputs.noise_mask).get_data()
153158

154-
_, s = nlmeans_proxy(self.inputs.in_file,
155-
settings,
156-
noise_mask=noise_mask,
159+
_, s = nlmeans_proxy(self.inputs.in_file, settings,
160+
smask=signal_mask,
161+
nmask=noise_mask,
157162
out_file=out_file)
158163
iflogger.info(('Denoised image saved as {i}, estimated '
159-
'sigma={s}').format(i=out_file, s=s))
164+
'SNR={s}').format(i=out_file, s=str(s)))
160165
return runtime
161166

162167
def _list_outputs(self):
@@ -209,7 +214,9 @@ def resample_proxy(in_file, order=3, new_zooms=None, out_file=None):
209214

210215

211216
def nlmeans_proxy(in_file, settings,
212-
noise_mask=None, out_file=None):
217+
smask=None,
218+
nmask=None,
219+
out_file=None):
213220
"""
214221
Uses non-local means to denoise 4D datasets
215222
"""
@@ -228,17 +235,35 @@ def nlmeans_proxy(in_file, settings,
228235
data = img.get_data()
229236
aff = img.get_affine()
230237

231-
nmask = data[..., 0] > 80
232-
if noise_mask is not None:
233-
noise_mask = np.squeeze(noise_mask)
234-
nmask = np.zeros_like(noise_mask)
235-
nmask[noise_mask > 0] = 1
236-
if nmask.ndim != data.ndim:
237-
nmask = np.array([nmask] * data.shape[-1])
238-
239-
sigma = np.std(data[nmask > 0])
240-
den = nlmeans(data, sigma, **settings)
241-
238+
if data.ndims < 4:
239+
data = data[..., np.newaxis]
240+
b0 = data[..., 0]
241+
242+
if smask is None:
243+
smask = np.zeros_like(b0)
244+
smask[b0 > np.percentile(b0, 0.85)] = 1
245+
246+
if nmask is None:
247+
nmask = np.zeros_like(b0)
248+
try:
249+
bmask = settings['mask']
250+
nmask[~bmask] = 1
251+
except AttributeError:
252+
nmask[b0 < np.percentile(b0, 0.15)] = 1
253+
else:
254+
nmask = np.squeeze(nmask)
255+
nmask[nmask > 0] = 1
256+
257+
den = np.zeros_like(data)
258+
snr = []
259+
for i in range(data.shape[-1]):
260+
d = data[..., i]
261+
s = np.mean(d[smask > 0])
262+
n = np.std(d[nmask > 0])
263+
snr.append(s/n)
264+
den[..., i] = nlmeans(d, s/n, **settings)
265+
266+
den = np.squeeze(den)
242267
nb.Nifti1Image(den.astype(hdr.get_data_dtype()), aff,
243268
hdr).to_filename(out_file)
244-
return out_file, sigma
269+
return out_file, snr

nipype/interfaces/dipy/tests/test_auto_CSD.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
from nipype.interfaces.dipy.reconstruction import CSD
44

55
def test_CSD_inputs():
6-
input_map = dict(ignore_exception=dict(nohash=True,
6+
input_map = dict(b0_thres=dict(usedefault=True,
7+
),
8+
ignore_exception=dict(nohash=True,
79
usedefault=True,
810
),
911
in_bval=dict(mandatory=True,

nipype/interfaces/dipy/tests/test_auto_Denoise.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ def test_Denoise_inputs():
1212
usedefault=True,
1313
),
1414
patch_radius=dict(),
15+
signal_mask=dict(),
1516
)
1617
inputs = Denoise.input_spec()
1718

nipype/interfaces/dipy/tests/test_auto_DipyBaseInterface.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
from nipype.interfaces.dipy.base import DipyBaseInterface
44

55
def test_DipyBaseInterface_inputs():
6-
input_map = dict(ignore_exception=dict(nohash=True,
6+
input_map = dict(b0_thres=dict(usedefault=True,
7+
),
8+
ignore_exception=dict(nohash=True,
79
usedefault=True,
810
),
911
in_bval=dict(mandatory=True,

nipype/interfaces/dipy/tests/test_auto_EstimateResponseSH.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
from nipype.interfaces.dipy.reconstruction import EstimateResponseSH
44

55
def test_EstimateResponseSH_inputs():
6-
input_map = dict(fa_thresh=dict(usedefault=True,
6+
input_map = dict(b0_thres=dict(usedefault=True,
7+
),
8+
fa_thresh=dict(usedefault=True,
79
),
810
ignore_exception=dict(nohash=True,
911
usedefault=True,
@@ -19,8 +21,7 @@ def test_EstimateResponseSH_inputs():
1921
in_mask=dict(),
2022
out_prefix=dict(),
2123
response=dict(),
22-
save_glyph=dict(mandatory=True,
23-
usedefault=True,
24+
save_glyph=dict(usedefault=True,
2425
),
2526
)
2627
inputs = EstimateResponseSH.input_spec()

nipype/interfaces/dipy/tests/test_auto_RESTORE.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
from nipype.interfaces.dipy.reconstruction import RESTORE
44

55
def test_RESTORE_inputs():
6-
input_map = dict(ignore_exception=dict(nohash=True,
6+
input_map = dict(b0_thres=dict(usedefault=True,
7+
),
8+
ignore_exception=dict(nohash=True,
79
usedefault=True,
810
),
911
in_bval=dict(mandatory=True,

0 commit comments

Comments
 (0)