Skip to content

Commit 581e460

Browse files
authored
Merge pull request #453 from nipreps/fix/revise-bspline-fitting
FIX: Revision of the B-Spline fitting code
2 parents bb89bdc + 1883c15 commit 581e460

File tree

5 files changed

+65
-34
lines changed

5 files changed

+65
-34
lines changed

.circleci/config.yml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -292,9 +292,9 @@ jobs:
292292

293293
- restore_cache:
294294
keys:
295-
- workdir-v2-{{ .Branch }}-
296-
- workdir-v2-master-
297-
- workdir-v2-
295+
- workdir-v3-{{ .Branch }}-
296+
- workdir-v3-master-
297+
- workdir-v3-
298298
- run:
299299
name: Refreshing cached intermediate results
300300
working_directory: /tmp/src/sdcflows
@@ -343,7 +343,7 @@ jobs:
343343
--cov sdcflows --cov-report xml:/out/unittests.xml \
344344
sdcflows/
345345
- save_cache:
346-
key: workdir-v2-{{ .Branch }}-{{ .BuildNum }}
346+
key: workdir-v3-{{ .Branch }}-{{ .BuildNum }}
347347
paths:
348348
- /tmp/work
349349
- store_artifacts:

sdcflows/interfaces/bspline.py

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -56,17 +56,17 @@ class _BSplineApproxInputSpec(BaseInterfaceInputSpec):
5656
in_data = File(exists=True, mandatory=True, desc="path to a fieldmap")
5757
in_mask = File(exists=True, desc="path to a brain mask")
5858
bs_spacing = InputMultiObject(
59-
[DEFAULT_ZOOMS_MM],
59+
[DEFAULT_HF_ZOOMS_MM],
6060
traits.Tuple(traits.Float, traits.Float, traits.Float),
6161
usedefault=True,
6262
desc="spacing between B-Spline control points",
6363
)
6464
ridge_alpha = traits.Float(
65-
0.01, usedefault=True, desc="controls the regularization"
65+
1e-4, usedefault=True, desc="controls the regularization"
6666
)
6767
recenter = traits.Enum(
68-
"mode",
6968
"median",
69+
"mode",
7070
"mean",
7171
False,
7272
usedefault=True,
@@ -80,7 +80,7 @@ class _BSplineApproxInputSpec(BaseInterfaceInputSpec):
8080
zooms_min = traits.Union(
8181
traits.Float,
8282
traits.Tuple(traits.Float, traits.Float, traits.Float),
83-
default_value=4.0,
83+
default_value=1.0,
8484
usedefault=True,
8585
desc="limit minimum image zooms, set 0.0 to use the original image",
8686
)
@@ -90,6 +90,7 @@ class _BSplineApproxInputSpec(BaseInterfaceInputSpec):
9090

9191

9292
class _BSplineApproxOutputSpec(TraitedSpec):
93+
out_intercept = traits.Float
9394
out_field = File(exists=True)
9495
out_coeff = OutputMultiObject(File(exists=True))
9596
out_error = File(exists=True)
@@ -139,15 +140,15 @@ def _run_interface(self, runtime):
139140

140141
# Load in the fieldmap
141142
fmapnii = nb.load(self.inputs.in_data)
142-
fmapnii = nb.as_closest_canonical(fmapnii)
143+
# fmapnii = nb.as_closest_canonical(fmapnii)
143144
zooms = fmapnii.header.get_zooms()
144145

145146
# Get a mask (or define on the spot to cover the full extent)
146147
masknii = (
147148
nb.load(self.inputs.in_mask) if isdefined(self.inputs.in_mask) else None
148149
)
149-
if masknii is not None:
150-
masknii = nb.as_closest_canonical(masknii)
150+
# if masknii is not None:
151+
# masknii = nb.as_closest_canonical(masknii)
151152

152153
# Determine the shape of bspline coefficients
153154
# This should not change with resizing, so do it first
@@ -211,9 +212,7 @@ def _run_interface(self, runtime):
211212
)
212213

213214
# Fit the model
214-
model = lm.Ridge(
215-
alpha=self.inputs.ridge_alpha, fit_intercept=False, solver="lsqr"
216-
)
215+
model = lm.Ridge(alpha=self.inputs.ridge_alpha, fit_intercept=False)
217216
for attempt in range(3):
218217
model.fit(colmat, data.reshape(-1))
219218
extreme = np.abs(model.coef_).max()
@@ -228,6 +227,8 @@ def _run_interface(self, runtime):
228227
f"Extreme value {extreme:.2e} detected in spline coefficients."
229228
)
230229

230+
self._results["out_intercept"] = model.intercept_
231+
231232
# Store coefficients
232233
index = 0
233234
self._results["out_coeff"] = []
@@ -247,11 +248,11 @@ def _run_interface(self, runtime):
247248
# Interpolating in the original grid will require a new collocation matrix
248249
if need_resize:
249250
fmapnii = nb.load(self.inputs.in_data)
250-
fmapnii = nb.as_closest_canonical(fmapnii)
251+
# fmapnii = nb.as_closest_canonical(fmapnii)
251252
data = fmapnii.get_fdata(dtype="float32") - center
252253
if masknii is not None:
253254
masknii = nb.load(self.inputs.in_mask)
254-
masknii = nb.as_closest_canonical(masknii)
255+
# masknii = nb.as_closest_canonical(masknii)
255256
mask = np.asanyarray(masknii.dataobj) > 1e-4
256257
else:
257258
mask = np.ones_like(fmapnii.dataobj, dtype=bool)
@@ -267,14 +268,20 @@ def _run_interface(self, runtime):
267268
# Store interpolated field
268269
hdr = fmapnii.header.copy()
269270
hdr.set_data_dtype("float32")
270-
fmapnii.__class__(interp_data, fmapnii.affine, hdr).to_filename(out_name)
271+
outnii = fmapnii.__class__(interp_data, fmapnii.affine, hdr)
272+
outnii.header["cal_max"] = np.abs(outnii.dataobj).max()
273+
outnii.header["cal_min"] = -outnii.header["cal_max"]
274+
outnii.to_filename(out_name)
271275
self._results["out_field"] = out_name
272276

273277
# Write out fitting-error map
274278
self._results["out_error"] = out_name.replace("_field.", "_error.")
275-
fmapnii.__class__(
276-
data * mask - interp_data, fmapnii.affine, fmapnii.header
277-
).to_filename(self._results["out_error"])
279+
errornii = fmapnii.__class__(
280+
(data - interp_data) * mask, fmapnii.affine, fmapnii.header
281+
)
282+
errornii.header["cal_min"] = 0
283+
errornii.header["cal_max"] = np.max(errornii.dataobj)
284+
errornii.to_filename(self._results["out_error"])
278285

279286
if not self.inputs.extrapolate:
280287
return runtime

sdcflows/interfaces/tests/test_bspline.py

Lines changed: 32 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -40,41 +40,64 @@
4040
@pytest.mark.parametrize("testnum", range(100))
4141
def test_bsplines(tmp_path, testnum):
4242
"""Test idempotency of B-Splines interpolation + approximation."""
43-
targetshape = (10, 12, 9)
43+
targetshape = (50, 50, 30)
4444

4545
# Generate an oblique affine matrix for the target - it will be a common case.
4646
targetaff = nb.affines.from_matvec(
47-
nb.eulerangles.euler2mat(x=0.9, y=0.001, z=0.001) @ np.diag((2, 3, 4)),
47+
nb.eulerangles.euler2mat(x=-0.9, y=0.001, z=0.001) @ np.diag((2, 2, 2.4)),
4848
)
4949

5050
# Intendedly mis-centered (exercise we may not have volume-centered NIfTIs)
5151
targetaff[:3, 3] = nb.affines.apply_affine(
5252
targetaff, 0.5 * (np.array(targetshape) - 3)
5353
)
5454

55+
mask = np.zeros(targetshape)
56+
mask[10:-10, 10:-10, 6:-6] = 1
5557
# Generate some target grid
56-
targetnii = nb.Nifti1Image(np.ones(targetshape), targetaff, None)
57-
targetnii.to_filename(tmp_path / "target.nii.gz")
58+
targetnii = nb.Nifti1Image(mask, targetaff, None)
59+
targetnii.header.set_qform(targetaff, code=1)
60+
targetnii.header.set_sform(targetaff, code=1)
61+
targetnii.to_filename(tmp_path / "mask.nii.gz")
5862

5963
# Generate random coefficients
60-
gridnii = bspline_grid(targetnii, control_zooms_mm=(4, 6, 8))
61-
coeff = (rng.random(size=gridnii.shape) - 0.5) * 500
64+
gridnii = bspline_grid(targetnii, control_zooms_mm=(40, 40, 16))
65+
coeff = (rng.standard_normal(size=gridnii.shape)) * 100
6266
coeffnii = nb.Nifti1Image(coeff.astype("float32"), gridnii.affine, gridnii.header)
67+
coeffnii.header["cal_max"] = np.abs(coeff).max()
68+
coeffnii.header["cal_min"] = -coeffnii.header["cal_max"]
69+
coeffnii.header.set_qform(gridnii.affine, code=1)
70+
coeffnii.header.set_sform(gridnii.affine, code=1)
6371
coeffnii.to_filename(tmp_path / "coeffs.nii.gz")
6472

6573
os.chdir(tmp_path)
6674
# Check that we can interpolate the coefficients on a target
6775
test1 = ApplyCoeffsField(
68-
in_data=str(tmp_path / "target.nii.gz"),
76+
in_data=str(tmp_path / "mask.nii.gz"),
6977
in_coeff=str(tmp_path / "coeffs.nii.gz"),
7078
pe_dir="j-",
7179
ro_time=1.0,
7280
).run()
7381

82+
fieldnii = nb.load(test1.outputs.out_field)
83+
fielddata = fieldnii.get_fdata()
84+
fielddata -= np.median(fielddata)
85+
fielddata = 200 * fielddata / np.abs(fielddata).max()
86+
87+
fieldnii.header["cal_max"] = np.abs(fielddata).max()
88+
fieldnii.header["cal_min"] = -fieldnii.header["cal_max"]
89+
fieldnii.header.set_qform(targetaff, code=1)
90+
fieldnii.header.set_sform(targetaff, code=1)
91+
92+
nb.Nifti1Image(fielddata, targetaff, fieldnii.header).to_filename(
93+
tmp_path / "testfield.nii.gz",
94+
)
95+
7496
# Approximate the interpolated target
7597
test2 = BSplineApprox(
76-
in_data=test1.outputs.out_field,
77-
bs_spacing=[(4, 6, 8)],
98+
in_data=str(tmp_path / "testfield.nii.gz"),
99+
# in_mask=str(tmp_path / "mask.nii.gz"),
100+
bs_spacing=[(40, 40, 16)],
78101
zooms_min=0,
79102
recenter=False,
80103
ridge_alpha=1e-4,

sdcflows/utils/wrangler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,7 @@ def find_estimators(
339339
"Dataset includes `B0FieldIdentifier` metadata."
340340
"Any data missing this metadata will be ignored."
341341
)
342+
342343
for b0_id in b0_ids:
343344
# Found B0FieldIdentifier metadata entries
344345
b0_entities = base_entities.copy()

sdcflows/workflows/fit/fieldmap.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -86,9 +86,7 @@ def init_fmap_wf(omp_nthreads=1, sloppy=False, debug=False, mode="phasediff", na
8686
"""
8787
from ...interfaces.bspline import (
8888
BSplineApprox,
89-
DEFAULT_LF_ZOOMS_MM,
9089
DEFAULT_HF_ZOOMS_MM,
91-
DEFAULT_ZOOMS_MM,
9290
)
9391
from ...interfaces.fmap import CheckRegister
9492

@@ -114,9 +112,11 @@ def _unzip(fmap_spec):
114112
magnitude_wf = init_magnitude_wf(omp_nthreads=omp_nthreads)
115113
bs_filter = pe.Node(BSplineApprox(), name="bs_filter")
116114
bs_filter.interface._always_run = debug
117-
bs_filter.inputs.bs_spacing = (
118-
[DEFAULT_LF_ZOOMS_MM, DEFAULT_HF_ZOOMS_MM] if not sloppy else [DEFAULT_ZOOMS_MM]
119-
)
115+
bs_filter.inputs.bs_spacing = [DEFAULT_HF_ZOOMS_MM]
116+
117+
if sloppy:
118+
bs_filter.inputs.zooms_min = 4.0
119+
120120
bs_filter.inputs.extrapolate = not debug
121121

122122
# fmt: off

0 commit comments

Comments
 (0)