Skip to content

Commit 966cec0

Browse files
authored
Merge pull request #31 from thewtex/displacement-field-subsampling-factor
ENH: Add DisplacementFieldSubsamplingFactor
2 parents afd998c + f92efa9 commit 966cec0

File tree

2 files changed

+56
-5
lines changed

2 files changed

+56
-5
lines changed

include/itkANTSRegistration.h

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include "itkCompositeTransform.h"
2424
#include "itkDataObjectDecorator.h"
2525
#include "itkantsRegistrationHelper.h"
26+
#include "itkDisplacementFieldTransformParametersAdaptor.h"
2627

2728
namespace itk
2829
{
@@ -240,11 +241,17 @@ class ANTSRegistration : public ProcessObject
240241
/** Set/Get the optimizer weights. When set, this allows restricting the optimization
241242
* of the displacement field, translation, rigid or affine transform on a per-component basis.
242243
* For example, to limit the deformation or rotation of 3-D volume to the first two dimensions,
243-
* specify a weight vector of ‘(1,1,0)’ for a 3D deformation field
244+
* specify a weight vector of ‘(1,1,0)’ for a 3D displacement field
244245
* or ‘(1,1,0,1,1,0)’ for a rigid transformation. */
245246
itkSetMacro(RestrictTransformation, std::vector<ParametersValueType>);
246247
itkGetConstReferenceMacro(RestrictTransformation, std::vector<ParametersValueType>);
247248

249+
/** Set/Get the subsampling factor for displacement fields results.
250+
* A factor of 1 results in no subsampling. This is applied in all dimensions.
251+
* The default is 2. */
252+
itkSetMacro(DisplacementFieldSubsamplingFactor, unsigned int);
253+
itkGetMacro(DisplacementFieldSubsamplingFactor, unsigned int);
254+
248255
virtual DecoratedOutputTransformType *
249256
GetOutput(DataObjectPointerArraySizeType i);
250257
virtual const DecoratedOutputTransformType *
@@ -286,6 +293,10 @@ class ANTSRegistration : public ProcessObject
286293
DataObjectPointer MakeOutput(DataObjectPointerArraySizeType) override;
287294
using RegistrationHelperType = ::ants::RegistrationHelper<TParametersValueType, FixedImageType::ImageDimension>;
288295
using InternalImageType = typename RegistrationHelperType::ImageType; // float or double pixels
296+
using DisplacementFieldTransformType = typename RegistrationHelperType::DisplacementFieldTransformType;
297+
using DisplacementFieldType = typename DisplacementFieldTransformType::DisplacementFieldType;
298+
using DisplacementFieldTransformParametersAdaptorType =
299+
DisplacementFieldTransformParametersAdaptor<DisplacementFieldTransformType>;
289300

290301
template <typename TImage>
291302
typename InternalImageType::Pointer
@@ -346,6 +357,7 @@ class ANTSRegistration : public ProcessObject
346357
unsigned int m_Radius{ 4 };
347358
bool m_CollapseCompositeTransform{ true };
348359
bool m_MaskAllStages{ false };
360+
unsigned int m_DisplacementFieldSubsamplingFactor{ 2 };
349361

350362
std::vector<unsigned int> m_SynIterations{ 40, 20, 0 };
351363
std::vector<unsigned int> m_AffineIterations{ 2100, 1200, 1200, 10 };
@@ -355,7 +367,10 @@ class ANTSRegistration : public ProcessObject
355367
std::vector<ParametersValueType> m_RestrictTransformation;
356368

357369
private:
358-
typename RegistrationHelperType::Pointer m_Helper{ RegistrationHelperType::New() };
370+
typename RegistrationHelperType::Pointer m_Helper{ RegistrationHelperType::New() };
371+
typename DisplacementFieldTransformParametersAdaptorType::Pointer m_DisplacementFieldAdaptor{
372+
DisplacementFieldTransformParametersAdaptorType::New()
373+
};
359374

360375
#ifdef ITK_USE_CONCEPT_CHECKING
361376
static_assert(TFixedImage::ImageDimension == TMovingImage::ImageDimension,

include/itkANTSRegistration.hxx

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ ANTSRegistration<TFixedImage, TMovingImage, TParametersValueType>::PrintSelf(std
6767
os << indent << "Radius: " << this->m_Radius << std::endl;
6868
os << indent << "CollapseCompositeTransform: " << (this->m_CollapseCompositeTransform ? "On" : "Off") << std::endl;
6969
os << indent << "MaskAllStages: " << (this->m_MaskAllStages ? "On" : "Off") << std::endl;
70+
os << indent << "DisplacementFieldSubsamplingFactor: " << this->m_DisplacementFieldSubsamplingFactor << std::endl;
7071

7172
os << indent << "SynIterations: " << this->m_SynIterations << std::endl;
7273
os << indent << "AffineIterations: " << this->m_AffineIterations << std::endl;
@@ -257,8 +258,8 @@ ANTSRegistration<TFixedImage, TMovingImage, TParametersValueType>::MakeOutput(Da
257258
template <typename TFixedImage, typename TMovingImage, typename TParametersValueType>
258259
template <typename TImage>
259260
auto
260-
ANTSRegistration<TFixedImage, TMovingImage, TParametersValueType>::CastImageToInternalType(
261-
const TImage * inputImage) -> typename InternalImageType::Pointer
261+
ANTSRegistration<TFixedImage, TMovingImage, TParametersValueType>::CastImageToInternalType(const TImage * inputImage) ->
262+
typename InternalImageType::Pointer
262263
{
263264
using CastFilterType = CastImageFilter<TImage, InternalImageType>;
264265
typename CastFilterType::Pointer castFilter = CastFilterType::New();
@@ -606,7 +607,7 @@ ANTSRegistration<TFixedImage, TMovingImage, TParametersValueType>::GenerateData(
606607
{
607608
itkExceptionMacro(<< "Unsupported transform type: " << this->GetTypeOfTransform());
608609
}
609-
this->UpdateProgress(0.95);
610+
this->UpdateProgress(0.90);
610611

611612
typename OutputTransformType::Pointer forwardTransform = m_Helper->GetModifiableCompositeTransform();
612613
if (m_CollapseCompositeTransform)
@@ -615,6 +616,41 @@ ANTSRegistration<TFixedImage, TMovingImage, TParametersValueType>::GenerateData(
615616
}
616617
this->SetForwardTransform(forwardTransform);
617618

619+
if (m_DisplacementFieldSubsamplingFactor > 1)
620+
{
621+
using TransformType = typename OutputTransformType::TransformType;
622+
for (unsigned int i = 0; i < forwardTransform->GetNumberOfTransforms(); ++i)
623+
{
624+
typename TransformType::Pointer transform = forwardTransform->GetNthTransform(i);
625+
typename DisplacementFieldTransformType::Pointer displacementFieldTransform =
626+
dynamic_cast<DisplacementFieldTransformType *>(transform.GetPointer());
627+
if (displacementFieldTransform)
628+
{
629+
// The transform is a DisplacementFieldTransform
630+
displacementFieldTransform->Print(std::cout, 3);
631+
const auto displacementField = displacementFieldTransform->GetDisplacementField();
632+
m_DisplacementFieldAdaptor->SetTransform(displacementFieldTransform);
633+
m_DisplacementFieldAdaptor->SetRequiredOrigin(displacementField->GetOrigin());
634+
m_DisplacementFieldAdaptor->SetRequiredDirection(displacementField->GetDirection());
635+
auto requiredSize = displacementField->GetLargestPossibleRegion().GetSize();
636+
for (unsigned int i = 0; i < requiredSize.GetSizeDimension(); ++i)
637+
{
638+
requiredSize[i] /= m_DisplacementFieldSubsamplingFactor;
639+
}
640+
m_DisplacementFieldAdaptor->SetRequiredSize(requiredSize);
641+
auto requiredSpacing = displacementField->GetSpacing();
642+
for (unsigned int i = 0; i < requiredSpacing.GetVectorDimension(); ++i)
643+
{
644+
requiredSpacing[i] *= m_DisplacementFieldSubsamplingFactor;
645+
}
646+
m_DisplacementFieldAdaptor->SetRequiredSpacing(requiredSpacing);
647+
m_DisplacementFieldAdaptor->AdaptTransformParameters();
648+
displacementFieldTransform->Print(std::cout, 3);
649+
}
650+
}
651+
}
652+
this->UpdateProgress(0.95);
653+
618654
typename OutputTransformType::Pointer inverseTransform = OutputTransformType::New();
619655
if (forwardTransform->GetInverse(inverseTransform))
620656
{

0 commit comments

Comments
 (0)