Skip to content

Commit 66ea1c2

Browse files
committed
Merge pull request #579 from FredLoney/ants_multiple_metrics_per_stage
Support multiple metric arguments per stage.
2 parents ab8a5de + d9d7682 commit 66ea1c2

File tree

1 file changed

+119
-34
lines changed

1 file changed

+119
-34
lines changed

nipype/interfaces/ants/registration.py

Lines changed: 119 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -209,21 +209,42 @@ class RegistrationInputSpec(ANTSCommandInputSpec):
209209
initial_moving_transform_com = traits.Bool(argstr='%s',
210210
default=False, xor=['initial_moving_transform'],
211211
desc="Use center of mass for moving transform")
212-
metric = traits.List(traits.Enum("CC", "MeanSquares", "Demons",
213-
"GC", "MI", "Mattes"), mandatory=True, desc='')
212+
metric_item_trait = traits.Enum("CC", "MeanSquares", "Demons", "GC", "MI",
213+
"Mattes")
214+
metric_stage_trait = traits.Either(
215+
metric_item_trait, traits.List(metric_item_trait))
216+
metric = traits.List(metric_stage_trait, mandatory=True,
217+
desc='the metric(s) to use for each stage. '
218+
'Note that multiple metrics per stage are not supported '
219+
'in ANTS 1.9.1 and earlier.')
220+
metric_weight_item_trait = traits.Float(1.0)
221+
metric_weight_stage_trait = traits.Either(
222+
metric_weight_item_trait, traits.List(metric_weight_item_trait))
214223
metric_weight = traits.List(
215-
traits.Float(1.0), usedefault=True, requires=['metric'], mandatory=True,
216-
desc="Note that the metricWeight is currently not used. Rather, it is a temporary place \
217-
holder until multivariate metrics are available for a single stage.")
218-
### This is interpreted as number_of_bins for MI and Mattes, and as radius for all other metrics
224+
metric_weight_stage_trait, value=[1.0], usedefault=True,
225+
requires=['metric'], mandatory=True,
226+
desc='the metric weight(s) for each stage. '
227+
'The weights must sum to 1 per stage.')
228+
radius_bins_item_trait = traits.Int(5)
229+
radius_bins_stage_trait = traits.Either(
230+
radius_bins_item_trait, traits.List(radius_bins_item_trait))
219231
radius_or_number_of_bins = traits.List(
220-
traits.Int(5), usedefault=True, requires=['metric_weight'], desc='')
232+
radius_bins_stage_trait, value=[5], usedefault=True,
233+
requires=['metric_weight'],
234+
desc='the number of bins in each stage for the MI and Mattes metric, '
235+
'the radius for other metrics')
236+
sampling_strategy_item_trait = traits.Enum("Dense", "Regular", "Random", None)
237+
sampling_strategy_stage_trait = traits.Either(
238+
sampling_strategy_item_trait, traits.List(sampling_strategy_item_trait))
221239
sampling_strategy = traits.List(
222-
trait=traits.Enum("Dense", "Regular", "Random", None), value=['Dense'], minlen=1,
223-
usedefault=True, requires=['metric_weight'], desc='')
240+
trait=sampling_strategy_stage_trait, requires=['metric_weight'],
241+
desc='the metric sampling strategy (strategies) for each stage')
242+
sampling_percentage_item_trait = traits.Either(traits.Range(low=0.0, high=1.0), None)
243+
sampling_percentage_stage_trait = traits.Either(
244+
sampling_percentage_item_trait, traits.List(sampling_percentage_item_trait))
224245
sampling_percentage = traits.List(
225-
trait=traits.Either(traits.Range(low=0.0, high=1.0), None), value=[None], minlen=1,
226-
requires=['sampling_strategy'], desc='')
246+
trait=sampling_percentage_stage_trait, requires=['sampling_strategy'],
247+
desc="the metric sampling percentage(s) to use for each stage")
227248
use_estimate_learning_rate_once = traits.List(traits.Bool(), desc='')
228249
use_histogram_matching = traits.List(
229250
traits.Bool(argstr='%s'), default=True, usedefault=True)
@@ -335,19 +356,19 @@ class Registration(ANTSCommand):
335356
>>> reg1.inputs.winsorize_lower_quantile = 0.025
336357
>>> reg1.inputs.collapse_linear_transforms_to_fixed_image_header = False
337358
>>> reg1.cmdline
338-
'antsRegistration --collapse-linear-transforms-to-fixed-image-header 0 --collapse-output-transforms 0 --dimensionality 3 --initial-moving-transform [ trans.mat, 1 ] --interpolation Linear --output [ output_, output_warped_image.nii.gz ] --transform Affine[ 2.0 ] --metric Mattes[ fixed1.nii, moving1.nii, 1, 32 ,Random,0.05 ] --convergence [ 1500x200, 1e-08, 20 ] --smoothing-sigmas 1x0vox --shrink-factors 2x1 --use-estimate-learning-rate-once 1 --use-histogram-matching 1 --transform SyN[ 0.25, 3.0, 0.0 ] --metric Mattes[ fixed1.nii, moving1.nii, 1, 32 ] --convergence [ 100x50x30, 1e-09, 20 ] --smoothing-sigmas 2x1x0vox --shrink-factors 3x2x1 --use-estimate-learning-rate-once 1 --use-histogram-matching 1 --winsorize-image-intensities [ 0.025, 1.0 ] --write-composite-transform 1'
359+
'antsRegistration --collapse-linear-transforms-to-fixed-image-header 0 --collapse-output-transforms 0 --dimensionality 3 --initial-moving-transform [ trans.mat, 1 ] --interpolation Linear --output [ output_, output_warped_image.nii.gz ] --transform Affine[ 2.0 ] --metric Mattes[ fixed1.nii, moving1.nii, 1, 32, Random, 0.05 ] --convergence [ 1500x200, 1e-08, 20 ] --smoothing-sigmas 1x0vox --shrink-factors 2x1 --use-estimate-learning-rate-once 1 --use-histogram-matching 1 --transform SyN[ 0.25, 3.0, 0.0 ] --metric Mattes[ fixed1.nii, moving1.nii, 1, 32 ] --convergence [ 100x50x30, 1e-09, 20 ] --smoothing-sigmas 2x1x0vox --shrink-factors 3x2x1 --use-estimate-learning-rate-once 1 --use-histogram-matching 1 --winsorize-image-intensities [ 0.025, 1.0 ] --write-composite-transform 1'
339360
>>> reg1.run() #doctest: +SKIP
340361
341362
>>> reg2 = copy.deepcopy(reg)
342363
>>> reg2.inputs.winsorize_upper_quantile = 0.975
343364
>>> reg2.cmdline
344-
'antsRegistration --collapse-linear-transforms-to-fixed-image-header 0 --collapse-output-transforms 0 --dimensionality 3 --initial-moving-transform [ trans.mat, 1 ] --interpolation Linear --output [ output_, output_warped_image.nii.gz ] --transform Affine[ 2.0 ] --metric Mattes[ fixed1.nii, moving1.nii, 1, 32 ,Random,0.05 ] --convergence [ 1500x200, 1e-08, 20 ] --smoothing-sigmas 1x0vox --shrink-factors 2x1 --use-estimate-learning-rate-once 1 --use-histogram-matching 1 --transform SyN[ 0.25, 3.0, 0.0 ] --metric Mattes[ fixed1.nii, moving1.nii, 1, 32 ] --convergence [ 100x50x30, 1e-09, 20 ] --smoothing-sigmas 2x1x0vox --shrink-factors 3x2x1 --use-estimate-learning-rate-once 1 --use-histogram-matching 1 --winsorize-image-intensities [ 0.0, 0.975 ] --write-composite-transform 1'
365+
'antsRegistration --collapse-linear-transforms-to-fixed-image-header 0 --collapse-output-transforms 0 --dimensionality 3 --initial-moving-transform [ trans.mat, 1 ] --interpolation Linear --output [ output_, output_warped_image.nii.gz ] --transform Affine[ 2.0 ] --metric Mattes[ fixed1.nii, moving1.nii, 1, 32, Random, 0.05 ] --convergence [ 1500x200, 1e-08, 20 ] --smoothing-sigmas 1x0vox --shrink-factors 2x1 --use-estimate-learning-rate-once 1 --use-histogram-matching 1 --transform SyN[ 0.25, 3.0, 0.0 ] --metric Mattes[ fixed1.nii, moving1.nii, 1, 32 ] --convergence [ 100x50x30, 1e-09, 20 ] --smoothing-sigmas 2x1x0vox --shrink-factors 3x2x1 --use-estimate-learning-rate-once 1 --use-histogram-matching 1 --winsorize-image-intensities [ 0.0, 0.975 ] --write-composite-transform 1'
345366
346367
>>> reg3 = copy.deepcopy(reg)
347368
>>> reg3.inputs.winsorize_lower_quantile = 0.025
348369
>>> reg3.inputs.winsorize_upper_quantile = 0.975
349370
>>> reg3.cmdline
350-
'antsRegistration --collapse-linear-transforms-to-fixed-image-header 0 --collapse-output-transforms 0 --dimensionality 3 --initial-moving-transform [ trans.mat, 1 ] --interpolation Linear --output [ output_, output_warped_image.nii.gz ] --transform Affine[ 2.0 ] --metric Mattes[ fixed1.nii, moving1.nii, 1, 32 ,Random,0.05 ] --convergence [ 1500x200, 1e-08, 20 ] --smoothing-sigmas 1x0vox --shrink-factors 2x1 --use-estimate-learning-rate-once 1 --use-histogram-matching 1 --transform SyN[ 0.25, 3.0, 0.0 ] --metric Mattes[ fixed1.nii, moving1.nii, 1, 32 ] --convergence [ 100x50x30, 1e-09, 20 ] --smoothing-sigmas 2x1x0vox --shrink-factors 3x2x1 --use-estimate-learning-rate-once 1 --use-histogram-matching 1 --winsorize-image-intensities [ 0.025, 0.975 ] --write-composite-transform 1'
371+
'antsRegistration --collapse-linear-transforms-to-fixed-image-header 0 --collapse-output-transforms 0 --dimensionality 3 --initial-moving-transform [ trans.mat, 1 ] --interpolation Linear --output [ output_, output_warped_image.nii.gz ] --transform Affine[ 2.0 ] --metric Mattes[ fixed1.nii, moving1.nii, 1, 32, Random, 0.05 ] --convergence [ 1500x200, 1e-08, 20 ] --smoothing-sigmas 1x0vox --shrink-factors 2x1 --use-estimate-learning-rate-once 1 --use-histogram-matching 1 --transform SyN[ 0.25, 3.0, 0.0 ] --metric Mattes[ fixed1.nii, moving1.nii, 1, 32 ] --convergence [ 100x50x30, 1e-09, 20 ] --smoothing-sigmas 2x1x0vox --shrink-factors 3x2x1 --use-estimate-learning-rate-once 1 --use-histogram-matching 1 --winsorize-image-intensities [ 0.025, 0.975 ] --write-composite-transform 1'
351372
352373
# Test collapse transforms flag
353374
>>> reg4 = copy.deepcopy(reg)
@@ -357,32 +378,95 @@ class Registration(ANTSCommand):
357378
{'reverse_invert_flags': [True, False], 'inverse_composite_transform': ['.../nipype/testing/data/output_InverseComposite.h5'], 'warped_image': '.../nipype/testing/data/output_warped_image.nii.gz', 'inverse_warped_image': <undefined>, 'forward_invert_flags': [False, False], 'reverse_transforms': ['.../nipype/testing/data/output_0GenericAffine.mat', '.../nipype/testing/data/output_1InverseWarp.nii.gz'], 'composite_transform': ['.../nipype/testing/data/output_Composite.h5'], 'forward_transforms': ['.../nipype/testing/data/output_0GenericAffine.mat', '.../nipype/testing/data/output_1Warp.nii.gz']}
358379
>>> reg4.aggregate_outputs() #doctest: +SKIP
359380
381+
# Test multiple metrics per stage
382+
>>> reg5 = copy.deepcopy(reg)
383+
>>> reg5.inputs.metric = ['CC', ['CC', 'Mattes']]
384+
>>> reg5.inputs.metric_weight = [1, [.5]*2]
385+
>>> reg5.inputs.radius_or_number_of_bins = [4, [32]*2]
386+
>>> reg5.inputs.sampling_strategy = ['Random', None] # use default strategy in second stage
387+
>>> reg5.inputs.sampling_percentage = [0.05, [0.05, 0.10]]
388+
>>> reg5.cmdline
389+
'antsRegistration --collapse-linear-transforms-to-fixed-image-header 0 --collapse-output-transforms 0 --dimensionality 3 --initial-moving-transform [ trans.mat, 1 ] --interpolation Linear --output [ output_, output_warped_image.nii.gz ] --transform Affine[ 2.0 ] --metric CC[ fixed1.nii, moving1.nii, 1, 4, Random, 0.05 ] --convergence [ 1500x200, 1e-08, 20 ] --smoothing-sigmas 1x0vox --shrink-factors 2x1 --use-estimate-learning-rate-once 1 --use-histogram-matching 1 --transform SyN[ 0.25, 3.0, 0.0 ] --metric CC[ fixed1.nii, moving1.nii, 0.5, 32, Dense, 0.05 ] --metric Mattes[ fixed1.nii, moving1.nii, 0.5, 32, Dense, 0.1 ] --convergence [ 100x50x30, 1e-09, 20 ] --smoothing-sigmas 2x1x0vox --shrink-factors 3x2x1 --use-estimate-learning-rate-once 1 --use-histogram-matching 1 --winsorize-image-intensities [ 0.0, 1.0 ] --write-composite-transform 1'
360390
"""
391+
DEF_SAMPLING_STRATEGY = 'Dense'
392+
"""The default sampling stratey argument."""
393+
361394
_cmd = 'antsRegistration'
362395
input_spec = RegistrationInputSpec
363396
output_spec = RegistrationOutputSpec
364397
_quantilesDone = False
365398

366-
def _optionalMetricParameters(self, index):
367-
if (len(self.inputs.sampling_strategy) > index) and (self.inputs.sampling_strategy[index] is not None):
368-
if self.inputs.sampling_strategy[index] == "Dense":
369-
return '' # The default when nothing is specified
370-
if isdefined(self.inputs.sampling_percentage) and (self.inputs.sampling_percentage is not None):
371-
return ',%s,%g' % (self.inputs.sampling_strategy[index], self.inputs.sampling_percentage[index])
372-
else:
373-
return ',%s' % self.inputs.sampling_strategy[index]
374-
return ''
375-
376399
def _formatMetric(self, index):
377-
retval = []
378-
retval.append(
379-
'%s[ %s, %s, %g, %d' % (self.inputs.metric[index], self.inputs.fixed_image[0],
380-
self.inputs.moving_image[
381-
0], self.inputs.metric_weight[index],
382-
self.inputs.radius_or_number_of_bins[index]))
383-
retval.append(' %s' % self._optionalMetricParameters(index))
384-
retval.append(' ]')
385-
return "".join(retval)
400+
"""
401+
Format the antsRegistration -m metric argument(s).
402+
403+
Parameters
404+
----------
405+
index: the stage index
406+
"""
407+
# The common fixed image.
408+
fixed = self.inputs.fixed_image[0]
409+
# The common moving image.
410+
moving = self.inputs.moving_image[0]
411+
# The metric name input for the current stage.
412+
name_input = self.inputs.metric[index]
413+
# The stage-specific input dictionary.
414+
stage_inputs = dict(
415+
metric=name_input,
416+
weight=self.inputs.metric_weight[index],
417+
radius_or_bins=self.inputs.radius_or_number_of_bins[index],
418+
optional=self.inputs.radius_or_number_of_bins[index]
419+
)
420+
# The optional sampling strategy and percentage.
421+
if (isdefined(self.inputs.sampling_strategy) and self.inputs.sampling_strategy):
422+
sampling_strategy = self.inputs.sampling_strategy[index]
423+
if sampling_strategy:
424+
stage_inputs['sampling_strategy'] = sampling_strategy
425+
sampling_percentage = self.inputs.sampling_percentage
426+
if (isdefined(self.inputs.sampling_percentage) and self.inputs.sampling_percentage):
427+
sampling_percentage = self.inputs.sampling_percentage[index]
428+
if sampling_percentage:
429+
stage_inputs['sampling_percentage'] = sampling_percentage
430+
431+
# Make a list of metric specifications, one per -m command line
432+
# argument for the current stage.
433+
# If there are multiple inputs for this stage, then convert the
434+
# dictionary of list inputs into a list of metric specifications.
435+
# Otherwise, make a singleton list of the metric specification
436+
# from the non-list inputs.
437+
if isinstance(name_input, list):
438+
items = stage_inputs.items()
439+
indexes = range(0, len(name_input))
440+
specs = [{k: v[i] for k, v in items} for i in indexes]
441+
else:
442+
specs = [stage_inputs]
443+
444+
# Format the --metric command line metric arguments, one per specification.
445+
return [self._formatMetricArgument(fixed, moving, **spec) for spec in specs]
446+
447+
def _formatMetricArgument(self, fixed, moving, **kwargs):
448+
retval = '%s[ %s, %s, %g, %d' % (kwargs['metric'],
449+
fixed, moving, kwargs['weight'],
450+
kwargs['radius_or_bins'])
451+
452+
# The optional sampling strategy.
453+
if kwargs.has_key('sampling_strategy'):
454+
sampling_strategy = kwargs['sampling_strategy']
455+
elif kwargs.has_key('sampling_percentage'):
456+
# The sampling percentage is specified but not the
457+
# sampling strategy. Use the default strategy.
458+
sampling_strategy = Registration.DEF_SAMPLING_STRATEGY
459+
else:
460+
sampling_strategy = None
461+
# Format the optional sampling arguments.
462+
if sampling_strategy:
463+
retval += ', %s' % sampling_strategy
464+
if kwargs.has_key('sampling_percentage'):
465+
retval += ', %g' % kwargs['sampling_percentage']
466+
467+
retval += ' ]'
468+
469+
return retval
386470

387471
def _formatTransform(self, index):
388472
retval = []
@@ -397,7 +481,8 @@ def _formatRegistration(self):
397481
retval = []
398482
for ii in range(len(self.inputs.transforms)):
399483
retval.append('--transform %s' % (self._formatTransform(ii)))
400-
retval.append('--metric %s' % self._formatMetric(ii))
484+
for metric in self._formatMetric(ii):
485+
retval.append('--metric %s' % metric)
401486
retval.append('--convergence %s' % self._formatConvergence(ii))
402487
retval.append('--smoothing-sigmas %s%s' % (self._antsJoinList(
403488
self.inputs.smoothing_sigmas[ii]),

0 commit comments

Comments
 (0)