@@ -384,7 +384,7 @@ def _run_interface(self, runtime):
384
384
diff += w * ch
385
385
386
386
nb .save (nb .Nifti1Image (diff , nb .load ( self .inputs .in_ref [0 ]).get_affine (),
387
- nb .load ( self .inputs .in_ref [0 ]).get_header ()), self .inputs .out_file )
387
+ nb .load (self .inputs .in_ref [0 ]).get_header ()), self .inputs .out_file )
388
388
389
389
390
390
return runtime
@@ -400,20 +400,20 @@ def _list_outputs(self):
400
400
return outputs
401
401
402
402
403
- class ErrorMapInputSpec ( BaseInterfaceInputSpec ):
403
+ class ErrorMapInputSpec (BaseInterfaceInputSpec ):
404
404
in_ref = File (exists = True , mandatory = True ,
405
405
desc = "Reference image. Requires the same dimensions as in_tst." )
406
406
in_tst = File (exists = True , mandatory = True ,
407
407
desc = "Test image. Requires the same dimensions as in_ref." )
408
408
mask = File (exists = True , desc = "calculate overlap only within this mask." )
409
- method = traits .Enum ( "squared_diff " , "eucl " ,
410
- desc = '' ,
411
- usedefault = True )
412
- out_map = File ( desc = "Name for the output file" )
409
+ metric = traits .Enum ("sqeuclidean " , "euclidean " ,
410
+ desc = 'error map metric (as implemented in scipy cdist) ' ,
411
+ usedefault = True , mandatory = True )
412
+ out_map = File (desc = "Name for the output file" )
413
413
414
- class ErrorMapOutputSpec (TraitedSpec ):
415
- out_map = File (exists = True , desc = "resulting error map" )
416
414
415
+ class ErrorMapOutputSpec (TraitedSpec ):
416
+ out_map = File (exists = True , desc = "resulting error map" )
417
417
418
418
419
419
class ErrorMap (BaseInterface ):
@@ -429,75 +429,79 @@ class ErrorMap(BaseInterface):
429
429
"""
430
430
input_spec = ErrorMapInputSpec
431
431
output_spec = ErrorMapOutputSpec
432
- _out_file = ""
433
-
434
-
435
- def _run_interface ( self , runtime ):
436
- nii_ref = nb .load ( self .inputs .in_ref )
437
- ref_data = np .squeeze ( nii_ref .get_data () )
438
- tst_data = np .squeeze ( nb .load ( self .inputs .in_tst ).get_data () )
432
+ _out_file = ''
439
433
440
- assert ( ref_data .ndim == tst_data .ndim )
434
+ def _run_interface (self , runtime ):
435
+ from scipy .spatial .distance import cdist , pdist
436
+ nii_ref = nb .load (self .inputs .in_ref )
437
+ ref_data = np .squeeze (nii_ref .get_data ())
438
+ tst_data = np .squeeze (nb .load (self .inputs .in_tst ).get_data ())
439
+ assert (ref_data .ndim == tst_data .ndim )
441
440
441
+ comps = 1
442
+ mapshape = ref_data .shape
442
443
443
- if ( ref_data .ndim == 4 ):
444
+ if (ref_data .ndim == 4 ):
444
445
comps = ref_data .shape [- 1 ]
445
446
mapshape = ref_data .shape [:- 1 ]
446
- refvector = np .reshape ( ref_data , (- 1 ,comps ))
447
- tstvector = np .reshape ( tst_data , (- 1 ,comps ))
448
- else :
449
- mapshape = ref_data .shape
450
- refvector = ref_data .reshape (- 1 )
451
- tstvector = tst_data .reshape (- 1 )
452
447
453
- if isdefined ( self .inputs .mask ):
448
+ if isdefined (self .inputs .mask ):
454
449
msk = nb .load ( self .inputs .mask ).get_data ()
450
+ if (mapshape != msk .shape ):
451
+ raise RuntimeError ("Mask should match volume shape, \
452
+ mask is %s and volumes are %s" %
453
+ (list (msk .shape ), list (mapshape )))
454
+ else :
455
+ msk = np .ones (shape = mapshape )
456
+
457
+ mskvector = msk .reshape (- 1 )
458
+ msk_idxs = np .where (mskvector == 1 )
459
+ refvector = ref_data .reshape (- 1 ,comps )[msk_idxs ].astype (np .float32 )
460
+ tstvector = tst_data .reshape (- 1 ,comps )[msk_idxs ].astype (np .float32 )
461
+ diffvector = (refvector - tstvector )
462
+
463
+ if self .inputs .metric == 'sqeuclidean' :
464
+ errvector = diffvector ** 2
465
+ elif self .inputs .metric == 'euclidean' :
466
+ X = np .hstack ((refvector , tstvector ))
467
+ errvector = np .linalg .norm (X , axis = 1 )
468
+
469
+ if (comps > 1 ):
470
+ errvector = np .sum (errvector , axis = 1 )
471
+ else :
472
+ errvector = np .squeeze (errvector )
455
473
456
- if ( mapshape != msk .shape ):
457
- raise RuntimeError ( "Mask should match volume shape, \
458
- mask is %s and volumes are %s" %
459
- ( list (msk .shape ), list (mapshape ) ) )
460
-
461
- mskvector = msk .reshape (- 1 )
462
- refvector = refvector * mskvector [:,np .newaxis ]
463
- tstvector = tstvector * mskvector [:,np .newaxis ]
464
-
465
- diffvector = (tstvector - refvector )** 2
466
- if ( ref_data .ndim > 1 ):
467
- diffvector = np .sum ( diffvector , axis = 1 )
474
+ errvectorexp = np .zeros_like (mskvector )
475
+ errvectorexp [msk_idxs ] = errvector
468
476
469
- diffmap = diffvector .reshape ( mapshape )
477
+ errmap = errvectorexp .reshape (mapshape )
470
478
471
479
hdr = nii_ref .get_header ().copy ()
472
- hdr .set_data_dtype ( np .float32 )
480
+ hdr .set_data_dtype (np .float32 )
473
481
hdr ['data_type' ] = 16
474
- hdr .set_data_shape ( diffmap . shape )
482
+ hdr .set_data_shape (mapshape )
475
483
476
- niimap = nb .Nifti1Image ( diffmap .astype ( np .float32 ),
477
- nii_ref .get_affine (), hdr )
478
-
479
- if not isdefined ( self .inputs .out_map ):
480
- fname ,ext = op .splitext ( op .basename ( self .inputs .in_tst ) )
484
+ if not isdefined (self .inputs .out_map ):
485
+ fname ,ext = op .splitext (op .basename (self .inputs .in_tst ))
481
486
if ext == '.gz' :
482
- fname ,ext2 = op .splitext ( fname )
487
+ fname ,ext2 = op .splitext (fname )
483
488
ext = ext2 + ext
484
- self ._out_file = op .abspath ( fname + "_errmap" + ext )
489
+ self ._out_file = op .abspath (fname + "_errmap" + ext )
485
490
else :
486
491
self ._out_file = self .inputs .out_map
487
492
488
- nb .save ( niimap , self ._out_file )
493
+ nb .Nifti1Image (errmap .astype (np .float32 ), nii_ref .get_affine (),
494
+ hdr ).to_filename (self ._out_file )
489
495
490
496
return runtime
491
497
492
498
def _list_outputs (self ):
493
499
outputs = self .output_spec ().get ()
494
500
outputs ['out_map' ] = self ._out_file
495
-
496
501
return outputs
497
502
498
503
499
504
class SimilarityInputSpec (BaseInterfaceInputSpec ):
500
-
501
505
volume1 = File (exists = True , desc = "3D/4D volume" , mandatory = True )
502
506
volume2 = File (exists = True , desc = "3D/4D volume" , mandatory = True )
503
507
mask1 = File (exists = True , desc = "3D volume" )
0 commit comments