Skip to content

Commit 69aa16f

Browse files
committed
Merge pull request #1 from matthew-brett/concat-image-outside-loop
MRG: suggested refactoring of concat checks
2 parents e9298cf + 188c7ea commit 69aa16f

File tree

2 files changed

+38
-49
lines changed

2 files changed

+38
-49
lines changed

nibabel/funcs.py

Lines changed: 37 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -88,19 +88,6 @@ def squeeze_image(img):
8888
img.extra)
8989

9090

91-
def _shape_equal_excluding(shape1, shape2, exclude_axes):
92-
""" Helper function to compare two array shapes, excluding any
93-
axis specified."""
94-
95-
if len(shape1) != len(shape2):
96-
return False
97-
98-
idx_mask = np.ones((len(shape1),), dtype=bool)
99-
idx_mask[exclude_axes] = False
100-
return np.array_equal(np.asarray(shape1)[idx_mask],
101-
np.asarray(shape2)[idx_mask])
102-
103-
10491
def concat_images(images, check_affines=True, axis=None):
10592
''' Concatenate images in list to single image, along specified dimension
10693
@@ -112,51 +99,53 @@ def concat_images(images, check_affines=True, axis=None):
11299
If True, then check that all the affines for `images` are nearly
113100
the same, raising a ``ValueError`` otherwise. Default is True
114101
axis : None or int, optional
115-
If None, concatenates on a new dimension. This rrequires all images
116-
to be the same shape).
117-
If not None, concatenates on the specified dimension. This requires
118-
all images to be the same shape, except on the specified dimension.
119-
For 4D images, axis must be between -2 and 3.
102+
If None, concatenates on a new dimension. This requires all images to
103+
be the same shape. If not None, concatenates on the specified
104+
dimension. This requires all images to be the same shape, except on
105+
the specified dimension.
120106
Returns
121107
-------
122108
concat_img : ``SpatialImage``
123109
New image resulting from concatenating `images` across last
124110
dimension
125111
'''
126-
112+
images = [load(img) if not hasattr(img, 'get_data')
113+
else img for img in images]
127114
n_imgs = len(images)
128115
if n_imgs == 0:
129116
raise ValueError("Cannot concatenate an empty list of images.")
130-
117+
img0 = images[0]
118+
affine = img0.affine
119+
header = img0.header
120+
klass = img0.__class__
121+
shape0 = img0.shape
122+
n_dim = len(shape0)
123+
if axis is None:
124+
# collect images in output array for efficiency
125+
out_shape = (n_imgs, ) + shape0
126+
out_data = np.empty(out_shape)
127+
else:
128+
# collect images in list for use with np.concatenate
129+
out_data = [None] * n_imgs
130+
# Get part of shape we need to check inside loop
131+
idx_mask = np.ones((n_dim,), dtype=bool)
132+
if axis is not None:
133+
idx_mask[axis] = False
134+
masked_shape = np.array(shape0)[idx_mask]
131135
for i, img in enumerate(images):
132-
if not hasattr(img, 'get_data'):
133-
img = load(img)
134-
135-
if i == 0: # first image, initialize data from loaded image
136-
affine = img.affine
137-
header = img.header
138-
shape = img.shape
139-
klass = img.__class__
140-
141-
if axis is None: # collect images in output array for efficiency
142-
out_shape = (n_imgs, ) + shape
143-
out_data = np.empty(out_shape)
144-
else: # collect images in list for use with np.concatenate
145-
out_data = [None] * n_imgs
146-
147-
elif check_affines and not np.all(img.affine == affine):
148-
raise ValueError('Affines do not match')
149-
150-
elif ((axis is None and not np.array_equal(shape, img.shape)) or
151-
(axis is not None and not _shape_equal_excluding(shape, img.shape,
152-
exclude_axes=[axis]))):
153-
# shape mismatch; numpy broadcast / concatenate can hide these.
154-
raise ValueError("Image #%d (shape=%s) does not match the first "
155-
"image shape (%s)." % (i, shape, img.shape))
156-
157-
out_data[i] = img.get_data()
158-
159-
del img
136+
if len(img.shape) != n_dim:
137+
raise ValueError(
138+
'Image {0} has {1} dimensions, image 0 has {2}'.format(
139+
i, len(img.shape), n_dim))
140+
if not np.all(np.array(img.shape)[idx_mask] == masked_shape):
141+
raise ValueError('shape {0} for image {1} not compatible with '
142+
'first image shape {2} with axis == {0}'.format(
143+
img.shape, i, shape0, axis))
144+
if check_affines and not np.all(img.affine == affine):
145+
raise ValueError('Affine for image {0} does not match affine '
146+
'for first image'.format(i))
147+
# Do not fill cache in image if it is empty
148+
out_data[i] = img.get_data(caching='unchanged')
160149

161150
if axis is None:
162151
out_data = np.rollaxis(out_data, 0, out_data.ndim)

nibabel/tests/test_funcs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,14 +79,14 @@ def test_concat():
7979
# but our efficient logic (where all images are
8080
# 3D and the same size) fails, so we also
8181
# have to expect errors for those.
82-
expect_error = data0.ndim != data1.ndim
8382
if axis is None: # 3D from here and below
8483
all_data = np.concatenate([data0[..., np.newaxis],
8584
data1[..., np.newaxis]],
8685
**np_concat_kwargs)
8786
else: # both 3D, appending on final axis
8887
all_data = np.concatenate([data0, data1],
8988
**np_concat_kwargs)
89+
expect_error = False
9090
except ValueError:
9191
# Shapes are not combinable
9292
expect_error = True

0 commit comments

Comments
 (0)