Skip to content

Commit f92efa9

Browse files
committed
ENH: Add DisplacementFieldSubsamplingFactor
Output ANTs displacement fields are very dense, sampled with the fixed image. This is generally over-sampled for the purpose of downstream use. Memory usage is very high and serialization and deserialization takes quite some time. Add a DisplacementFieldSubsamplingFactor parameter for downsampling the resulting displacement fields. This is applied in all directions for both the forward and inverse transform. If the DisplacementFieldSubsamplingFactor is greater than 1, this is applied. The current default is 2. This uses the itk::DisplacementFieldTransformParametersAdapter. In the future, we may want to increase the default, and / or use the itk::GaussianSmoothingOnUpdateDisplacementFieldTransformParametersAdaptor with larger factors to avoid aliasing.
1 parent afd998c commit f92efa9

File tree

2 files changed

+56
-5
lines changed

2 files changed

+56
-5
lines changed

include/itkANTSRegistration.h

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include "itkCompositeTransform.h"
2424
#include "itkDataObjectDecorator.h"
2525
#include "itkantsRegistrationHelper.h"
26+
#include "itkDisplacementFieldTransformParametersAdaptor.h"
2627

2728
namespace itk
2829
{
@@ -240,11 +241,17 @@ class ANTSRegistration : public ProcessObject
240241
/** Set/Get the optimizer weights. When set, this allows restricting the optimization
241242
* of the displacement field, translation, rigid or affine transform on a per-component basis.
242243
* For example, to limit the deformation or rotation of 3-D volume to the first two dimensions,
243-
* specify a weight vector of ‘(1,1,0)’ for a 3D deformation field
244+
* specify a weight vector of ‘(1,1,0)’ for a 3D displacement field
244245
* or ‘(1,1,0,1,1,0)’ for a rigid transformation. */
245246
itkSetMacro(RestrictTransformation, std::vector<ParametersValueType>);
246247
itkGetConstReferenceMacro(RestrictTransformation, std::vector<ParametersValueType>);
247248

249+
/** Set/Get the subsampling factor for displacement fields results.
250+
* A factor of 1 results in no subsampling. This is applied in all dimensions.
251+
* The default is 2. */
252+
itkSetMacro(DisplacementFieldSubsamplingFactor, unsigned int);
253+
itkGetMacro(DisplacementFieldSubsamplingFactor, unsigned int);
254+
248255
virtual DecoratedOutputTransformType *
249256
GetOutput(DataObjectPointerArraySizeType i);
250257
virtual const DecoratedOutputTransformType *
@@ -286,6 +293,10 @@ class ANTSRegistration : public ProcessObject
286293
DataObjectPointer MakeOutput(DataObjectPointerArraySizeType) override;
287294
using RegistrationHelperType = ::ants::RegistrationHelper<TParametersValueType, FixedImageType::ImageDimension>;
288295
using InternalImageType = typename RegistrationHelperType::ImageType; // float or double pixels
296+
using DisplacementFieldTransformType = typename RegistrationHelperType::DisplacementFieldTransformType;
297+
using DisplacementFieldType = typename DisplacementFieldTransformType::DisplacementFieldType;
298+
using DisplacementFieldTransformParametersAdaptorType =
299+
DisplacementFieldTransformParametersAdaptor<DisplacementFieldTransformType>;
289300

290301
template <typename TImage>
291302
typename InternalImageType::Pointer
@@ -346,6 +357,7 @@ class ANTSRegistration : public ProcessObject
346357
unsigned int m_Radius{ 4 };
347358
bool m_CollapseCompositeTransform{ true };
348359
bool m_MaskAllStages{ false };
360+
unsigned int m_DisplacementFieldSubsamplingFactor{ 2 };
349361

350362
std::vector<unsigned int> m_SynIterations{ 40, 20, 0 };
351363
std::vector<unsigned int> m_AffineIterations{ 2100, 1200, 1200, 10 };
@@ -355,7 +367,10 @@ class ANTSRegistration : public ProcessObject
355367
std::vector<ParametersValueType> m_RestrictTransformation;
356368

357369
private:
358-
typename RegistrationHelperType::Pointer m_Helper{ RegistrationHelperType::New() };
370+
typename RegistrationHelperType::Pointer m_Helper{ RegistrationHelperType::New() };
371+
typename DisplacementFieldTransformParametersAdaptorType::Pointer m_DisplacementFieldAdaptor{
372+
DisplacementFieldTransformParametersAdaptorType::New()
373+
};
359374

360375
#ifdef ITK_USE_CONCEPT_CHECKING
361376
static_assert(TFixedImage::ImageDimension == TMovingImage::ImageDimension,

include/itkANTSRegistration.hxx

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ ANTSRegistration<TFixedImage, TMovingImage, TParametersValueType>::PrintSelf(std
6767
os << indent << "Radius: " << this->m_Radius << std::endl;
6868
os << indent << "CollapseCompositeTransform: " << (this->m_CollapseCompositeTransform ? "On" : "Off") << std::endl;
6969
os << indent << "MaskAllStages: " << (this->m_MaskAllStages ? "On" : "Off") << std::endl;
70+
os << indent << "DisplacementFieldSubsamplingFactor: " << this->m_DisplacementFieldSubsamplingFactor << std::endl;
7071

7172
os << indent << "SynIterations: " << this->m_SynIterations << std::endl;
7273
os << indent << "AffineIterations: " << this->m_AffineIterations << std::endl;
@@ -257,8 +258,8 @@ ANTSRegistration<TFixedImage, TMovingImage, TParametersValueType>::MakeOutput(Da
257258
template <typename TFixedImage, typename TMovingImage, typename TParametersValueType>
258259
template <typename TImage>
259260
auto
260-
ANTSRegistration<TFixedImage, TMovingImage, TParametersValueType>::CastImageToInternalType(
261-
const TImage * inputImage) -> typename InternalImageType::Pointer
261+
ANTSRegistration<TFixedImage, TMovingImage, TParametersValueType>::CastImageToInternalType(const TImage * inputImage) ->
262+
typename InternalImageType::Pointer
262263
{
263264
using CastFilterType = CastImageFilter<TImage, InternalImageType>;
264265
typename CastFilterType::Pointer castFilter = CastFilterType::New();
@@ -606,7 +607,7 @@ ANTSRegistration<TFixedImage, TMovingImage, TParametersValueType>::GenerateData(
606607
{
607608
itkExceptionMacro(<< "Unsupported transform type: " << this->GetTypeOfTransform());
608609
}
609-
this->UpdateProgress(0.95);
610+
this->UpdateProgress(0.90);
610611

611612
typename OutputTransformType::Pointer forwardTransform = m_Helper->GetModifiableCompositeTransform();
612613
if (m_CollapseCompositeTransform)
@@ -615,6 +616,41 @@ ANTSRegistration<TFixedImage, TMovingImage, TParametersValueType>::GenerateData(
615616
}
616617
this->SetForwardTransform(forwardTransform);
617618

619+
if (m_DisplacementFieldSubsamplingFactor > 1)
620+
{
621+
using TransformType = typename OutputTransformType::TransformType;
622+
for (unsigned int i = 0; i < forwardTransform->GetNumberOfTransforms(); ++i)
623+
{
624+
typename TransformType::Pointer transform = forwardTransform->GetNthTransform(i);
625+
typename DisplacementFieldTransformType::Pointer displacementFieldTransform =
626+
dynamic_cast<DisplacementFieldTransformType *>(transform.GetPointer());
627+
if (displacementFieldTransform)
628+
{
629+
// The transform is a DisplacementFieldTransform
630+
displacementFieldTransform->Print(std::cout, 3);
631+
const auto displacementField = displacementFieldTransform->GetDisplacementField();
632+
m_DisplacementFieldAdaptor->SetTransform(displacementFieldTransform);
633+
m_DisplacementFieldAdaptor->SetRequiredOrigin(displacementField->GetOrigin());
634+
m_DisplacementFieldAdaptor->SetRequiredDirection(displacementField->GetDirection());
635+
auto requiredSize = displacementField->GetLargestPossibleRegion().GetSize();
636+
for (unsigned int i = 0; i < requiredSize.GetSizeDimension(); ++i)
637+
{
638+
requiredSize[i] /= m_DisplacementFieldSubsamplingFactor;
639+
}
640+
m_DisplacementFieldAdaptor->SetRequiredSize(requiredSize);
641+
auto requiredSpacing = displacementField->GetSpacing();
642+
for (unsigned int i = 0; i < requiredSpacing.GetVectorDimension(); ++i)
643+
{
644+
requiredSpacing[i] *= m_DisplacementFieldSubsamplingFactor;
645+
}
646+
m_DisplacementFieldAdaptor->SetRequiredSpacing(requiredSpacing);
647+
m_DisplacementFieldAdaptor->AdaptTransformParameters();
648+
displacementFieldTransform->Print(std::cout, 3);
649+
}
650+
}
651+
}
652+
this->UpdateProgress(0.95);
653+
618654
typename OutputTransformType::Pointer inverseTransform = OutputTransformType::New();
619655
if (forwardTransform->GetInverse(inverseTransform))
620656
{

0 commit comments

Comments
 (0)