Skip to content

Commit 664b93c

Browse files
committed
Improvements on ErrorMap implementation
Now it uses numpy.linalg.norm
1 parent 960a374 commit 664b93c

File tree

3 files changed

+87
-50
lines changed

3 files changed

+87
-50
lines changed

nipype/algorithms/metrics.py

Lines changed: 53 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -384,7 +384,7 @@ def _run_interface(self, runtime):
384384
diff+= w* ch
385385

386386
nb.save(nb.Nifti1Image(diff, nb.load( self.inputs.in_ref[0]).get_affine(),
387-
nb.load( self.inputs.in_ref[0]).get_header()), self.inputs.out_file )
387+
nb.load(self.inputs.in_ref[0]).get_header()), self.inputs.out_file)
388388

389389

390390
return runtime
@@ -400,20 +400,20 @@ def _list_outputs(self):
400400
return outputs
401401

402402

403-
class ErrorMapInputSpec( BaseInterfaceInputSpec ):
403+
class ErrorMapInputSpec(BaseInterfaceInputSpec):
404404
in_ref = File(exists=True, mandatory=True,
405405
desc="Reference image. Requires the same dimensions as in_tst.")
406406
in_tst = File(exists=True, mandatory=True,
407407
desc="Test image. Requires the same dimensions as in_ref.")
408408
mask = File(exists=True, desc="calculate overlap only within this mask.")
409-
method = traits.Enum( "squared_diff", "eucl",
410-
desc='',
411-
usedefault=True )
412-
out_map = File( desc="Name for the output file" )
409+
metric = traits.Enum("sqeuclidean", "euclidean",
410+
desc='error map metric (as implemented in scipy cdist)',
411+
usedefault=True, mandatory=True)
412+
out_map = File(desc="Name for the output file")
413413

414-
class ErrorMapOutputSpec(TraitedSpec):
415-
out_map = File(exists=True, desc="resulting error map" )
416414

415+
class ErrorMapOutputSpec(TraitedSpec):
416+
out_map = File(exists=True, desc="resulting error map")
417417

418418

419419
class ErrorMap(BaseInterface):
@@ -429,75 +429,79 @@ class ErrorMap(BaseInterface):
429429
"""
430430
input_spec = ErrorMapInputSpec
431431
output_spec = ErrorMapOutputSpec
432-
_out_file = ""
433-
434-
435-
def _run_interface( self, runtime ):
436-
nii_ref = nb.load( self.inputs.in_ref )
437-
ref_data = np.squeeze( nii_ref.get_data() )
438-
tst_data = np.squeeze( nb.load( self.inputs.in_tst ).get_data() )
432+
_out_file = ''
439433

440-
assert( ref_data.ndim == tst_data.ndim )
434+
def _run_interface(self, runtime):
435+
from scipy.spatial.distance import cdist, pdist
436+
nii_ref = nb.load(self.inputs.in_ref)
437+
ref_data = np.squeeze(nii_ref.get_data())
438+
tst_data = np.squeeze(nb.load(self.inputs.in_tst).get_data())
439+
assert(ref_data.ndim == tst_data.ndim)
441440

441+
comps = 1
442+
mapshape = ref_data.shape
442443

443-
if ( ref_data.ndim == 4 ):
444+
if (ref_data.ndim == 4):
444445
comps = ref_data.shape[-1]
445446
mapshape = ref_data.shape[:-1]
446-
refvector = np.reshape( ref_data, (-1,comps))
447-
tstvector = np.reshape( tst_data, (-1,comps))
448-
else:
449-
mapshape = ref_data.shape
450-
refvector = ref_data.reshape(-1)
451-
tstvector = tst_data.reshape(-1)
452447

453-
if isdefined( self.inputs.mask ):
448+
if isdefined(self.inputs.mask):
454449
msk = nb.load( self.inputs.mask ).get_data()
450+
if (mapshape != msk.shape):
451+
raise RuntimeError("Mask should match volume shape, \
452+
mask is %s and volumes are %s" %
453+
(list(msk.shape), list(mapshape)))
454+
else:
455+
msk = np.ones(shape=mapshape)
456+
457+
mskvector = msk.reshape(-1)
458+
msk_idxs = np.where(mskvector==1)
459+
refvector = ref_data.reshape(-1,comps)[msk_idxs].astype(np.float32)
460+
tstvector = tst_data.reshape(-1,comps)[msk_idxs].astype(np.float32)
461+
diffvector = (refvector-tstvector)
462+
463+
if self.inputs.metric == 'sqeuclidean':
464+
errvector = diffvector**2
465+
elif self.inputs.metric == 'euclidean':
466+
X = np.hstack((refvector, tstvector))
467+
errvector = np.linalg.norm(X, axis=1)
468+
469+
if (comps > 1):
470+
errvector = np.sum(errvector, axis=1)
471+
else:
472+
errvector = np.squeeze(errvector)
455473

456-
if ( mapshape != msk.shape ):
457-
raise RuntimeError( "Mask should match volume shape, \
458-
mask is %s and volumes are %s" %
459-
( list(msk.shape), list(mapshape) ) )
460-
461-
mskvector = msk.reshape(-1)
462-
refvector = refvector * mskvector[:,np.newaxis]
463-
tstvector = tstvector * mskvector[:,np.newaxis]
464-
465-
diffvector = (tstvector-refvector)**2
466-
if ( ref_data.ndim > 1 ):
467-
diffvector = np.sum( diffvector, axis=1 )
474+
errvectorexp = np.zeros_like(mskvector)
475+
errvectorexp[msk_idxs] = errvector
468476

469-
diffmap = diffvector.reshape( mapshape )
477+
errmap = errvectorexp.reshape(mapshape)
470478

471479
hdr = nii_ref.get_header().copy()
472-
hdr.set_data_dtype( np.float32 )
480+
hdr.set_data_dtype(np.float32)
473481
hdr['data_type'] = 16
474-
hdr.set_data_shape( diffmap.shape )
482+
hdr.set_data_shape(mapshape)
475483

476-
niimap = nb.Nifti1Image( diffmap.astype( np.float32 ),
477-
nii_ref.get_affine(), hdr )
478-
479-
if not isdefined( self.inputs.out_map ):
480-
fname,ext = op.splitext( op.basename( self.inputs.in_tst ) )
484+
if not isdefined(self.inputs.out_map):
485+
fname,ext = op.splitext(op.basename(self.inputs.in_tst))
481486
if ext=='.gz':
482-
fname,ext2 = op.splitext( fname )
487+
fname,ext2 = op.splitext(fname)
483488
ext = ext2 + ext
484-
self._out_file = op.abspath( fname + "_errmap" + ext )
489+
self._out_file = op.abspath(fname + "_errmap" + ext)
485490
else:
486491
self._out_file = self.inputs.out_map
487492

488-
nb.save( niimap, self._out_file )
493+
nb.Nifti1Image(errmap.astype(np.float32), nii_ref.get_affine(),
494+
hdr).to_filename(self._out_file)
489495

490496
return runtime
491497

492498
def _list_outputs(self):
493499
outputs = self.output_spec().get()
494500
outputs['out_map'] = self._out_file
495-
496501
return outputs
497502

498503

499504
class SimilarityInputSpec(BaseInterfaceInputSpec):
500-
501505
volume1 = File(exists=True, desc="3D/4D volume", mandatory=True)
502506
volume2 = File(exists=True, desc="3D/4D volume", mandatory=True)
503507
mask1 = File(exists=True, desc="3D volume")

nipype/algorithms/tests/test_auto_ErrorMap.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@ def test_ErrorMap_inputs():
1111
in_tst=dict(mandatory=True,
1212
),
1313
mask=dict(),
14-
method=dict(usedefault=True,
14+
metric=dict(mandatory=True,
15+
usedefault=True,
1516
),
1617
out_map=dict(),
1718
)
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# AUTO-GENERATED by tools/checkspecs.py - DO NOT EDIT
2+
from nipype.testing import assert_equal
3+
from nipype.algorithms.metrics import Similarity
4+
5+
def test_Similarity_inputs():
6+
input_map = dict(ignore_exception=dict(nohash=True,
7+
usedefault=True,
8+
),
9+
mask1=dict(),
10+
mask2=dict(),
11+
metric=dict(usedefault=True,
12+
),
13+
volume1=dict(mandatory=True,
14+
),
15+
volume2=dict(mandatory=True,
16+
),
17+
)
18+
inputs = Similarity.input_spec()
19+
20+
for key, metadata in input_map.items():
21+
for metakey, value in metadata.items():
22+
yield assert_equal, getattr(inputs.traits()[key], metakey), value
23+
24+
def test_Similarity_outputs():
25+
output_map = dict(similarity=dict(),
26+
)
27+
outputs = Similarity.output_spec()
28+
29+
for key, metadata in output_map.items():
30+
for metakey, value in metadata.items():
31+
yield assert_equal, getattr(outputs.traits()[key], metakey), value
32+

0 commit comments

Comments
 (0)