Skip to content

Commit b44d2bb

Browse files
committed
fix: do not fit intercept, improve tests
1 parent bbd9261 commit b44d2bb

File tree

2 files changed

+48
-19
lines changed

2 files changed

+48
-19
lines changed

sdcflows/interfaces/bspline.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -140,15 +140,15 @@ def _run_interface(self, runtime):
140140

141141
# Load in the fieldmap
142142
fmapnii = nb.load(self.inputs.in_data)
143-
fmapnii = nb.as_closest_canonical(fmapnii)
143+
# fmapnii = nb.as_closest_canonical(fmapnii)
144144
zooms = fmapnii.header.get_zooms()
145145

146146
# Get a mask (or define on the spot to cover the full extent)
147147
masknii = (
148148
nb.load(self.inputs.in_mask) if isdefined(self.inputs.in_mask) else None
149149
)
150-
if masknii is not None:
151-
masknii = nb.as_closest_canonical(masknii)
150+
# if masknii is not None:
151+
# masknii = nb.as_closest_canonical(masknii)
152152

153153
# Determine the shape of bspline coefficients
154154
# This should not change with resizing, so do it first
@@ -212,7 +212,7 @@ def _run_interface(self, runtime):
212212
)
213213

214214
# Fit the model
215-
model = lm.Ridge(alpha=self.inputs.ridge_alpha, fit_intercept=True)
215+
model = lm.Ridge(alpha=self.inputs.ridge_alpha, fit_intercept=False)
216216
for attempt in range(3):
217217
model.fit(colmat, data.reshape(-1))
218218
extreme = np.abs(model.coef_).max()
@@ -248,11 +248,11 @@ def _run_interface(self, runtime):
248248
# Interpolating in the original grid will require a new collocation matrix
249249
if need_resize:
250250
fmapnii = nb.load(self.inputs.in_data)
251-
fmapnii = nb.as_closest_canonical(fmapnii)
251+
# fmapnii = nb.as_closest_canonical(fmapnii)
252252
data = fmapnii.get_fdata(dtype="float32") - center
253253
if masknii is not None:
254254
masknii = nb.load(self.inputs.in_mask)
255-
masknii = nb.as_closest_canonical(masknii)
255+
# masknii = nb.as_closest_canonical(masknii)
256256
mask = np.asanyarray(masknii.dataobj) > 1e-4
257257
else:
258258
mask = np.ones_like(fmapnii.dataobj, dtype=bool)
@@ -268,14 +268,20 @@ def _run_interface(self, runtime):
268268
# Store interpolated field
269269
hdr = fmapnii.header.copy()
270270
hdr.set_data_dtype("float32")
271-
fmapnii.__class__(interp_data, fmapnii.affine, hdr).to_filename(out_name)
271+
outnii = fmapnii.__class__(interp_data, fmapnii.affine, hdr)
272+
outnii.header["cal_max"] = np.abs(outnii.dataobj).max()
273+
outnii.header["cal_min"] = -outnii.header["cal_max"]
274+
outnii.to_filename(out_name)
272275
self._results["out_field"] = out_name
273276

274277
# Write out fitting-error map
275278
self._results["out_error"] = out_name.replace("_field.", "_error.")
276-
fmapnii.__class__(
277-
data * mask - interp_data + model.intercept_, fmapnii.affine, fmapnii.header
278-
).to_filename(self._results["out_error"])
279+
errornii = fmapnii.__class__(
280+
(data - interp_data) * mask, fmapnii.affine, fmapnii.header
281+
)
282+
errornii.header["cal_min"] = 0
283+
errornii.header["cal_max"] = np.max(errornii.dataobj)
284+
errornii.to_filename(self._results["out_error"])
279285

280286
if not self.inputs.extrapolate:
281287
return runtime

sdcflows/interfaces/tests/test_bspline.py

Lines changed: 32 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -40,41 +40,64 @@
4040
@pytest.mark.parametrize("testnum", range(100))
4141
def test_bsplines(tmp_path, testnum):
4242
"""Test idempotency of B-Splines interpolation + approximation."""
43-
targetshape = (10, 12, 9)
43+
targetshape = (50, 50, 30)
4444

4545
# Generate an oblique affine matrix for the target - it will be a common case.
4646
targetaff = nb.affines.from_matvec(
47-
nb.eulerangles.euler2mat(x=0.9, y=0.001, z=0.001) @ np.diag((2, 3, 4)),
47+
nb.eulerangles.euler2mat(x=-0.9, y=0.001, z=0.001) @ np.diag((2, 2, 2.4)),
4848
)
4949

5050
# Intendedly mis-centered (exercise we may not have volume-centered NIfTIs)
5151
targetaff[:3, 3] = nb.affines.apply_affine(
5252
targetaff, 0.5 * (np.array(targetshape) - 3)
5353
)
5454

55+
mask = np.zeros(targetshape)
56+
mask[10:-10, 10:-10, 6:-6] = 1
5557
# Generate some target grid
56-
targetnii = nb.Nifti1Image(np.ones(targetshape), targetaff, None)
57-
targetnii.to_filename(tmp_path / "target.nii.gz")
58+
targetnii = nb.Nifti1Image(mask, targetaff, None)
59+
targetnii.header.set_qform(targetaff, code=1)
60+
targetnii.header.set_sform(targetaff, code=1)
61+
targetnii.to_filename(tmp_path / "mask.nii.gz")
5862

5963
# Generate random coefficients
60-
gridnii = bspline_grid(targetnii, control_zooms_mm=(4, 6, 8))
61-
coeff = (rng.random(size=gridnii.shape) - 0.5) * 500
64+
gridnii = bspline_grid(targetnii, control_zooms_mm=(40, 40, 16))
65+
coeff = (rng.standard_normal(size=gridnii.shape)) * 100
6266
coeffnii = nb.Nifti1Image(coeff.astype("float32"), gridnii.affine, gridnii.header)
67+
coeffnii.header["cal_max"] = np.abs(coeff).max()
68+
coeffnii.header["cal_min"] = -coeffnii.header["cal_max"]
69+
coeffnii.header.set_qform(gridnii.affine, code=1)
70+
coeffnii.header.set_sform(gridnii.affine, code=1)
6371
coeffnii.to_filename(tmp_path / "coeffs.nii.gz")
6472

6573
os.chdir(tmp_path)
6674
# Check that we can interpolate the coefficients on a target
6775
test1 = ApplyCoeffsField(
68-
in_data=str(tmp_path / "target.nii.gz"),
76+
in_data=str(tmp_path / "mask.nii.gz"),
6977
in_coeff=str(tmp_path / "coeffs.nii.gz"),
7078
pe_dir="j-",
7179
ro_time=1.0,
7280
).run()
7381

82+
fieldnii = nb.load(test1.outputs.out_field)
83+
fielddata = fieldnii.get_fdata()
84+
fielddata -= np.median(fielddata)
85+
fielddata = 200 * fielddata / np.abs(fielddata).max()
86+
87+
fieldnii.header["cal_max"] = np.abs(fielddata).max()
88+
fieldnii.header["cal_min"] = -fieldnii.header["cal_max"]
89+
fieldnii.header.set_qform(targetaff, code=1)
90+
fieldnii.header.set_sform(targetaff, code=1)
91+
92+
nb.Nifti1Image(fielddata, targetaff, fieldnii.header).to_filename(
93+
tmp_path / "testfield.nii.gz",
94+
)
95+
7496
# Approximate the interpolated target
7597
test2 = BSplineApprox(
76-
in_data=test1.outputs.out_field,
77-
bs_spacing=[(4, 6, 8)],
98+
in_data=str(tmp_path / "testfield.nii.gz"),
99+
# in_mask=str(tmp_path / "mask.nii.gz"),
100+
bs_spacing=[(40, 40, 16)],
78101
zooms_min=0,
79102
recenter=False,
80103
ridge_alpha=1e-4,

0 commit comments

Comments
 (0)