@@ -559,6 +559,7 @@ def eval(self, x, batch_size=8, channels=None, channel_axis=None, z_axis=None,
559559 styles (list, np.ndarray): style vector summarizing each image of size 256.
560560 imgs (list of 2D/3D arrays): Restored images
561561 """
562+
562563 if isinstance (normalize , dict ):
563564 normalize_params = {** normalize_default , ** normalize }
564565 elif not isinstance (normalize , bool ):
@@ -578,8 +579,11 @@ def eval(self, x, batch_size=8, channels=None, channel_axis=None, z_axis=None,
578579 # turn off special normalization for segmentation
579580 normalize_params = normalize_default
580581
581- # change channels for segmentation (denoise model outputs up to 2 channels)
582- channels_new = [0 , 0 ] if channels [0 ] == 0 else [1 , 2 ]
582+ # change channels for segmentation
583+ if channels is not None :
584+ channels_new = [0 , 0 ] if channels [0 ] == 0 else [1 , 2 ]
585+ else :
586+ channels_new = None
583587 # change diameter if self.ratio > 1 (upsampled to self.dn.diam_mean)
584588 diameter = self .dn .diam_mean if self .dn .ratio > 1 else diameter
585589 masks , flows , styles = self .cp .eval (
@@ -759,7 +763,7 @@ def eval(self, x, batch_size=8, channels=None, channel_axis=None, z_axis=None,
759763 else :
760764 # reshape image
761765 x = transforms .convert_image (x , channels , channel_axis = channel_axis ,
762- z_axis = z_axis , do_3D = do_3D )
766+ z_axis = z_axis , do_3D = do_3D , nchan = None )
763767 if x .ndim < 4 :
764768 squeeze = True
765769 x = x [np .newaxis , ...]
@@ -790,7 +794,7 @@ def eval(self, x, batch_size=8, channels=None, channel_axis=None, z_axis=None,
790794 elif rescale is None :
791795 rescale = 1.0
792796
793- if np .ptp (x [..., - 1 ]) < 1e-3 or channels [- 1 ] == 0 :
797+ if np .ptp (x [..., - 1 ]) < 1e-3 or ( channels is not None and channels [- 1 ] == 0 ) :
794798 x = x [..., :1 ]
795799
796800 for c in range (x .shape [- 1 ]):
0 commit comments