Skip to content

Commit 403451e

Browse files
committed
fix:computation of noise in AddNoise
1 parent 7f0d029 commit 403451e

File tree

1 file changed

+45
-44
lines changed

1 file changed

+45
-44
lines changed

nipype/algorithms/misc.py

Lines changed: 45 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
66
Change directory to provide relative paths for doctests
77
>>> import os
8-
>>> filepath = os.path.dirname( os.path.realpath( __file__ ) )
8+
>>> filepath = os.path.dirname(os.path.realpath(__file__))
99
>>> datadir = os.path.realpath(os.path.join(filepath, '../testing/data'))
1010
>>> os.chdir(datadir)
1111
@@ -34,7 +34,7 @@
3434
from ..interfaces.base import (BaseInterface, traits, TraitedSpec, File,
3535
InputMultiPath, OutputMultiPath,
3636
BaseInterfaceInputSpec, isdefined,
37-
DynamicTraitedSpec )
37+
DynamicTraitedSpec)
3838
from nipype.utils.filemanip import fname_presuffix, split_filename
3939
iflogger = logging.getLogger('interface')
4040

@@ -785,7 +785,7 @@ def _list_outputs(self):
785785

786786
class AddCSVRowInputSpec(DynamicTraitedSpec, BaseInterfaceInputSpec):
787787
in_file = traits.File(mandatory=True, desc='Input comma-separated value (CSV) files')
788-
_outputs = traits.Dict( traits.Any, value={}, usedefault=True )
788+
_outputs = traits.Dict(traits.Any, value={}, usedefault=True)
789789

790790
def __setattr__(self, key, value):
791791
if key not in self.copyable_trait_names():
@@ -876,7 +876,7 @@ def _run_interface(self, runtime):
876876

877877
if op.exists(self.inputs.in_file):
878878
formerdf = pd.read_csv(self.inputs.in_file, index_col=0)
879-
df = pd.concat([formerdf, df], ignore_index=True )
879+
df = pd.concat([formerdf, df], ignore_index=True)
880880

881881
with open(self.inputs.in_file, 'w') as f:
882882
df.to_csv(f)
@@ -993,25 +993,25 @@ class AddNoise(BaseInterface):
993993
output_spec = AddNoiseOutputSpec
994994

995995
def _run_interface(self, runtime):
996-
in_image = nb.load( self.inputs.in_file )
996+
in_image = nb.load(self.inputs.in_file)
997997
in_data = in_image.get_data()
998998
snr = self.inputs.snr
999999

1000-
if isdefined( self.inputs.in_mask ):
1001-
in_mask = nb.load( self.inputs.in_mask ).get_data()
1000+
if isdefined(self.inputs.in_mask):
1001+
in_mask = nb.load(self.inputs.in_mask).get_data()
10021002
else:
1003-
in_mask = np.ones_like( in_data )
1003+
in_mask = np.ones_like(in_data)
10041004

10051005
result = self.gen_noise(in_data, mask=in_mask, snr_db=snr,
10061006
dist=self.inputs.dist, bg_dist=self.inputs.bg_dist)
10071007
res_im = nb.Nifti1Image(result, in_image.get_affine(), in_image.get_header())
10081008
res_im.to_filename(self._gen_output_filename())
10091009
return runtime
10101010

1011-
def _gen_output_filename( self ):
1012-
if not isdefined( self.inputs.out_file ):
1013-
_, base, _ = split_filename( self.inputs.in_file )
1014-
out_file = os.path.abspath( base + ('_SNR%03.2f' % self.inputs.snr) + '.nii.gz' )
1011+
def _gen_output_filename(self):
1012+
if not isdefined(self.inputs.out_file):
1013+
_, base, ext = split_filename(self.inputs.in_file)
1014+
out_file = os.path.abspath('%s_SNR%03.2f%s' % (base, self.inputs.snr, ext))
10151015
else:
10161016
out_file = self.inputs.out_file
10171017

@@ -1030,30 +1030,31 @@ def gen_noise(self, image, mask=None, snr_db=10.0, dist='normal', bg_dist='norma
10301030
from math import sqrt
10311031
snr = sqrt(np.power(10.0, snr_db/10.0))
10321032

1033-
if dist == 'normal':
1034-
noise = np.random.normal(size=image.shape)
1035-
else:
1036-
raise NotImplementedError('Only normal distribution is supported')
1037-
10381033
if mask is None:
10391034
mask = np.ones_like(image)
10401035

10411036
signal = image[mask>0].reshape(-1)
10421037
signal = signal - signal.mean()
1043-
S = (signal.var())**2
1038+
sigma_s = signal.var()
1039+
sigma_n = sqrt((sigma_s**2)/snr)
1040+
1041+
if dist == 'normal':
1042+
noise = np.random.normal(size=image.shape, scale=sigma_n)
1043+
else:
1044+
raise NotImplementedError('Only normal distribution is supported')
10441045

10451046
if np.any(mask==0):
10461047
if bg_dist == 'rayleigh':
1047-
bg_noise = np.random.rayleigh(size=image.shape)
1048+
bg_noise = np.random.rayleigh(size=image.shape, scale=sigma_n)
10481049
noise[mask==0] = bg_noise[mask==0]
10491050

1050-
im_noise = image + noise * (S/snr)
1051+
im_noise = image + noise
10511052
return im_noise
10521053

10531054

10541055
class NormalizeProbabilityMapSetInputSpec(TraitedSpec):
10551056
in_files = InputMultiPath(File(exists=True, mandatory=True,
1056-
desc='The tpms to be normalized') )
1057+
desc='The tpms to be normalized'))
10571058
in_mask = File(exists=True, mandatory=False,
10581059
desc='Masked voxels must sum up 1.0, 0.0 otherwise.')
10591060

@@ -1080,10 +1081,10 @@ class NormalizeProbabilityMapSet(BaseInterface):
10801081
def _run_interface(self, runtime):
10811082
mask = None
10821083

1083-
if isdefined( self.inputs.in_mask ):
1084+
if isdefined(self.inputs.in_mask):
10841085
mask = self.inputs.in_mask
10851086

1086-
self._out_filenames = normalize_tpms( self.inputs.in_files, mask )
1087+
self._out_filenames = normalize_tpms(self.inputs.in_files, mask)
10871088
return runtime
10881089

10891090
def _list_outputs(self):
@@ -1092,7 +1093,7 @@ def _list_outputs(self):
10921093
return outputs
10931094

10941095

1095-
def normalize_tpms( in_files, in_mask=None, out_files=[] ):
1096+
def normalize_tpms(in_files, in_mask=None, out_files=[]):
10961097
"""
10971098
Returns the input tissue probability maps (tpms, aka volume fractions)
10981099
normalized to sum up 1.0 at each voxel within the mask.
@@ -1101,16 +1102,16 @@ def normalize_tpms( in_files, in_mask=None, out_files=[] ):
11011102
import numpy as np
11021103
import os.path as op
11031104

1104-
in_files = np.atleast_1d( in_files ).tolist()
1105+
in_files = np.atleast_1d(in_files).tolist()
11051106

1106-
if len(out_files)!=len(in_files):
1107-
for i,finname in enumerate( in_files ):
1108-
fname,fext = op.splitext( op.basename( finname ) )
1107+
if len(out_files) != len(in_files):
1108+
for i,finname in enumerate(in_files):
1109+
fname,fext = op.splitext(op.basename(finname))
11091110
if fext == '.gz':
1110-
fname,fext2 = op.splitext( fname )
1111+
fname,fext2 = op.splitext(fname)
11111112
fext = fext2 + fext
11121113

1113-
out_file = op.abspath(fname+'_norm'+('_%02d' % i)+fext)
1114+
out_file = op.abspath('%s_norm_%02d%s' % (fname,i,fext))
11141115
out_files+= [out_file]
11151116

11161117
imgs = [nib.load(fim) for fim in in_files]
@@ -1120,39 +1121,39 @@ def normalize_tpms( in_files, in_mask=None, out_files=[] ):
11201121
img_data[img_data>0.0] = 1.0
11211122
hdr = imgs[0].get_header().copy()
11221123
hdr['data_type']= 16
1123-
hdr.set_data_dtype( 'float32' )
1124-
nib.save( nib.Nifti1Image( img_data.astype(np.float32), imgs[0].get_affine(), hdr ), out_files[0] )
1124+
hdr.set_data_dtype(np.float32)
1125+
nib.save(nib.Nifti1Image(img_data.astype(np.float32), imgs[0].get_affine(), hdr), out_files[0])
11251126
return out_files[0]
11261127

1127-
img_data = np.array( [ im.get_data() for im in imgs ] ).astype( 'f32' )
1128+
img_data = np.array([im.get_data() for im in imgs]).astype(np.float32)
11281129
#img_data[img_data>1.0] = 1.0
11291130
img_data[img_data<0.0] = 0.0
1130-
weights = np.sum( img_data, axis=0 )
1131+
weights = np.sum(img_data, axis=0)
11311132

1132-
msk = np.ones_like( imgs[0].get_data() )
1133+
msk = np.ones_like(imgs[0].get_data())
11331134
msk[ weights<= 0 ] = 0
11341135

11351136
if not in_mask is None:
1136-
msk = nib.load( in_mask ).get_data()
1137+
msk = nib.load(in_mask).get_data()
11371138
msk[ msk<=0 ] = 0
11381139
msk[ msk>0 ] = 1
11391140

1140-
msk = np.ma.masked_equal( msk, 0 )
1141+
msk = np.ma.masked_equal(msk, 0)
11411142

11421143

1143-
for i,out_file in enumerate( out_files ):
1144-
data = np.ma.masked_equal( img_data[i], 0 )
1144+
for i,out_file in enumerate(out_files):
1145+
data = np.ma.masked_equal(img_data[i], 0)
11451146
probmap = data / weights
11461147
hdr = imgs[i].get_header().copy()
11471148
hdr['data_type']= 16
1148-
hdr.set_data_dtype( 'float32' )
1149-
nib.save( nib.Nifti1Image( probmap.astype(np.float32), imgs[i].get_affine(), hdr ), out_file )
1149+
hdr.set_data_dtype('float32')
1150+
nib.save(nib.Nifti1Image(probmap.astype(np.float32), imgs[i].get_affine(), hdr), out_file)
11501151

11511152
return out_files
11521153

11531154

11541155
# Deprecated interfaces ---------------------------------------------------------
1155-
class Distance( nam.Distance ):
1156+
class Distance(nam.Distance):
11561157
"""Calculates distance between two volumes.
11571158
11581159
.. deprecated:: 0.10.0
@@ -1164,7 +1165,7 @@ def __init__(self, **inputs):
11641165
" please use nipype.algorithms.metrics.Distance"),
11651166
DeprecationWarning)
11661167

1167-
class Overlap( nam.Overlap ):
1168+
class Overlap(nam.Overlap):
11681169
"""Calculates various overlap measures between two maps.
11691170
11701171
.. deprecated:: 0.10.0
@@ -1177,7 +1178,7 @@ def __init__(self, **inputs):
11771178
DeprecationWarning)
11781179

11791180

1180-
class FuzzyOverlap( nam.FuzzyOverlap ):
1181+
class FuzzyOverlap(nam.FuzzyOverlap):
11811182
"""Calculates various overlap measures between two maps, using a fuzzy
11821183
definition.
11831184

0 commit comments

Comments
 (0)