21
21
from .. import config , logging
22
22
from ..utils .misc import package_check
23
23
24
- from ..interfaces .base import (BaseInterface , traits , TraitedSpec , File ,
25
- InputMultiPath , BaseInterfaceInputSpec ,
26
- isdefined )
27
- from .. utils import NUMPY_MMAP
24
+ from ..interfaces .base import (
25
+ SimpleInterface , BaseInterface , traits , TraitedSpec , File ,
26
+ InputMultiPath , BaseInterfaceInputSpec ,
27
+ isdefined )
28
28
29
29
iflogger = logging .getLogger ('interface' )
30
30
@@ -383,6 +383,7 @@ class FuzzyOverlapInputSpec(BaseInterfaceInputSpec):
383
383
File (exists = True ),
384
384
mandatory = True ,
385
385
desc = 'Test image. Requires the same dimensions as in_ref.' )
386
+ in_mask = File (exists = True , desc = 'calculate overlap only within mask' )
386
387
weighting = traits .Enum (
387
388
'none' ,
388
389
'volume' ,
@@ -403,10 +404,6 @@ class FuzzyOverlapInputSpec(BaseInterfaceInputSpec):
403
404
class FuzzyOverlapOutputSpec (TraitedSpec ):
404
405
jaccard = traits .Float (desc = 'Fuzzy Jaccard Index (fJI), all the classes' )
405
406
dice = traits .Float (desc = 'Fuzzy Dice Index (fDI), all the classes' )
406
- diff_file = File (
407
- exists = True ,
408
- desc =
409
- 'resulting difference-map of all classes, using the chosen weighting' )
410
407
class_fji = traits .List (
411
408
traits .Float (),
412
409
desc = 'Array containing the fJIs of each computed class' )
@@ -415,7 +412,7 @@ class FuzzyOverlapOutputSpec(TraitedSpec):
415
412
desc = 'Array containing the fDIs of each computed class' )
416
413
417
414
418
- class FuzzyOverlap (BaseInterface ):
415
+ class FuzzyOverlap (SimpleInterface ):
419
416
"""Calculates various overlap measures between two maps, using the fuzzy
420
417
definition proposed in: Crum et al., Generalized Overlap Measures for
421
418
Evaluation and Validation in Medical Image Analysis, IEEE Trans. Med.
@@ -439,77 +436,77 @@ class FuzzyOverlap(BaseInterface):
439
436
output_spec = FuzzyOverlapOutputSpec
440
437
441
438
def _run_interface (self , runtime ):
442
- ncomp = len (self .inputs .in_ref )
443
- assert (ncomp == len (self .inputs .in_tst ))
444
- weights = np .ones (shape = ncomp )
445
-
446
- img_ref = np .array ([
447
- nb .load (fname , mmap = NUMPY_MMAP ).get_data ()
448
- for fname in self .inputs .in_ref
449
- ])
450
- img_tst = np .array ([
451
- nb .load (fname , mmap = NUMPY_MMAP ).get_data ()
452
- for fname in self .inputs .in_tst
453
- ])
454
-
455
- msk = np .sum (img_ref , axis = 0 )
456
- msk [msk > 0 ] = 1.0
457
- tst_msk = np .sum (img_tst , axis = 0 )
458
- tst_msk [tst_msk > 0 ] = 1.0
459
-
460
- # check that volumes are normalized
461
- # img_ref[:][msk>0] = img_ref[:][msk>0] / (np.sum( img_ref, axis=0 ))[msk>0]
462
- # img_tst[tst_msk>0] = img_tst[tst_msk>0] / np.sum( img_tst, axis=0 )[tst_msk>0]
463
-
464
- self ._jaccards = []
465
- volumes = []
466
-
467
- diff_im = np .zeros (img_ref .shape )
468
-
469
- for ref_comp , tst_comp , diff_comp in zip (img_ref , img_tst , diff_im ):
470
- num = np .minimum (ref_comp , tst_comp )
471
- ddr = np .maximum (ref_comp , tst_comp )
472
- diff_comp [ddr > 0 ] += 1.0 - (num [ddr > 0 ] / ddr [ddr > 0 ])
473
- self ._jaccards .append (np .sum (num ) / np .sum (ddr ))
474
- volumes .append (np .sum (ref_comp ))
475
-
476
- self ._dices = 2.0 * (np .array (self ._jaccards ) /
477
- (np .array (self ._jaccards ) + 1.0 ))
439
+ # Load data
440
+ refdata = nb .concat_images (self .inputs .in_ref ).get_data ()
441
+ tstdata = nb .concat_images (self .inputs .in_tst ).get_data ()
442
+
443
+ # Data must have same shape
444
+ if not refdata .shape == tstdata .shape :
445
+ raise RuntimeError (
446
+ 'Size of "in_tst" %s must match that of "in_ref" %s.' %
447
+ (tstdata .shape , refdata .shape ))
478
448
449
+ # Load mask
450
+ mask = np .ones_like (refdata [..., 0 ], dtype = bool )
451
+ if isdefined (self .inputs .in_mask ):
452
+ mask = nb .load (self .inputs .in_mask ).get_data ()
453
+ mask = mask > 0
454
+ assert mask .shape == refdata .shape [:- 1 ]
455
+
456
+ ncomp = refdata .shape [- 1 ]
457
+
458
+ # Drop data outside mask
459
+ refdata = refdata [mask [..., np .newaxis ]]
460
+ tstdata = tstdata [mask [..., np .newaxis ]]
461
+
462
+ if np .any (refdata < 0.0 ):
463
+ iflogger .warning ('Negative values encountered in "in_ref" input, '
464
+ 'taking absolute values.' )
465
+ refdata = np .abs (refdata )
466
+
467
+ if np .any (tstdata < 0.0 ):
468
+ iflogger .warning ('Negative values encountered in "in_tst" input, '
469
+ 'taking absolute values.' )
470
+ tstdata = np .abs (tstdata )
471
+
472
+ if np .any (refdata > 1.0 ):
473
+ iflogger .warning ('Values greater than 1.0 found in "in_ref" input, '
474
+ 'scaling values.' )
475
+ refdata /= refdata .max ()
476
+
477
+ if np .any (tstdata > 1.0 ):
478
+ iflogger .warning ('Values greater than 1.0 found in "in_tst" input, '
479
+ 'scaling values.' )
480
+ tstdata /= tstdata .max ()
481
+
482
+ numerators = np .atleast_2d (
483
+ np .minimum (refdata , tstdata ).reshape ((- 1 , ncomp )))
484
+ denominators = np .atleast_2d (
485
+ np .maximum (refdata , tstdata ).reshape ((- 1 , ncomp )))
486
+
487
+ jaccards = numerators .sum (axis = 0 ) / denominators .sum (axis = 0 )
488
+
489
+ # Calculate weights
490
+ weights = np .ones_like (jaccards , dtype = float )
479
491
if self .inputs .weighting != "none" :
492
+ volumes = np .sum ((refdata + tstdata ) > 0 , axis = 1 ).reshape ((- 1 , ncomp ))
480
493
weights = 1.0 / np .array (volumes )
481
494
if self .inputs .weighting == "squared_vol" :
482
495
weights = weights ** 2
483
496
484
497
weights = weights / np .sum (weights )
498
+ dices = 2.0 * jaccards / (jaccards + 1.0 )
485
499
486
- setattr (self , '_jaccard' , np .sum (weights * self ._jaccards ))
487
- setattr (self , '_dice' , np .sum (weights * self ._dices ))
488
-
489
- diff = np .zeros (diff_im [0 ].shape )
490
-
491
- for w , ch in zip (weights , diff_im ):
492
- ch [msk == 0 ] = 0
493
- diff += w * ch
494
-
495
- nb .save (
496
- nb .Nifti1Image (diff ,
497
- nb .load (self .inputs .in_ref [0 ]).affine ,
498
- nb .load (self .inputs .in_ref [0 ]).header ),
499
- self .inputs .out_file )
500
+ # Fill-in the results object
501
+ self ._results ['jaccard' ] = float (np .sum (weights * jaccards ))
502
+ self ._results ['dice' ] = float (np .sum (weights * dices ))
503
+ self ._results ['class_fji' ] = [
504
+ float (v ) for v in jaccards .astype (float ).tolist ()]
505
+ self ._results ['class_fdi' ] = [
506
+ float (v ) for v in dices .astype (float ).tolist ()]
500
507
501
508
return runtime
502
509
503
- def _list_outputs (self ):
504
- outputs = self ._outputs ().get ()
505
- for method in ("dice" , "jaccard" ):
506
- outputs [method ] = getattr (self , '_' + method )
507
- # outputs['volume_difference'] = self._volume
508
- outputs ['diff_file' ] = os .path .abspath (self .inputs .out_file )
509
- outputs ['class_fji' ] = np .array (self ._jaccards ).astype (float ).tolist ()
510
- outputs ['class_fdi' ] = self ._dices .astype (float ).tolist ()
511
- return outputs
512
-
513
510
514
511
class ErrorMapInputSpec (BaseInterfaceInputSpec ):
515
512
in_ref = File (
0 commit comments