Skip to content

Commit 0837e91

Browse files
committed
sty: pacify flake8
1 parent f59720d commit 0837e91

File tree

6 files changed

+85
-73
lines changed

6 files changed

+85
-73
lines changed

nitransforms/base.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,10 @@ def __ne__(self, other):
177177
class TransformBase:
178178
"""Abstract image class to represent transforms."""
179179

180-
__slots__ = ("_reference", "_ndim",)
180+
__slots__ = (
181+
"_reference",
182+
"_ndim",
183+
)
181184

182185
def __init__(self, reference=None):
183186
"""Instantiate a transform."""
@@ -283,6 +286,7 @@ def _as_homogeneous(xyz, dtype="float32", dim=3):
283286

284287
return np.hstack((xyz, np.ones((xyz.shape[0], 1), dtype=dtype)))
285288

289+
286290
def _apply_affine(x, affine, dim):
287291
"""Get the image array's indexes corresponding to coordinates."""
288292
return np.tensordot(

nitransforms/linear.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,12 @@
1010
import warnings
1111
import numpy as np
1212
from pathlib import Path
13-
from scipy import ndimage as ndi
1413

15-
from nibabel.loadsave import load as _nbload
1614
from nibabel.affines import from_matvec
17-
from nibabel.arrayproxy import get_obj_dtype
1815

1916
from nitransforms.base import (
2017
ImageGrid,
2118
TransformBase,
22-
SpatialReference,
2319
_as_homogeneous,
2420
EQUALITY_TOL,
2521
)
@@ -112,7 +108,7 @@ def __invert__(self):
112108
113109
"""
114110
return self.__class__(self._inverse)
115-
111+
116112
def __len__(self):
117113
"""Enable using len()."""
118114
return 1 if self._matrix.ndim == 2 else len(self._matrix)

nitransforms/nonlinear.py

Lines changed: 28 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,10 @@
1414
from nitransforms import io
1515
from nitransforms.io.base import _ensure_image
1616
from nitransforms.interp.bspline import grid_bspline_weights, _cubic_bspline
17-
from nitransforms.resampling import apply
1817
from nitransforms.base import (
1918
TransformBase,
2019
TransformError,
2120
ImageGrid,
22-
SpatialReference,
2321
_as_homogeneous,
2422
)
2523
from scipy.ndimage import map_coordinates
@@ -77,14 +75,12 @@ def __init__(self, field=None, is_deltas=True, reference=None):
7775
is_deltas = True
7876

7977
try:
80-
self.reference = ImageGrid(
81-
reference if reference is not None else field
82-
)
78+
self.reference = ImageGrid(reference if reference is not None else field)
8379
except AttributeError:
8480
raise TransformError(
8581
"Field must be a spatial image if reference is not provided"
86-
if reference is None else
87-
"Reference is not a spatial image"
82+
if reference is None
83+
else "Reference is not a spatial image"
8884
)
8985

9086
if self._field.shape[-1] != self.ndim:
@@ -175,16 +171,19 @@ def map(self, x, inverse=False):
175171
indexes = tuple(tuple(i) for i in indexes)
176172
return self._field[indexes]
177173

178-
return np.vstack(tuple(
179-
map_coordinates(
180-
self._field[..., i],
181-
ijk.T,
182-
order=3,
183-
mode="constant",
184-
cval=0,
185-
prefilter=True,
186-
) for i in range(self.reference.ndim)
187-
)).T
174+
return np.vstack(
175+
tuple(
176+
map_coordinates(
177+
self._field[..., i],
178+
ijk.T,
179+
order=3,
180+
mode="constant",
181+
cval=0,
182+
prefilter=True,
183+
)
184+
for i in range(self.reference.ndim)
185+
)
186+
).T
188187

189188
def __matmul__(self, b):
190189
"""
@@ -206,9 +205,9 @@ def __matmul__(self, b):
206205
True
207206
208207
"""
209-
retval = b.map(
210-
self._field.reshape((-1, self._field.shape[-1]))
211-
).reshape(self._field.shape)
208+
retval = b.map(self._field.reshape((-1, self._field.shape[-1]))).reshape(
209+
self._field.shape
210+
)
212211
return DenseFieldTransform(retval, is_deltas=False, reference=self.reference)
213212

214213
def __eq__(self, other):
@@ -247,12 +246,12 @@ def from_filename(cls, filename, fmt="X5"):
247246
class BSplineFieldTransform(TransformBase):
248247
"""Represent a nonlinear transform parameterized by BSpline basis."""
249248

250-
__slots__ = ['_coeffs', '_knots', '_weights', '_order', '_moving']
249+
__slots__ = ["_coeffs", "_knots", "_weights", "_order", "_moving"]
251250

252251
@property
253252
def ndim(self):
254253
"""Access the dimensions of this BSpline."""
255-
#return ndim = self._coeffs.shape[-1]
254+
# return ndim = self._coeffs.shape[-1]
256255
return self._coeffs.ndim - 1
257256

258257
def __init__(self, coefficients, reference=None, order=3):
@@ -270,8 +269,9 @@ def __init__(self, coefficients, reference=None, order=3):
270269

271270
if coefficients.shape[-1] != self.reference.ndim:
272271
raise TransformError(
273-
'Number of components of the coefficients does '
274-
'not match the number of dimensions')
272+
"Number of components of the coefficients does "
273+
"not match the number of dimensions"
274+
)
275275

276276
@property
277277
def ndim(self):
@@ -281,8 +281,7 @@ def ndim(self):
281281
def to_field(self, reference=None, dtype="float32"):
282282
"""Generate a displacements deformation field from this B-Spline field."""
283283
_ref = (
284-
self.reference if reference is None else
285-
ImageGrid(_ensure_image(reference))
284+
self.reference if reference is None else ImageGrid(_ensure_image(reference))
286285
)
287286
if _ref is None:
288287
raise TransformError("A reference must be defined")
@@ -350,9 +349,9 @@ def _map_xyz(x, reference, knots, coeffs):
350349
# Probably this will change if the order of the B-Spline is different
351350
w_start, w_end = np.ceil(ijk - 2).astype(int), np.floor(ijk + 2).astype(int)
352351
# Generate a grid of indexes corresponding to the window
353-
nonzero_knots = tuple([
354-
np.arange(start, end + 1) for start, end in zip(w_start, w_end)
355-
])
352+
nonzero_knots = tuple(
353+
[np.arange(start, end + 1) for start, end in zip(w_start, w_end)]
354+
)
356355
nonzero_knots = tuple(np.meshgrid(*nonzero_knots, indexing="ij"))
357356
window = np.array(nonzero_knots).reshape((ndim, -1))
358357

nitransforms/resampling.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,9 @@ def apply(
7777
reference = _nbload(str(reference))
7878

7979
_ref = (
80-
transform.reference if reference is None else SpatialReference.factory(reference)
80+
transform.reference
81+
if reference is None
82+
else SpatialReference.factory(reference)
8183
)
8284

8385
if _ref is None:
@@ -89,20 +91,25 @@ def apply(
8991
data = np.asanyarray(spatialimage.dataobj)
9092

9193
if data.ndim == 4 and data.shape[-1] != len(transform):
92-
raise ValueError("The fourth dimension of the data does not match the tranform's shape.")
94+
raise ValueError(
95+
"The fourth dimension of the data does not match the tranform's shape."
96+
)
9397

9498
if data.ndim < transform.ndim:
9599
data = data[..., np.newaxis]
96-
97-
if hasattr(transform, 'to_field') and callable(transform.to_field):
100+
101+
if hasattr(transform, "to_field") and callable(transform.to_field):
98102
targets = ImageGrid(spatialimage).index(
99-
_as_homogeneous(transform.to_field(reference=reference).map(_ref.ndcoords.T), dim=_ref.ndim)
103+
_as_homogeneous(
104+
transform.to_field(reference=reference).map(_ref.ndcoords.T),
105+
dim=_ref.ndim,
106+
)
100107
)
101108
else:
102109
targets = ImageGrid(spatialimage).index( # data should be an image
103110
_as_homogeneous(transform.map(_ref.ndcoords.T), dim=_ref.ndim)
104111
)
105-
112+
106113
if transform.ndim == 4:
107114
targets = _as_homogeneous(targets.reshape(-2, targets.shape[0])).T
108115

@@ -115,14 +122,14 @@ def apply(
115122
cval=cval,
116123
prefilter=prefilter,
117124
)
118-
125+
119126
if isinstance(_ref, ImageGrid): # If reference is grid, reshape
120127
hdr = None
121128
if _ref.header is not None:
122129
hdr = _ref.header.copy()
123130
hdr.set_data_dtype(output_dtype or spatialimage.get_data_dtype())
124131
moved = spatialimage.__class__(
125-
resampled.reshape(_ref.shape if data.ndim < 4 else _ref.shape + (-1, )),
132+
resampled.reshape(_ref.shape if data.ndim < 4 else _ref.shape + (-1,)),
126133
_ref.affine,
127134
hdr,
128135
)

nitransforms/tests/test_base.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,12 @@
44
import pytest
55
import h5py
66

7-
from ..base import SpatialReference, SampledSpatialData, ImageGrid, TransformBase, _as_homogeneous
7+
from ..base import (
8+
SpatialReference,
9+
SampledSpatialData,
10+
ImageGrid,
11+
TransformBase,
12+
)
813
from .. import linear as nitl
914
from ..resampling import apply
1015

@@ -104,15 +109,11 @@ def _to_hdf5(klass, x5_root):
104109
xfm = nitl.Affine()
105110
xfm.reference = fname
106111
moved = apply(xfm, fname, order=0)
107-
assert np.all(
108-
imgdata == np.asanyarray(moved.dataobj, dtype=moved.get_data_dtype())
109-
)
112+
assert np.all(imgdata == np.asanyarray(moved.dataobj, dtype=moved.get_data_dtype()))
110113

111114
# Test ndim returned by affine
112115
assert nitl.Affine().ndim == 3
113-
assert nitl.LinearTransformsMapping(
114-
[nitl.Affine(), nitl.Affine()]
115-
).ndim == 4
116+
assert nitl.LinearTransformsMapping([nitl.Affine(), nitl.Affine()]).ndim == 4
116117

117118
# Test applying to Gifti
118119
gii = nb.gifti.GiftiImage(
@@ -127,7 +128,7 @@ def _to_hdf5(klass, x5_root):
127128
assert np.allclose(giimoved.reshape(xfm.reference.shape), moved.get_fdata())
128129

129130
# Test to_filename
130-
xfm.to_filename("data.xfm", fmt='itk')
131+
xfm.to_filename("data.xfm", fmt="itk")
131132

132133

133134
def test_SampledSpatialData(testdata_path):

nitransforms/tests/test_nonlinear.py

Lines changed: 27 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
3dNwarpApply -nwarp {transform} -source {moving} \
3030
-master {reference} -interp NN -prefix {output} {extra}\
3131
""".format,
32-
'fsl': """\
32+
"fsl": """\
3333
applywarp -i {moving} -r {reference} -o {output} {extra}\
3434
-w {transform} --interp=nn""".format,
3535
}
@@ -39,7 +39,9 @@
3939
def test_itk_disp_load(size):
4040
"""Checks field sizes."""
4141
with pytest.raises(TransformFileError):
42-
ITKDisplacementsField.from_image(nb.Nifti1Image(np.zeros(size), np.eye(4), None))
42+
ITKDisplacementsField.from_image(
43+
nb.Nifti1Image(np.zeros(size), np.eye(4), None)
44+
)
4345

4446

4547
@pytest.mark.parametrize("size", [(20, 20, 20), (20, 20, 20, 2, 3), (20, 20, 20, 1, 4)])
@@ -98,15 +100,16 @@ def test_bsplines_references(testdata_path):
98100

99101
with pytest.raises(TransformError):
100102
apply(
101-
BSplineFieldTransform(testdata_path / "someones_bspline_coefficients.nii.gz"),
103+
BSplineFieldTransform(
104+
testdata_path / "someones_bspline_coefficients.nii.gz"
105+
),
102106
testdata_path / "someones_anatomy.nii.gz",
103107
)
104108

105109
apply(
106-
BSplineFieldTransform(
107-
testdata_path / "someones_bspline_coefficients.nii.gz"),
110+
BSplineFieldTransform(testdata_path / "someones_bspline_coefficients.nii.gz"),
108111
testdata_path / "someones_anatomy.nii.gz",
109-
reference=testdata_path / "someones_anatomy.nii.gz"
112+
reference=testdata_path / "someones_anatomy.nii.gz",
110113
)
111114

112115

@@ -170,7 +173,7 @@ def test_displacements_field1(
170173
nt_moved_mask.set_data_dtype(msk.get_data_dtype())
171174
diff = np.asanyarray(sw_moved_mask.dataobj) - np.asanyarray(nt_moved_mask.dataobj)
172175

173-
assert np.sqrt((diff ** 2).mean()) < RMSE_TOL
176+
assert np.sqrt((diff**2).mean()) < RMSE_TOL
174177
brainmask = np.asanyarray(nt_moved_mask.dataobj, dtype=bool)
175178

176179
# Then apply the transform and cross-check with software
@@ -179,7 +182,7 @@ def test_displacements_field1(
179182
reference=tmp_path / "reference.nii.gz",
180183
moving=tmp_path / "reference.nii.gz",
181184
output=tmp_path / "resampled.nii.gz",
182-
extra="--output-data-type uchar" if sw_tool == "itk" else ""
185+
extra="--output-data-type uchar" if sw_tool == "itk" else "",
183186
)
184187

185188
exit_code = check_call([cmd], shell=True)
@@ -190,10 +193,9 @@ def test_displacements_field1(
190193
nt_moved.set_data_dtype(nii.get_data_dtype())
191194
nt_moved.to_filename("nt_resampled.nii.gz")
192195
sw_moved.set_data_dtype(nt_moved.get_data_dtype())
193-
diff = (
194-
np.asanyarray(sw_moved.dataobj, dtype=sw_moved.get_data_dtype())
195-
- np.asanyarray(nt_moved.dataobj, dtype=nt_moved.get_data_dtype())
196-
)
196+
diff = np.asanyarray(
197+
sw_moved.dataobj, dtype=sw_moved.get_data_dtype()
198+
) - np.asanyarray(nt_moved.dataobj, dtype=nt_moved.get_data_dtype())
197199
# A certain tolerance is necessary because of resampling at borders
198200
assert np.sqrt((diff[brainmask] ** 2).mean()) < RMSE_TOL
199201

@@ -230,12 +232,11 @@ def test_displacements_field2(tmp_path, testdata_path, sw_tool):
230232
nt_moved = xfm.apply(img_fname, order=0)
231233
nt_moved.to_filename("nt_resampled.nii.gz")
232234
sw_moved.set_data_dtype(nt_moved.get_data_dtype())
233-
diff = (
234-
np.asanyarray(sw_moved.dataobj, dtype=sw_moved.get_data_dtype())
235-
- np.asanyarray(nt_moved.dataobj, dtype=nt_moved.get_data_dtype())
236-
)
235+
diff = np.asanyarray(
236+
sw_moved.dataobj, dtype=sw_moved.get_data_dtype()
237+
) - np.asanyarray(nt_moved.dataobj, dtype=nt_moved.get_data_dtype())
237238
# A certain tolerance is necessary because of resampling at borders
238-
assert np.sqrt((diff ** 2).mean()) < RMSE_TOL
239+
assert np.sqrt((diff**2).mean()) < RMSE_TOL
239240

240241

241242
def test_bspline(tmp_path, testdata_path):
@@ -249,12 +250,16 @@ def test_bspline(tmp_path, testdata_path):
249250
bsplxfm = BSplineFieldTransform(bs_name, reference=img_name)
250251
dispxfm = DenseFieldTransform(disp_name)
251252

252-
out_disp = apply(dispxfm,img_name)
253-
out_bspl = apply(bsplxfm,img_name)
253+
out_disp = apply(dispxfm, img_name)
254+
out_bspl = apply(bsplxfm, img_name)
254255

255256
out_disp.to_filename("resampled_field.nii.gz")
256257
out_bspl.to_filename("resampled_bsplines.nii.gz")
257258

258-
assert np.sqrt(
259-
(out_disp.get_fdata(dtype="float32") - out_bspl.get_fdata(dtype="float32")) ** 2
260-
).mean() < 0.2
259+
assert (
260+
np.sqrt(
261+
(out_disp.get_fdata(dtype="float32") - out_bspl.get_fdata(dtype="float32"))
262+
** 2
263+
).mean()
264+
< 0.2
265+
)

0 commit comments

Comments
 (0)