@@ -221,6 +221,7 @@ def imread(filename):
221221def imread_2D (img_file ):
222222 """
223223 Read in a 2D image file and convert it to a 3-channel image. Attempts to do this for multi-channel and grayscale images.
224+ If the image has more than 3 channels, only the first 3 channels are kept.
224225
225226 Args:
226227 img_file (str): The path to the image file.
@@ -230,27 +231,30 @@ def imread_2D(img_file):
230231 """
231232 img = imread (img_file )
232233
233- # force XYC:
234+ if (img .ndim == 1 ) or (img .ndim == 4 ):
235+ raise ValueError ("img_file should have 2 or 3 dimensions, shape: %s" % img .shape )
236+
237+ # if image has no channel dimension, add one and return the image
238+ if img .ndim == 2 :
239+ img_out = np .zeros ((img .shape [0 ], img .shape [1 ], 3 ), dtype = img .dtype )
240+ img_out [:, :, 0 ] = img
241+ return img_out
242+
243+ # Otherwise, image will have a channel dimension, assume it's either first or last
244+ # force it to be last (XYC):
234245 if img .shape [0 ] < img .shape [- 1 ]:
235- # move channel to last dim:
236246 img = np .moveaxis (img , 0 , - 1 )
237247
238- nchan = img .shape [2 ]
248+ nchan = img .shape [- 1 ]
239249
240- if img .ndim == 3 :
241- if nchan == 3 :
242- # already has 3 channels
243- return img
244-
245- # ensure there are 3 channels
246- img_out = np .zeros ((img .shape [0 ], img .shape [1 ], 3 ), dtype = img .dtype )
247- copy_chan = min (3 , nchan )
248- img_out [:, :, :copy_chan ] = img [:, :, :copy_chan ]
249-
250- elif img .ndim == 2 :
251- # add a channel dimension
252- img_out = np .zeros ((img .shape [0 ], img .shape [1 ], 3 ), dtype = img .dtype )
253- img_out [:, :, 0 ] = img
250+ if nchan == 3 :
251+ # already has 3 channels
252+ return img
253+
254+ # ensure there are 3 channels
255+ img_out = np .zeros ((img .shape [0 ], img .shape [1 ], 3 ), dtype = img .dtype )
256+ copy_chan = min (3 , nchan )
257+ img_out [:, :, :copy_chan ] = img [:, :, :copy_chan ]
254258
255259 return img_out
256260
0 commit comments