Skip to content

Commit d9d7682

Browse files
author
FredLoney
committed
Support multiple metric arguments per stage.
1 parent 9608c58 commit d9d7682

File tree

1 file changed

+119
-34
lines changed

1 file changed

+119
-34
lines changed

nipype/interfaces/ants/registration.py

Lines changed: 119 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -208,21 +208,42 @@ class RegistrationInputSpec(ANTSCommandInputSpec):
208208
initial_moving_transform_com = traits.Bool(argstr='%s',
209209
xor=['initial_moving_transform'],
210210
desc="Use center of mass for moving transform")
211-
metric = traits.List(traits.Enum("CC", "MeanSquares", "Demons",
212-
"GC", "MI", "Mattes"), mandatory=True, desc='')
211+
metric_item_trait = traits.Enum("CC", "MeanSquares", "Demons", "GC", "MI",
212+
"Mattes")
213+
metric_stage_trait = traits.Either(
214+
metric_item_trait, traits.List(metric_item_trait))
215+
metric = traits.List(metric_stage_trait, mandatory=True,
216+
desc='the metric(s) to use for each stage. '
217+
'Note that multiple metrics per stage are not supported '
218+
'in ANTS 1.9.1 and earlier.')
219+
metric_weight_item_trait = traits.Float(1.0)
220+
metric_weight_stage_trait = traits.Either(
221+
metric_weight_item_trait, traits.List(metric_weight_item_trait))
213222
metric_weight = traits.List(
214-
traits.Float(1.0), usedefault=True, requires=['metric'], mandatory=True,
215-
desc="Note that the metricWeight is currently not used. Rather, it is a temporary place \
216-
holder until multivariate metrics are available for a single stage.")
217-
### This is interpreted as number_of_bins for MI and Mattes, and as radius for all other metrics
223+
metric_weight_stage_trait, value=[1.0], usedefault=True,
224+
requires=['metric'], mandatory=True,
225+
desc='the metric weight(s) for each stage. '
226+
'The weights must sum to 1 per stage.')
227+
radius_bins_item_trait = traits.Int(5)
228+
radius_bins_stage_trait = traits.Either(
229+
radius_bins_item_trait, traits.List(radius_bins_item_trait))
218230
radius_or_number_of_bins = traits.List(
219-
traits.Int(5), usedefault=True, requires=['metric_weight'], desc='')
231+
radius_bins_stage_trait, value=[5], usedefault=True,
232+
requires=['metric_weight'],
233+
desc='the number of bins in each stage for the MI and Mattes metric, '
234+
'the radius for other metrics')
235+
sampling_strategy_item_trait = traits.Enum("Dense", "Regular", "Random", None)
236+
sampling_strategy_stage_trait = traits.Either(
237+
sampling_strategy_item_trait, traits.List(sampling_strategy_item_trait))
220238
sampling_strategy = traits.List(
221-
trait=traits.Enum("Dense", "Regular", "Random", None), value=['Dense'], minlen=1,
222-
usedefault=True, requires=['metric_weight'], desc='')
239+
trait=sampling_strategy_stage_trait, requires=['metric_weight'],
240+
desc='the metric sampling strategy (strategies) for each stage')
241+
sampling_percentage_item_trait = traits.Either(traits.Range(low=0.0, high=1.0), None)
242+
sampling_percentage_stage_trait = traits.Either(
243+
sampling_percentage_item_trait, traits.List(sampling_percentage_item_trait))
223244
sampling_percentage = traits.List(
224-
trait=traits.Either(traits.Range(low=0.0, high=1.0), None), value=[None], minlen=1,
225-
requires=['sampling_strategy'], desc='')
245+
trait=sampling_percentage_stage_trait, requires=['sampling_strategy'],
246+
desc="the metric sampling percentage(s) to use for each stage")
226247
use_estimate_learning_rate_once = traits.List(traits.Bool(), desc='')
227248
use_histogram_matching = traits.List(
228249
traits.Bool(argstr='%s'), default=True, usedefault=True)
@@ -334,19 +355,19 @@ class Registration(ANTSCommand):
334355
>>> reg1.inputs.winsorize_lower_quantile = 0.025
335356
>>> reg1.inputs.collapse_linear_transforms_to_fixed_image_header = False
336357
>>> reg1.cmdline
337-
'antsRegistration --collapse-linear-transforms-to-fixed-image-header 0 --collapse-output-transforms 0 --dimensionality 3 --initial-moving-transform [ trans.mat, 1 ] --interpolation Linear --output [ output_, output_warped_image.nii.gz ] --transform Affine[ 2.0 ] --metric Mattes[ fixed1.nii, moving1.nii, 1, 32 ,Random,0.05 ] --convergence [ 1500x200, 1e-08, 20 ] --smoothing-sigmas 1x0vox --shrink-factors 2x1 --use-estimate-learning-rate-once 1 --use-histogram-matching 1 --transform SyN[ 0.25, 3.0, 0.0 ] --metric Mattes[ fixed1.nii, moving1.nii, 1, 32 ] --convergence [ 100x50x30, 1e-09, 20 ] --smoothing-sigmas 2x1x0vox --shrink-factors 3x2x1 --use-estimate-learning-rate-once 1 --use-histogram-matching 1 --winsorize-image-intensities [ 0.025, 1.0 ] --write-composite-transform 1'
358+
'antsRegistration --collapse-linear-transforms-to-fixed-image-header 0 --collapse-output-transforms 0 --dimensionality 3 --initial-moving-transform [ trans.mat, 1 ] --interpolation Linear --output [ output_, output_warped_image.nii.gz ] --transform Affine[ 2.0 ] --metric Mattes[ fixed1.nii, moving1.nii, 1, 32, Random, 0.05 ] --convergence [ 1500x200, 1e-08, 20 ] --smoothing-sigmas 1x0vox --shrink-factors 2x1 --use-estimate-learning-rate-once 1 --use-histogram-matching 1 --transform SyN[ 0.25, 3.0, 0.0 ] --metric Mattes[ fixed1.nii, moving1.nii, 1, 32 ] --convergence [ 100x50x30, 1e-09, 20 ] --smoothing-sigmas 2x1x0vox --shrink-factors 3x2x1 --use-estimate-learning-rate-once 1 --use-histogram-matching 1 --winsorize-image-intensities [ 0.025, 1.0 ] --write-composite-transform 1'
338359
>>> reg1.run() #doctest: +SKIP
339360
340361
>>> reg2 = copy.deepcopy(reg)
341362
>>> reg2.inputs.winsorize_upper_quantile = 0.975
342363
>>> reg2.cmdline
343-
'antsRegistration --collapse-linear-transforms-to-fixed-image-header 0 --collapse-output-transforms 0 --dimensionality 3 --initial-moving-transform [ trans.mat, 1 ] --interpolation Linear --output [ output_, output_warped_image.nii.gz ] --transform Affine[ 2.0 ] --metric Mattes[ fixed1.nii, moving1.nii, 1, 32 ,Random,0.05 ] --convergence [ 1500x200, 1e-08, 20 ] --smoothing-sigmas 1x0vox --shrink-factors 2x1 --use-estimate-learning-rate-once 1 --use-histogram-matching 1 --transform SyN[ 0.25, 3.0, 0.0 ] --metric Mattes[ fixed1.nii, moving1.nii, 1, 32 ] --convergence [ 100x50x30, 1e-09, 20 ] --smoothing-sigmas 2x1x0vox --shrink-factors 3x2x1 --use-estimate-learning-rate-once 1 --use-histogram-matching 1 --winsorize-image-intensities [ 0.0, 0.975 ] --write-composite-transform 1'
364+
'antsRegistration --collapse-linear-transforms-to-fixed-image-header 0 --collapse-output-transforms 0 --dimensionality 3 --initial-moving-transform [ trans.mat, 1 ] --interpolation Linear --output [ output_, output_warped_image.nii.gz ] --transform Affine[ 2.0 ] --metric Mattes[ fixed1.nii, moving1.nii, 1, 32, Random, 0.05 ] --convergence [ 1500x200, 1e-08, 20 ] --smoothing-sigmas 1x0vox --shrink-factors 2x1 --use-estimate-learning-rate-once 1 --use-histogram-matching 1 --transform SyN[ 0.25, 3.0, 0.0 ] --metric Mattes[ fixed1.nii, moving1.nii, 1, 32 ] --convergence [ 100x50x30, 1e-09, 20 ] --smoothing-sigmas 2x1x0vox --shrink-factors 3x2x1 --use-estimate-learning-rate-once 1 --use-histogram-matching 1 --winsorize-image-intensities [ 0.0, 0.975 ] --write-composite-transform 1'
344365
345366
>>> reg3 = copy.deepcopy(reg)
346367
>>> reg3.inputs.winsorize_lower_quantile = 0.025
347368
>>> reg3.inputs.winsorize_upper_quantile = 0.975
348369
>>> reg3.cmdline
349-
'antsRegistration --collapse-linear-transforms-to-fixed-image-header 0 --collapse-output-transforms 0 --dimensionality 3 --initial-moving-transform [ trans.mat, 1 ] --interpolation Linear --output [ output_, output_warped_image.nii.gz ] --transform Affine[ 2.0 ] --metric Mattes[ fixed1.nii, moving1.nii, 1, 32 ,Random,0.05 ] --convergence [ 1500x200, 1e-08, 20 ] --smoothing-sigmas 1x0vox --shrink-factors 2x1 --use-estimate-learning-rate-once 1 --use-histogram-matching 1 --transform SyN[ 0.25, 3.0, 0.0 ] --metric Mattes[ fixed1.nii, moving1.nii, 1, 32 ] --convergence [ 100x50x30, 1e-09, 20 ] --smoothing-sigmas 2x1x0vox --shrink-factors 3x2x1 --use-estimate-learning-rate-once 1 --use-histogram-matching 1 --winsorize-image-intensities [ 0.025, 0.975 ] --write-composite-transform 1'
370+
'antsRegistration --collapse-linear-transforms-to-fixed-image-header 0 --collapse-output-transforms 0 --dimensionality 3 --initial-moving-transform [ trans.mat, 1 ] --interpolation Linear --output [ output_, output_warped_image.nii.gz ] --transform Affine[ 2.0 ] --metric Mattes[ fixed1.nii, moving1.nii, 1, 32, Random, 0.05 ] --convergence [ 1500x200, 1e-08, 20 ] --smoothing-sigmas 1x0vox --shrink-factors 2x1 --use-estimate-learning-rate-once 1 --use-histogram-matching 1 --transform SyN[ 0.25, 3.0, 0.0 ] --metric Mattes[ fixed1.nii, moving1.nii, 1, 32 ] --convergence [ 100x50x30, 1e-09, 20 ] --smoothing-sigmas 2x1x0vox --shrink-factors 3x2x1 --use-estimate-learning-rate-once 1 --use-histogram-matching 1 --winsorize-image-intensities [ 0.025, 0.975 ] --write-composite-transform 1'
350371
351372
# Test collapse transforms flag
352373
>>> reg4 = copy.deepcopy(reg)
@@ -356,32 +377,95 @@ class Registration(ANTSCommand):
356377
{'reverse_invert_flags': [True, False], 'inverse_composite_transform': ['.../nipype/testing/data/output_InverseComposite.h5'], 'warped_image': '.../nipype/testing/data/output_warped_image.nii.gz', 'inverse_warped_image': <undefined>, 'forward_invert_flags': [False, False], 'reverse_transforms': ['.../nipype/testing/data/output_0GenericAffine.mat', '.../nipype/testing/data/output_1InverseWarp.nii.gz'], 'composite_transform': ['.../nipype/testing/data/output_Composite.h5'], 'forward_transforms': ['.../nipype/testing/data/output_0GenericAffine.mat', '.../nipype/testing/data/output_1Warp.nii.gz']}
357378
>>> reg4.aggregate_outputs() #doctest: +SKIP
358379
380+
# Test multiple metrics per stage
381+
>>> reg5 = copy.deepcopy(reg)
382+
>>> reg5.inputs.metric = ['CC', ['CC', 'Mattes']]
383+
>>> reg5.inputs.metric_weight = [1, [.5]*2]
384+
>>> reg5.inputs.radius_or_number_of_bins = [4, [32]*2]
385+
>>> reg5.inputs.sampling_strategy = ['Random', None] # use default strategy in second stage
386+
>>> reg5.inputs.sampling_percentage = [0.05, [0.05, 0.10]]
387+
>>> reg5.cmdline
388+
'antsRegistration --collapse-linear-transforms-to-fixed-image-header 0 --collapse-output-transforms 0 --dimensionality 3 --initial-moving-transform [ trans.mat, 1 ] --interpolation Linear --output [ output_, output_warped_image.nii.gz ] --transform Affine[ 2.0 ] --metric CC[ fixed1.nii, moving1.nii, 1, 4, Random, 0.05 ] --convergence [ 1500x200, 1e-08, 20 ] --smoothing-sigmas 1x0vox --shrink-factors 2x1 --use-estimate-learning-rate-once 1 --use-histogram-matching 1 --transform SyN[ 0.25, 3.0, 0.0 ] --metric CC[ fixed1.nii, moving1.nii, 0.5, 32, Dense, 0.05 ] --metric Mattes[ fixed1.nii, moving1.nii, 0.5, 32, Dense, 0.1 ] --convergence [ 100x50x30, 1e-09, 20 ] --smoothing-sigmas 2x1x0vox --shrink-factors 3x2x1 --use-estimate-learning-rate-once 1 --use-histogram-matching 1 --winsorize-image-intensities [ 0.0, 1.0 ] --write-composite-transform 1'
359389
"""
390+
DEF_SAMPLING_STRATEGY = 'Dense'
391+
"""The default sampling stratey argument."""
392+
360393
_cmd = 'antsRegistration'
361394
input_spec = RegistrationInputSpec
362395
output_spec = RegistrationOutputSpec
363396
_quantilesDone = False
364397

365-
def _optionalMetricParameters(self, index):
366-
if (len(self.inputs.sampling_strategy) > index) and (self.inputs.sampling_strategy[index] is not None):
367-
if self.inputs.sampling_strategy[index] == "Dense":
368-
return '' # The default when nothing is specified
369-
if isdefined(self.inputs.sampling_percentage) and (self.inputs.sampling_percentage is not None):
370-
return ',%s,%g' % (self.inputs.sampling_strategy[index], self.inputs.sampling_percentage[index])
371-
else:
372-
return ',%s' % self.inputs.sampling_strategy[index]
373-
return ''
374-
375398
def _formatMetric(self, index):
376-
retval = []
377-
retval.append(
378-
'%s[ %s, %s, %g, %d' % (self.inputs.metric[index], self.inputs.fixed_image[0],
379-
self.inputs.moving_image[
380-
0], self.inputs.metric_weight[index],
381-
self.inputs.radius_or_number_of_bins[index]))
382-
retval.append(' %s' % self._optionalMetricParameters(index))
383-
retval.append(' ]')
384-
return "".join(retval)
399+
"""
400+
Format the antsRegistration -m metric argument(s).
401+
402+
Parameters
403+
----------
404+
index: the stage index
405+
"""
406+
# The common fixed image.
407+
fixed = self.inputs.fixed_image[0]
408+
# The common moving image.
409+
moving = self.inputs.moving_image[0]
410+
# The metric name input for the current stage.
411+
name_input = self.inputs.metric[index]
412+
# The stage-specific input dictionary.
413+
stage_inputs = dict(
414+
metric=name_input,
415+
weight=self.inputs.metric_weight[index],
416+
radius_or_bins=self.inputs.radius_or_number_of_bins[index],
417+
optional=self.inputs.radius_or_number_of_bins[index]
418+
)
419+
# The optional sampling strategy and percentage.
420+
if (isdefined(self.inputs.sampling_strategy) and self.inputs.sampling_strategy):
421+
sampling_strategy = self.inputs.sampling_strategy[index]
422+
if sampling_strategy:
423+
stage_inputs['sampling_strategy'] = sampling_strategy
424+
sampling_percentage = self.inputs.sampling_percentage
425+
if (isdefined(self.inputs.sampling_percentage) and self.inputs.sampling_percentage):
426+
sampling_percentage = self.inputs.sampling_percentage[index]
427+
if sampling_percentage:
428+
stage_inputs['sampling_percentage'] = sampling_percentage
429+
430+
# Make a list of metric specifications, one per -m command line
431+
# argument for the current stage.
432+
# If there are multiple inputs for this stage, then convert the
433+
# dictionary of list inputs into a list of metric specifications.
434+
# Otherwise, make a singleton list of the metric specification
435+
# from the non-list inputs.
436+
if isinstance(name_input, list):
437+
items = stage_inputs.items()
438+
indexes = range(0, len(name_input))
439+
specs = [{k: v[i] for k, v in items} for i in indexes]
440+
else:
441+
specs = [stage_inputs]
442+
443+
# Format the --metric command line metric arguments, one per specification.
444+
return [self._formatMetricArgument(fixed, moving, **spec) for spec in specs]
445+
446+
def _formatMetricArgument(self, fixed, moving, **kwargs):
447+
retval = '%s[ %s, %s, %g, %d' % (kwargs['metric'],
448+
fixed, moving, kwargs['weight'],
449+
kwargs['radius_or_bins'])
450+
451+
# The optional sampling strategy.
452+
if kwargs.has_key('sampling_strategy'):
453+
sampling_strategy = kwargs['sampling_strategy']
454+
elif kwargs.has_key('sampling_percentage'):
455+
# The sampling percentage is specified but not the
456+
# sampling strategy. Use the default strategy.
457+
sampling_strategy = Registration.DEF_SAMPLING_STRATEGY
458+
else:
459+
sampling_strategy = None
460+
# Format the optional sampling arguments.
461+
if sampling_strategy:
462+
retval += ', %s' % sampling_strategy
463+
if kwargs.has_key('sampling_percentage'):
464+
retval += ', %g' % kwargs['sampling_percentage']
465+
466+
retval += ' ]'
467+
468+
return retval
385469

386470
def _formatTransform(self, index):
387471
retval = []
@@ -396,7 +480,8 @@ def _formatRegistration(self):
396480
retval = []
397481
for ii in range(len(self.inputs.transforms)):
398482
retval.append('--transform %s' % (self._formatTransform(ii)))
399-
retval.append('--metric %s' % self._formatMetric(ii))
483+
for metric in self._formatMetric(ii):
484+
retval.append('--metric %s' % metric)
400485
retval.append('--convergence %s' % self._formatConvergence(ii))
401486
retval.append('--smoothing-sigmas %s%s' % (self._antsJoinList(
402487
self.inputs.smoothing_sigmas[ii]),

0 commit comments

Comments
 (0)