Skip to content

Commit 76b0e34

Browse files
committed
[ENH] Revise the implementation of FuzzyOverlap
- [x] Accept calculating overlap within an ``in_mask`` - [x] Drop calculation of the difference image
1 parent dc09e00 commit 76b0e34

File tree

1 file changed

+66
-69
lines changed

1 file changed

+66
-69
lines changed

nipype/algorithms/metrics.py

Lines changed: 66 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,10 @@
2121
from .. import config, logging
2222
from ..utils.misc import package_check
2323

24-
from ..interfaces.base import (BaseInterface, traits, TraitedSpec, File,
25-
InputMultiPath, BaseInterfaceInputSpec,
26-
isdefined)
27-
from ..utils import NUMPY_MMAP
24+
from ..interfaces.base import (
25+
SimpleInterface, BaseInterface, traits, TraitedSpec, File,
26+
InputMultiPath, BaseInterfaceInputSpec,
27+
isdefined)
2828

2929
iflogger = logging.getLogger('interface')
3030

@@ -383,6 +383,7 @@ class FuzzyOverlapInputSpec(BaseInterfaceInputSpec):
383383
File(exists=True),
384384
mandatory=True,
385385
desc='Test image. Requires the same dimensions as in_ref.')
386+
in_mask = File(exists=True, desc='calculate overlap only within mask')
386387
weighting = traits.Enum(
387388
'none',
388389
'volume',
@@ -403,10 +404,6 @@ class FuzzyOverlapInputSpec(BaseInterfaceInputSpec):
403404
class FuzzyOverlapOutputSpec(TraitedSpec):
404405
jaccard = traits.Float(desc='Fuzzy Jaccard Index (fJI), all the classes')
405406
dice = traits.Float(desc='Fuzzy Dice Index (fDI), all the classes')
406-
diff_file = File(
407-
exists=True,
408-
desc=
409-
'resulting difference-map of all classes, using the chosen weighting')
410407
class_fji = traits.List(
411408
traits.Float(),
412409
desc='Array containing the fJIs of each computed class')
@@ -415,7 +412,7 @@ class FuzzyOverlapOutputSpec(TraitedSpec):
415412
desc='Array containing the fDIs of each computed class')
416413

417414

418-
class FuzzyOverlap(BaseInterface):
415+
class FuzzyOverlap(SimpleInterface):
419416
"""Calculates various overlap measures between two maps, using the fuzzy
420417
definition proposed in: Crum et al., Generalized Overlap Measures for
421418
Evaluation and Validation in Medical Image Analysis, IEEE Trans. Med.
@@ -439,77 +436,77 @@ class FuzzyOverlap(BaseInterface):
439436
output_spec = FuzzyOverlapOutputSpec
440437

441438
def _run_interface(self, runtime):
442-
ncomp = len(self.inputs.in_ref)
443-
assert (ncomp == len(self.inputs.in_tst))
444-
weights = np.ones(shape=ncomp)
445-
446-
img_ref = np.array([
447-
nb.load(fname, mmap=NUMPY_MMAP).get_data()
448-
for fname in self.inputs.in_ref
449-
])
450-
img_tst = np.array([
451-
nb.load(fname, mmap=NUMPY_MMAP).get_data()
452-
for fname in self.inputs.in_tst
453-
])
454-
455-
msk = np.sum(img_ref, axis=0)
456-
msk[msk > 0] = 1.0
457-
tst_msk = np.sum(img_tst, axis=0)
458-
tst_msk[tst_msk > 0] = 1.0
459-
460-
# check that volumes are normalized
461-
# img_ref[:][msk>0] = img_ref[:][msk>0] / (np.sum( img_ref, axis=0 ))[msk>0]
462-
# img_tst[tst_msk>0] = img_tst[tst_msk>0] / np.sum( img_tst, axis=0 )[tst_msk>0]
463-
464-
self._jaccards = []
465-
volumes = []
466-
467-
diff_im = np.zeros(img_ref.shape)
468-
469-
for ref_comp, tst_comp, diff_comp in zip(img_ref, img_tst, diff_im):
470-
num = np.minimum(ref_comp, tst_comp)
471-
ddr = np.maximum(ref_comp, tst_comp)
472-
diff_comp[ddr > 0] += 1.0 - (num[ddr > 0] / ddr[ddr > 0])
473-
self._jaccards.append(np.sum(num) / np.sum(ddr))
474-
volumes.append(np.sum(ref_comp))
475-
476-
self._dices = 2.0 * (np.array(self._jaccards) /
477-
(np.array(self._jaccards) + 1.0))
439+
# Load data
440+
refdata = nb.concat_images(self.inputs.in_ref).get_data()
441+
tstdata = nb.concat_images(self.inputs.in_tst).get_data()
442+
443+
# Data must have same shape
444+
if not refdata.shape == tstdata.shape:
445+
raise RuntimeError(
446+
'Size of "in_tst" %s must match that of "in_ref" %s.' %
447+
(tstdata.shape, refdata.shape))
478448

449+
# Load mask
450+
mask = np.ones_like(refdata[..., 0], dtype=bool)
451+
if isdefined(self.inputs.in_mask):
452+
mask = nb.load(self.inputs.in_mask).get_data()
453+
mask = mask > 0
454+
assert mask.shape == refdata.shape[:-1]
455+
456+
ncomp = refdata.shape[-1]
457+
458+
# Drop data outside mask
459+
refdata = refdata[mask[..., np.newaxis]]
460+
tstdata = tstdata[mask[..., np.newaxis]]
461+
462+
if np.any(refdata < 0.0):
463+
iflogger.warning('Negative values encountered in "in_ref" input, '
464+
'taking absolute values.')
465+
refdata = np.abs(refdata)
466+
467+
if np.any(tstdata < 0.0):
468+
iflogger.warning('Negative values encountered in "in_tst" input, '
469+
'taking absolute values.')
470+
tstdata = np.abs(tstdata)
471+
472+
if np.any(refdata > 1.0):
473+
iflogger.warning('Values greater than 1.0 found in "in_ref" input, '
474+
'scaling values.')
475+
refdata /= refdata.max()
476+
477+
if np.any(tstdata > 1.0):
478+
iflogger.warning('Values greater than 1.0 found in "in_tst" input, '
479+
'scaling values.')
480+
tstdata /= tstdata.max()
481+
482+
numerators = np.atleast_2d(
483+
np.minimum(refdata, tstdata).reshape((-1, ncomp)))
484+
denominators = np.atleast_2d(
485+
np.maximum(refdata, tstdata).reshape((-1, ncomp)))
486+
487+
jaccards = numerators.sum(axis=0) / denominators.sum(axis=0)
488+
489+
# Calculate weights
490+
weights = np.ones_like(jaccards, dtype=float)
479491
if self.inputs.weighting != "none":
492+
volumes = np.sum((refdata + tstdata) > 0, axis=1).reshape((-1, ncomp))
480493
weights = 1.0 / np.array(volumes)
481494
if self.inputs.weighting == "squared_vol":
482495
weights = weights**2
483496

484497
weights = weights / np.sum(weights)
498+
dices = 2.0 * jaccards / (jaccards + 1.0)
485499

486-
setattr(self, '_jaccard', np.sum(weights * self._jaccards))
487-
setattr(self, '_dice', np.sum(weights * self._dices))
488-
489-
diff = np.zeros(diff_im[0].shape)
490-
491-
for w, ch in zip(weights, diff_im):
492-
ch[msk == 0] = 0
493-
diff += w * ch
494-
495-
nb.save(
496-
nb.Nifti1Image(diff,
497-
nb.load(self.inputs.in_ref[0]).affine,
498-
nb.load(self.inputs.in_ref[0]).header),
499-
self.inputs.out_file)
500+
# Fill-in the results object
501+
self._results['jaccard'] = float(np.sum(weights * jaccards))
502+
self._results['dice'] = float(np.sum(weights * dices))
503+
self._results['class_fji'] = [
504+
float(v) for v in jaccards.astype(float).tolist()]
505+
self._results['class_fdi'] = [
506+
float(v) for v in dices.astype(float).tolist()]
500507

501508
return runtime
502509

503-
def _list_outputs(self):
504-
outputs = self._outputs().get()
505-
for method in ("dice", "jaccard"):
506-
outputs[method] = getattr(self, '_' + method)
507-
# outputs['volume_difference'] = self._volume
508-
outputs['diff_file'] = os.path.abspath(self.inputs.out_file)
509-
outputs['class_fji'] = np.array(self._jaccards).astype(float).tolist()
510-
outputs['class_fdi'] = self._dices.astype(float).tolist()
511-
return outputs
512-
513510

514511
class ErrorMapInputSpec(BaseInterfaceInputSpec):
515512
in_ref = File(

0 commit comments

Comments
 (0)