Skip to content

Commit 3de72de

Browse files
authored
Merge pull request #1183 from MouseLand/3d_shape_fix
Refactor image conversion functions for improved clarity and error ha…
2 parents fb5a6c0 + 9cb3c29 commit 3de72de

File tree

4 files changed

+101
-63
lines changed

4 files changed

+101
-63
lines changed

cellpose/io.py

Lines changed: 18 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -230,33 +230,7 @@ def imread_2D(img_file):
230230
img_out (numpy.ndarray): The 3-channel image data as a NumPy array.
231231
"""
232232
img = imread(img_file)
233-
234-
if (img.ndim == 1) or (img.ndim == 4):
235-
raise ValueError("img_file should have 2 or 3 dimensions, shape: %s" % img.shape)
236-
237-
# if image has no channel dimension, add one and return the image
238-
if img.ndim == 2:
239-
img_out = np.zeros((img.shape[0], img.shape[1], 3), dtype=img.dtype)
240-
img_out[:, :, 0] = img
241-
return img_out
242-
243-
# Otherwise, image will have a channel dimension, assume it's either first or last
244-
# force it to be last (XYC):
245-
if img.shape[0] < img.shape[-1]:
246-
img = np.moveaxis(img, 0, -1)
247-
248-
nchan = img.shape[-1]
249-
250-
if nchan == 3:
251-
# already has 3 channels
252-
return img
253-
254-
# ensure there are 3 channels
255-
img_out = np.zeros((img.shape[0], img.shape[1], 3), dtype=img.dtype)
256-
copy_chan = min(3, nchan)
257-
img_out[:, :, :copy_chan] = img[:, :, :copy_chan]
258-
259-
return img_out
233+
return transforms.convert_image(img, do_3D=False)
260234

261235

262236
def imread_3D(img_file):
@@ -271,20 +245,25 @@ def imread_3D(img_file):
271245
"""
272246
img = imread(img_file)
273247

248+
dimension_lengths = list(img.shape)
249+
250+
# guess at channel axis:
251+
channel_axis = np.argmin(img.shape)
252+
del dimension_lengths[channel_axis]
253+
254+
# guess at z axis:
255+
z_axis = np.argmin(dimension_lengths)
256+
257+
# grayscale images:
274258
if img.ndim == 3:
275-
# add a channel dimension
276-
img_out = np.zeros((img.shape[0], img.shape[1], img.shape[2], 3), dtype=img.dtype)
277-
img_out[:, :, :, 0] = img
278-
elif img.ndim == 4:
279-
# assume it's opening as (z, c, y, x)
280-
img_out = np.zeros((img.shape[0], img.shape[2], img.shape[3], 3), dtype=img.dtype)
281-
img_out[:, :, :, :img.shape[1]] = img.transpose(0, 2, 3, 1)
282-
else:
283-
raise ValueError("Image should have 3 or 4 dimensions, shape: %s" % img.shape)
259+
channel_axis = None
284260

285-
del img
286-
return img_out
287-
261+
try:
262+
return transforms.convert_image(img, channel_axis=channel_axis, z_axis=z_axis, do_3D=True)
263+
except Exception as e:
264+
io_logger.critical("ERROR: could not read file, %s" % e)
265+
io_logger.critical("ERROR: Guessed z_axis: %s, channel_axis: %s" % (z_axis, channel_axis))
266+
return None
288267

289268
def remove_model(filename, delete=False):
290269
""" remove model from .cellpose custom model list """

cellpose/transforms.py

Lines changed: 37 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -447,9 +447,12 @@ def update_axis(m_axis, to_squeeze, ndim):
447447
return m_axis
448448

449449

450-
def convert_image_3d(x, channel_axis=None, z_axis=None):
450+
def _convert_image_3d(x, channel_axis=None, z_axis=None):
451451
"""
452452
Convert a 3D or 4D image array to have dimensions ordered as (Z, X, Y, C).
453+
454+
Arrays of ndim=3 are assumed to be grayscale and must be specified with z_axis.
455+
Arrays of ndim=4 must have both `channel_axis` and `z_axis` specified.
453456
454457
Args:
455458
x (numpy.ndarray): Input image array. Must be either 3D (assumed to be grayscale 3D) or 4D.
@@ -474,6 +477,12 @@ def convert_image_3d(x, channel_axis=None, z_axis=None):
474477
channels to ensure the output has exactly 3 channels.
475478
"""
476479

480+
if x.ndim < 3:
481+
raise ValueError(f"Input image must have at least 3 dimensions, input shape: {x.shape}, ndim={x.ndim}")
482+
483+
if z_axis is not None and z_axis < 0:
484+
z_axis += x.ndim
485+
477486
# if image is ndim==3, assume it is greyscale 3D and use provided z_axis
478487
if x.ndim == 3 and z_axis is not None:
479488
# add in channel axis
@@ -484,7 +493,11 @@ def convert_image_3d(x, channel_axis=None, z_axis=None):
484493

485494

486495
if channel_axis is None or z_axis is None:
487-
raise ValueError("both channel_axis and z_axis must be specified when segmenting 3D images of ndim=4")
496+
raise ValueError("For 4D images, both `channel_axis` and `z_axis` must be explicitly specified. Please provide values for both parameters.")
497+
if channel_axis is not None and channel_axis < 0:
498+
channel_axis += x.ndim
499+
if channel_axis is None or channel_axis >= x.ndim:
500+
raise IndexError(f"channel_axis {channel_axis} is out of bounds for input array with {x.ndim} dimensions")
488501
assert x.ndim == 4, f"input image must have ndim == 4, ndim={x.ndim}"
489502

490503
x_dim_shapes = list(x.shape)
@@ -519,23 +532,26 @@ def convert_image_3d(x, channel_axis=None, z_axis=None):
519532
x = x[..., :x_chans_to_copy]
520533
else:
521534
# less than 3 channels: pad up to
522-
x_out = np.zeros((num_z_layers, x_dim_shapes[0], x_dim_shapes[1], 3), dtype=x.dtype)
523-
x_out[..., :x_chans_to_copy] = x[...]
524-
x = x_out
525-
del x_out
535+
pad_width = [(0, 0), (0, 0), (0, 0), (0, 3 - x_chans_to_copy)]
536+
x = np.pad(x, pad_width, mode='constant', constant_values=0)
526537

527538
return x
528539

529540

530541
def convert_image(x, channel_axis=None, z_axis=None, do_3D=False):
531-
"""Converts the image to have the z-axis first, channels last.
542+
"""Converts the image to have the z-axis first, channels last. Image will be converted to 3 channels if it is not already.
543+
If more than 3 channels are provided, only the first 3 channels will be used.
544+
545+
Accepts:
546+
- 2D images with no channel dimension: `z_axis` and `channel_axis` must be `None`
547+
- 2D images with channel dimension: `channel_axis` will be guessed between first or last axis, can also specify `channel_axis`. `z_axis` must be `None`
548+
- 3D images with or without channels:
532549
533550
Args:
534551
x (numpy.ndarray or torch.Tensor): The input image.
535552
channel_axis (int or None): The axis of the channels in the input image. If None, the axis is determined automatically.
536553
z_axis (int or None): The axis of the z-dimension in the input image. If None, the axis is determined automatically.
537554
do_3D (bool): Whether to process the image in 3D mode. Defaults to False.
538-
nchan (int): The number of channels to keep if the input image has more than nchan channels.
539555
540556
Returns:
541557
numpy.ndarray: The converted image.
@@ -551,18 +567,23 @@ def convert_image(x, channel_axis=None, z_axis=None, do_3D=False):
551567
transforms_logger.warning("torch array used as input, converting to numpy")
552568
x = x.cpu().numpy()
553569

570+
# should be 2D
571+
if z_axis is not None and not do_3D:
572+
raise ValueError("2D image provided, but z_axis is not None. Set z_axis=None to process 2D images of ndim=2 or 3.")
573+
574+
# make sure that channel_axis and z_axis are specified if 3D
575+
if ndim == 4 and not do_3D:
576+
raise ValueError("3D input image provided, but do_3D is False. Set do_3D=True to process 3D images. ndims=4")
577+
554578
# make sure that channel_axis and z_axis are specified if 3D
555579
if do_3D:
556-
return convert_image_3d(x, channel_axis=channel_axis, z_axis=z_axis)
580+
return _convert_image_3d(x, channel_axis=channel_axis, z_axis=z_axis)
557581

558-
if ndim == 4:
559-
raise ValueError("3D input image provided, but do_3D is False. Set do_3D=True to process 3D images.")
560-
561582
######################## 2D reshaping ########################
562583
# if user specifies channel axis, return early
563584
if channel_axis is not None:
564585
if ndim == 2:
565-
raise ValueError("2D image provided, but channel_axis is not None. Set channel_axis=None to process 2D images.")
586+
raise ValueError("2D image provided, but channel_axis is not None. Set channel_axis=None to process 2D images of ndim=2.")
566587

567588
# Put channel axis last:
568589
# Find the indices of the dims that need to be put in dim 0 and 1
@@ -613,8 +634,9 @@ def convert_image(x, channel_axis=None, z_axis=None, do_3D=False):
613634
del x_out
614635
else:
615636
# something is wrong: yell
616-
transforms_logger.critical(f"ERROR: Unexpected image shape: {str(x.shape)}")
617-
raise ValueError(f"ERROR: Unexpected image shape: {str(x.shape)}")
637+
expected_shapes = "2D (H, W), 3D (H, W, C), or 4D (Z, H, W, C)"
638+
transforms_logger.critical(f"ERROR: Unexpected image shape: {str(x.shape)}. Expected shapes: {expected_shapes}")
639+
raise ValueError(f"ERROR: Unexpected image shape: {str(x.shape)}. Expected shapes: {expected_shapes}")
618640

619641
return x
620642

tests/test_output.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -94,19 +94,17 @@ def test_cyto2_to_seg(data_dir, image_names, cellposemodel_fixture_24layer):
9494
clear_output(data_dir, image_names)
9595

9696

97-
def test_class_3D(data_dir, image_names_3d, cellposemodel_fixture_2layer):
97+
def test_class_3D_one_img(data_dir, image_names_3d, cellposemodel_fixture_2layer):
9898
clear_output(data_dir, image_names_3d)
9999

100-
for image_name in image_names_3d:
101-
img_file = data_dir / '3D' / image_name
102-
img = io.imread_3D(img_file)
103-
masks_pred, flows_pred, _ = cellposemodel_fixture_2layer.eval(img, do_3D=True, channel_axis=-1, z_axis=0)
104-
# io.imsave(data_dir / "3D" / (img_file.stem + "_cp_masks.tif"), masks)
100+
img_file = data_dir / '3D' / image_names_3d[0]
101+
img = io.imread_3D(img_file)
102+
masks_pred, flows_pred, _ = cellposemodel_fixture_2layer.eval(img, do_3D=True, channel_axis=-1, z_axis=0)
105103

106-
assert img.shape[:-1] == masks_pred.shape, f'mask incorrect shape for {image_name}, {masks_pred.shape=}'
107-
assert img.shape[:-1] == flows_pred[1].shape[1:], f'flows incorrect shape for {image_name}, {flows_pred.shape=}'
104+
assert img.shape[:-1] == masks_pred.shape, f'mask incorrect shape for {image_name}, {masks_pred.shape=}'
105+
assert img.shape[:-1] == flows_pred[1].shape[1:], f'flows incorrect shape for {image_name}, {flows_pred.shape=}'
108106

109-
break # Just test one img for now
107+
# just compare shapes for now
110108
# compare_masks_cp4(data_dir, image_names_3d, "3D")
111109
clear_output(data_dir, image_names_3d)
112110

tests/test_transforms.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import numpy as np
22
import pytest
33

4+
from cellpose import transforms
45
from cellpose.io import imread
56
from cellpose.transforms import normalize_img, random_rotate_and_resize, resize_image
67

@@ -90,3 +91,41 @@ def test_resize(img_2d):
9091
img32 = resize_image(img_2d.astype("uint32"), Lx=Lx, Ly=Ly)
9192
assert img32.shape == (Ly, Lx, 3)
9293
assert img32.dtype == np.uint32
94+
95+
96+
@pytest.mark.parametrize(
97+
"input_shape, channel_axis, z_axis, do_3D, expected_shape, raises_error",
98+
[ # passing:
99+
# 2D:
100+
((100, 120), None, None, False, (100, 120, 3), False), # 2D grayscale image
101+
((100, 120, 3), None, None, False, (100, 120, 3), False), # 2D RGB image
102+
((3, 100, 120), 0, None, False, (100, 120, 3), False), # 2D RGB image with channels first
103+
((3, 100, 120), None, None, False, (100, 120, 3), False), # 2D RGB image with channels first
104+
105+
# 3D:
106+
((100, 120, 5), None, -1, True, (5, 100, 120, 3), False), # 3D grayscale image
107+
((5, 100, 120), None, 0, True, (5, 100, 120, 3), False), # 3D grayscale image
108+
((100, 5, 120, 5), 1, 3, True, (5, 100, 120, 3), False), # 3D 5chan image
109+
((10, 100, 120, 3), -1, 0, True, (10, 100, 120, 3), False), # 3D 5chan image
110+
111+
# failing:
112+
# 2D:
113+
((100, 120), None, 0, False, (100, 120, 3), True), # 2D grayscale image
114+
((100, 120, 3), None, None, True, (100, 120, 3), True), # 2D RGB image
115+
((3, 100, 120), -1, 2, False, (100, 120, 3), True), # 2D RGB image with channels first
116+
((3, 100, 120), None, None, True, (100, 120, 3), True), # 2D RGB image with channels first
117+
118+
# 3D:
119+
((5, 100, 120), None, None, True, (5, 100, 120, 3), True), # 3D grayscale image
120+
((10, 100, 120, 3), -1, 0, False, (10, 100, 120, 3), True), # 3D rgb image
121+
],
122+
)
123+
def test_convert_image(input_shape, channel_axis, z_axis, do_3D, expected_shape, raises_error):
124+
"""Test the convert_image function with various input shapes and parameters."""
125+
img = np.random.rand(*input_shape).astype(np.float32)
126+
if raises_error:
127+
with pytest.raises(ValueError):
128+
transforms.convert_image(img, channel_axis=channel_axis, z_axis=z_axis, do_3D=do_3D)
129+
else:
130+
converted_img = transforms.convert_image(img, channel_axis=channel_axis, z_axis=z_axis, do_3D=do_3D)
131+
assert converted_img.shape == expected_shape, f"Expected shape {expected_shape}, but got {converted_img.shape}"

0 commit comments

Comments
 (0)