Skip to content

Commit 1575f33

Browse files
committed
fix 2d channel img reading logic
1 parent 7e77525 commit 1575f33

File tree

1 file changed

+21
-17
lines changed

1 file changed

+21
-17
lines changed

cellpose/io.py

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,7 @@ def imread(filename):
221221
def imread_2D(img_file):
222222
"""
223223
Read in a 2D image file and convert it to a 3-channel image. Attempts to do this for multi-channel and grayscale images.
224+
If the image has more than 3 channels, only the first 3 channels are kept.
224225
225226
Args:
226227
img_file (str): The path to the image file.
@@ -230,27 +231,30 @@ def imread_2D(img_file):
230231
"""
231232
img = imread(img_file)
232233

233-
# force XYC:
234+
if (img.ndim == 1) or (img.ndim == 4):
235+
raise ValueError("img_file should have 2 or 3 dimensions, shape: %s" % img.shape)
236+
237+
# if image has no channel dimension, add one and return the image
238+
if img.ndim == 2:
239+
img_out = np.zeros((img.shape[0], img.shape[1], 3), dtype=img.dtype)
240+
img_out[:, :, 0] = img
241+
return img_out
242+
243+
# Otherwise, image will have a channel dimension, assume it's either first or last
244+
# force it to be last (XYC):
234245
if img.shape[0] < img.shape[-1]:
235-
# move channel to last dim:
236246
img = np.moveaxis(img, 0, -1)
237247

238-
nchan = img.shape[2]
248+
nchan = img.shape[-1]
239249

240-
if img.ndim == 3:
241-
if nchan == 3:
242-
# already has 3 channels
243-
return img
244-
245-
# ensure there are 3 channels
246-
img_out = np.zeros((img.shape[0], img.shape[1], 3), dtype=img.dtype)
247-
copy_chan = min(3, nchan)
248-
img_out[:, :, :copy_chan] = img[:, :, :copy_chan]
249-
250-
elif img.ndim == 2:
251-
# add a channel dimension
252-
img_out = np.zeros((img.shape[0], img.shape[1], 3), dtype=img.dtype)
253-
img_out[:, :, 0] = img
250+
if nchan == 3:
251+
# already has 3 channels
252+
return img
253+
254+
# ensure there are 3 channels
255+
img_out = np.zeros((img.shape[0], img.shape[1], 3), dtype=img.dtype)
256+
copy_chan = min(3, nchan)
257+
img_out[:, :, :copy_chan] = img[:, :, :copy_chan]
254258

255259
return img_out
256260

0 commit comments

Comments
 (0)