@@ -437,6 +437,95 @@ def _list_outputs(self):
437
437
return outputs
438
438
439
439
440
+ class FuzzyOverlapInputSpec (BaseInterfaceInputSpec ):
441
+ in_ref = InputMultiPath ( File (exists = True ), mandatory = True ,
442
+ desc = "Reference image. Requires the same dimensions as in_tst." )
443
+ in_tst = InputMultiPath ( File (exists = True ), mandatory = True ,
444
+ desc = "Test image. Requires the same dimensions as in_ref." )
445
+ mask_volume = File ( exists = True , desc = "calculate overlap only within this mask." )
446
+ weighting = traits .Enum ("none" , "volume" , desc = '""none": no class-overlap weighting is performed\
447
+ "volume": computed class-overlaps are weighted by class volume' ,usedefault = True )
448
+ out_file = File ("diff.nii" , usedefault = True )
449
+
450
+
451
+ class FuzzyOverlapOutputSpec (TraitedSpec ):
452
+ jaccard = traits .Float ()
453
+ dice = traits .Float ()
454
+ diff_file = File (exists = True )
455
+
456
+
457
+ class FuzzyOverlap (BaseInterface ):
458
+ """
459
+ Calculates various overlap measures between two maps, using the fuzzy definition
460
+ proposed in: Crum et al., Generalized Overlap Measures for Evaluation and Validation
461
+ in Medical Image Analysis, IEEE Trans. Med. Ima. 25(11),pp 1451-1461, Nov. 2006.
462
+
463
+ reference.nii and test.nii are lists of 2/3D images, each element on the list containing
464
+ one volume fraction map of a class in a fuzzy partition of the domain.
465
+
466
+ Example
467
+ -------
468
+
469
+ >>> overlap = FuzzyOverlap()
470
+ >>> overlap.inputs.in_ref = [ 'ref_class0.nii', 'ref_class1.nii', 'ref_class2.nii' ]
471
+ >>> overlap.inputs.in_tst = [ 'tst_class0.nii', 'tst_class1.nii', 'tst_class2.nii' ]
472
+ >>> overlap.inputs.weighting = 'volume'
473
+ >>> res = overlap.run() # doctest: +SKIP
474
+ """
475
+
476
+ input_spec = FuzzyOverlapInputSpec
477
+ output_spec = FuzzyOverlapOutputSpec
478
+
479
+ def _bool_vec_dissimilarity (self , booldata1 , booldata2 , method ):
480
+ methods = {"dice" : dice , "jaccard" : jaccard }
481
+ if not (np .any (booldata1 ) or np .any (booldata2 )):
482
+ return 0
483
+ return 1 - methods [method ](booldata1 .flat , booldata2 .flat )
484
+
485
+ def _run_interface (self , runtime ):
486
+ ncomp = len (self .inputs .in_ref )
487
+ assert ( ncomp == len (self .inputs .in_tst ) )
488
+ weights = np .ones ( shape = ncomp )
489
+
490
+ img_ref = np .array ( [ nib .load ( fname ).get_data () for fname in self .inputs .in_ref ] )
491
+ img_tst = np .array ( [ nib .load ( fname ).get_data () for fname in self .inputs .in_tst ] )
492
+
493
+ #check that volumes are normalized
494
+ img_ref = img_ref / np .sum ( img_ref , axis = 0 )
495
+ img_tst = img_tst / np .sum ( img_tst , axis = 0 )
496
+
497
+ num = float ( np .minimum ( img_ref , img_test ) )
498
+ ddr = float ( np .maximum ( img_ref , img_test ) )
499
+ both_data = num / ddr
500
+
501
+ jaccards = np .sum ( num , axis = 0 ) / np .sum ( ddr , axis = 0 )
502
+ dices = 2.0 * jaccards / (1.0 + jaccards )
503
+
504
+ if self .inputs .weighting != "none" :
505
+ weights = 1.0 / np .sum ( img_ref , axis = 0 )
506
+
507
+ if self .inputs .weighting == "squared_vol" :
508
+ weights = weights ** 2
509
+
510
+ setattr ( self , '_jaccard' , np .sum ( weights * jaccards ) / np .sum ( weights ) )
511
+ setattr ( self , '_dice' , np .sum ( weights * dices ) / np .sum ( weights ) )
512
+
513
+ # todo, this is a N+1 dimensional file, update header and affine.
514
+ nb .save (nb .Nifti1Image (both_data , nii1 .get_affine (),
515
+ nii1 .get_header ()), self .inputs .out_file )
516
+
517
+ return runtime
518
+
519
+ def _list_outputs (self ):
520
+ outputs = self ._outputs ().get ()
521
+ for method in ("dice" , "jaccard" ):
522
+ outputs [method ] = getattr (self , '_' + method )
523
+ #outputs['volume_difference'] = self._volume
524
+ outputs ['diff_file' ] = os .path .abspath (self .inputs .out_file )
525
+ return outputs
526
+
527
+
528
+
440
529
class CreateNiftiInputSpec (BaseInterfaceInputSpec ):
441
530
data_file = File (exists = True , mandatory = True , desc = "ANALYZE img file" )
442
531
header_file = File (
0 commit comments