Skip to content

Commit 18f92d9

Browse files
committed
Object-oriented interface, with fit and apply methods.
1 parent a3bfa1a commit 18f92d9

File tree

1 file changed

+68
-52
lines changed

1 file changed

+68
-52
lines changed

dmriprep/utils/register.py

Lines changed: 68 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
"""
22
Linear affine registration tools for motion correction.
33
"""
4+
import attr
5+
46
import numpy as np
57
import nibabel as nb
68
from dipy.align.metrics import CCMetric, EMMetric, SSDMetric
@@ -72,7 +74,7 @@ def c_of_mass(
7274
):
7375
transform = transform_centers_of_mass(static, static_affine, moving, moving_affine)
7476
transformed = transform.transform(moving)
75-
return transformed, transform.affine
77+
return transform
7678

7779

7880
def translation(
@@ -89,7 +91,7 @@ def translation(
8991
starting_affine=starting_affine,
9092
)
9193

92-
return translation.transform(moving), translation.affine
94+
return translation
9395

9496

9597
def rigid(
@@ -105,12 +107,13 @@ def rigid(
105107
moving_affine,
106108
starting_affine=starting_affine,
107109
)
108-
return rigid.transform(moving), rigid.affine
110+
return rigid
109111

110112

111-
def affine(
112-
moving, static, static_affine, moving_affine, reg, starting_affine, params0=None
113-
):
113+
def affine(moving, static, static_affine, moving_affine, reg, starting_affine,
114+
params0=None):
115+
"""
116+
"""
114117
transform = AffineTransform3D()
115118
affine = reg.optimize(
116119
static,
@@ -122,49 +125,62 @@ def affine(
122125
starting_affine=starting_affine,
123126
)
124127

125-
return affine.transform(moving), affine.affine
126-
127-
128-
def affine_registration(
129-
moving,
130-
static,
131-
nbins,
132-
sampling_prop,
133-
metric,
134-
pipeline,
135-
level_iters,
136-
sigmas,
137-
factors,
138-
params0,
139-
):
140-
141-
"""
142-
Find the affine transformation between two 3D images.
143-
144-
Parameters
145-
----------
146-
147-
"""
148-
# Define the Affine registration object we'll use with the chosen metric:
149-
use_metric = affine_metric_dict[metric](nbins, sampling_prop)
150-
affreg = AffineRegistration(
151-
metric=use_metric, level_iters=level_iters, sigmas=sigmas, factors=factors
152-
)
153-
154-
if not params0:
155-
starting_affine = np.eye(4)
156-
else:
157-
starting_affine = params0
158-
159-
# Go through the selected transformation:
160-
for func in pipeline:
161-
transformed, starting_affine = func(
162-
np.asarray(moving.dataobj),
163-
np.asarray(static.dataobj),
164-
static.affine,
165-
moving.affine,
166-
affreg,
167-
starting_affine,
168-
params0,
169-
)
170-
return nb.Nifti1Image(np.array(transformed), static.affine), starting_affine
128+
return affine
129+
130+
131+
@attr.s(slots=True, frozen=True)
132+
class AffineRegistration():
133+
def __init__(self):
134+
nbins = attr.ib(default=32)
135+
sampling_prop = attr.ib(default=1.0)
136+
metric = attr.ib(default="MI")
137+
level_iters = attr.ib(default=[10000, 1000, 100])
138+
sigmas = attr.ib(defaults=[3, 1, 0.0])
139+
factors = attr.ib(defaults=[4, 2, 1])
140+
pipeline = attr.ib(defaults=[c_of_mass, translation, rigid, affine])
141+
142+
def fit(self, static, moving, params0=None):
143+
"""
144+
static, moving : nib.Nifti1Image class images
145+
"""
146+
if params0 is None:
147+
starting_affine = np.eye(4)
148+
else:
149+
starting_affine = params0
150+
151+
use_metric = affine_metric_dict[self.metric](self.nbins,
152+
self.sampling_prop)
153+
affreg = AffineRegistration(
154+
metric=use_metric,
155+
level_iters=self.level_iters,
156+
sigmas=self.sigmas,
157+
factors=self.factors)
158+
159+
# Go through the selected transformation:
160+
for func in self.pipeline:
161+
transform = func(
162+
np.asarray(moving.dataobj),
163+
np.asarray(static.dataobj),
164+
static.affine,
165+
moving.affine,
166+
affreg,
167+
starting_affine,
168+
params0,
169+
)
170+
starting_affine = transform.affine
171+
172+
self.static_affine_ = static.affine
173+
self.moving_affine_ = moving.affine
174+
self.affine_ = starting_affine
175+
self.reg_ = AffineMap(starting_affine,
176+
static.shape, static.affine,
177+
moving.shape, moving.affine)
178+
179+
def apply(self, moving):
180+
"""
181+
182+
"""
183+
data = moving.get_fdata()
184+
assert np.all(moving.affine, self.moving_affine_)
185+
return nb.Nifti1Image(np.array(self.reg_.transform(data)),
186+
self.static_affine_)

0 commit comments

Comments
 (0)