@@ -447,9 +447,12 @@ def update_axis(m_axis, to_squeeze, ndim):
447447 return m_axis
448448
449449
450- def convert_image_3d (x , channel_axis = None , z_axis = None ):
450+ def _convert_image_3d (x , channel_axis = None , z_axis = None ):
451451 """
452452 Convert a 3D or 4D image array to have dimensions ordered as (Z, X, Y, C).
453+
454+ Arrays of ndim=3 are assumed to be grayscale and must be specified with z_axis.
455+ Arrays of ndim=4 must have both `channel_axis` and `z_axis` specified.
453456
454457 Args:
455458 x (numpy.ndarray): Input image array. Must be either 3D (assumed to be grayscale 3D) or 4D.
@@ -474,6 +477,12 @@ def convert_image_3d(x, channel_axis=None, z_axis=None):
474477 channels to ensure the output has exactly 3 channels.
475478 """
476479
480+ if x .ndim < 3 :
481+ raise ValueError (f"Input image must have at least 3 dimensions, input shape: { x .shape } , ndim={ x .ndim } " )
482+
483+ if z_axis is not None and z_axis < 0 :
484+ z_axis += x .ndim
485+
477486 # if image is ndim==3, assume it is greyscale 3D and use provided z_axis
478487 if x .ndim == 3 and z_axis is not None :
479488 # add in channel axis
@@ -484,7 +493,11 @@ def convert_image_3d(x, channel_axis=None, z_axis=None):
484493
485494
486495 if channel_axis is None or z_axis is None :
487- raise ValueError ("both channel_axis and z_axis must be specified when segmenting 3D images of ndim=4" )
496+ raise ValueError ("For 4D images, both `channel_axis` and `z_axis` must be explicitly specified. Please provide values for both parameters." )
497+ if channel_axis is not None and channel_axis < 0 :
498+ channel_axis += x .ndim
499+ if channel_axis is None or channel_axis >= x .ndim :
500+ raise IndexError (f"channel_axis { channel_axis } is out of bounds for input array with { x .ndim } dimensions" )
488501 assert x .ndim == 4 , f"input image must have ndim == 4, ndim={ x .ndim } "
489502
490503 x_dim_shapes = list (x .shape )
@@ -519,23 +532,26 @@ def convert_image_3d(x, channel_axis=None, z_axis=None):
519532 x = x [..., :x_chans_to_copy ]
520533 else :
521534 # less than 3 channels: pad up to
522- x_out = np .zeros ((num_z_layers , x_dim_shapes [0 ], x_dim_shapes [1 ], 3 ), dtype = x .dtype )
523- x_out [..., :x_chans_to_copy ] = x [...]
524- x = x_out
525- del x_out
535+ pad_width = [(0 , 0 ), (0 , 0 ), (0 , 0 ), (0 , 3 - x_chans_to_copy )]
536+ x = np .pad (x , pad_width , mode = 'constant' , constant_values = 0 )
526537
527538 return x
528539
529540
530541def convert_image (x , channel_axis = None , z_axis = None , do_3D = False ):
531- """Converts the image to have the z-axis first, channels last.
542+ """Converts the image to have the z-axis first, channels last. Image will be converted to 3 channels if it is not already.
543+ If more than 3 channels are provided, only the first 3 channels will be used.
544+
545+ Accepts:
546+ - 2D images with no channel dimension: `z_axis` and `channel_axis` must be `None`
547+ - 2D images with channel dimension: `channel_axis` will be guessed between first or last axis, can also specify `channel_axis`. `z_axis` must be `None`
548+ - 3D images with or without channels:
532549
533550 Args:
534551 x (numpy.ndarray or torch.Tensor): The input image.
535552 channel_axis (int or None): The axis of the channels in the input image. If None, the axis is determined automatically.
536553 z_axis (int or None): The axis of the z-dimension in the input image. If None, the axis is determined automatically.
537554 do_3D (bool): Whether to process the image in 3D mode. Defaults to False.
538- nchan (int): The number of channels to keep if the input image has more than nchan channels.
539555
540556 Returns:
541557 numpy.ndarray: The converted image.
@@ -551,18 +567,23 @@ def convert_image(x, channel_axis=None, z_axis=None, do_3D=False):
551567 transforms_logger .warning ("torch array used as input, converting to numpy" )
552568 x = x .cpu ().numpy ()
553569
570+ # should be 2D
571+ if z_axis is not None and not do_3D :
572+ raise ValueError ("2D image provided, but z_axis is not None. Set z_axis=None to process 2D images of ndim=2 or 3." )
573+
574+ # make sure that channel_axis and z_axis are specified if 3D
575+ if ndim == 4 and not do_3D :
576+ raise ValueError ("3D input image provided, but do_3D is False. Set do_3D=True to process 3D images. ndims=4" )
577+
554578 # make sure that channel_axis and z_axis are specified if 3D
555579 if do_3D :
556- return convert_image_3d (x , channel_axis = channel_axis , z_axis = z_axis )
580+ return _convert_image_3d (x , channel_axis = channel_axis , z_axis = z_axis )
557581
558- if ndim == 4 :
559- raise ValueError ("3D input image provided, but do_3D is False. Set do_3D=True to process 3D images." )
560-
561582 ######################## 2D reshaping ########################
562583 # if user specifies channel axis, return early
563584 if channel_axis is not None :
564585 if ndim == 2 :
565- raise ValueError ("2D image provided, but channel_axis is not None. Set channel_axis=None to process 2D images." )
586+ raise ValueError ("2D image provided, but channel_axis is not None. Set channel_axis=None to process 2D images of ndim=2 ." )
566587
567588 # Put channel axis last:
568589 # Find the indices of the dims that need to be put in dim 0 and 1
@@ -613,8 +634,9 @@ def convert_image(x, channel_axis=None, z_axis=None, do_3D=False):
613634 del x_out
614635 else :
615636 # something is wrong: yell
616- transforms_logger .critical (f"ERROR: Unexpected image shape: { str (x .shape )} " )
617- raise ValueError (f"ERROR: Unexpected image shape: { str (x .shape )} " )
637+ expected_shapes = "2D (H, W), 3D (H, W, C), or 4D (Z, H, W, C)"
638+ transforms_logger .critical (f"ERROR: Unexpected image shape: { str (x .shape )} . Expected shapes: { expected_shapes } " )
639+ raise ValueError (f"ERROR: Unexpected image shape: { str (x .shape )} . Expected shapes: { expected_shapes } " )
618640
619641 return x
620642
0 commit comments