@@ -326,16 +326,37 @@ def eval(self, x, batch_size=8, resample=True, channels=None, channel_axis=None,
326326 torch .cuda .empty_cache ()
327327 gc .collect ()
328328
329- if resample :
329+ if resample :
330+ # upsample flows flows before computing them:
331+ # dP = self._resize_gradients(dP, to_y_size=Ly_0, to_x_size=Lx_0, to_z_size=Lz_0)
332+ # cellprob = self._resize_cellprob(cellprob, to_x_size=Lx_0, to_y_size=Ly_0, to_z_size=Lz_0)
333+
334+ # resize XY then YZ and then put channels first
335+ dP = transforms .resize_image (dP .transpose (1 , 2 , 3 , 0 ), Ly = Ly_0 , Lx = Lx_0 , no_channels = False )
336+ dP = transforms .resize_image (dP .transpose (1 , 0 , 2 , 3 ), Lx = Lx_0 , Ly = Lz_0 , no_channels = False )
337+ dP = dP .transpose (3 , 1 , 0 , 2 )
338+
339+ # resize cellprob:
340+ cellprob = transforms .resize_image (cellprob , Ly = Ly_0 , Lx = Lx_0 , no_channels = True )
341+ cellprob = transforms .resize_image (cellprob .transpose (1 , 0 , 2 ), Lx = Lx_0 , Ly = Lz_0 , no_channels = True )
342+ cellprob = cellprob .transpose (1 , 0 , 2 )
343+
344+
345+ # 2d case:
346+ if resample and not do_3D :
330347 # upsample flows before computing them:
331- dP = self ._resize_gradients (dP , to_y_size = Ly_0 , to_x_size = Lx_0 , to_z_size = Lz_0 )
332- cellprob = self ._resize_cellprob (cellprob , to_x_size = Lx_0 , to_y_size = Ly_0 , to_z_size = Lz_0 )
348+ # dP = self._resize_gradients(dP, to_y_size=Ly_0, to_x_size=Lx_0, to_z_size=Lz_0)
349+ # cellprob = self._resize_cellprob(cellprob, to_x_size=Lx_0, to_y_size=Ly_0, to_z_size=Lz_0)
350+
351+ # 2D images have N = 1 in batch dimension:
352+ dP = transforms .resize_image (dP .transpose (1 , 2 , 3 , 0 ), Ly = Ly_0 , Lx = Lx_0 , no_channels = False ).transpose (3 , 0 , 1 , 2 )
353+ cellprob = transforms .resize_image (cellprob , Ly = Ly_0 , Lx = Lx_0 , no_channels = True )
333354
334355 if compute_masks :
335356 # use user niter if specified, otherwise scale niter (200) with diameter
336357 niter_scale = 1 if image_scaling is None else image_scaling
337358 niter = int (200 / niter_scale ) if niter is None or niter == 0 else niter
338- masks = self ._compute_masks (x . shape , dP , cellprob , flow_threshold = flow_threshold ,
359+ masks = self ._compute_masks (( Lz_0 or nimg , Ly_0 , Lx_0 ) , dP , cellprob , flow_threshold = flow_threshold ,
339360 cellprob_threshold = cellprob_threshold , min_size = min_size ,
340361 max_size_fraction = max_size_fraction , niter = niter ,
341362 stitch_threshold = stitch_threshold , do_3D = do_3D )
@@ -344,112 +365,9 @@ def eval(self, x, batch_size=8, resample=True, channels=None, channel_axis=None,
344365
345366 masks , dP , cellprob = masks .squeeze (), dP .squeeze (), cellprob .squeeze ()
346367
347- # undo resizing:
348- if image_scaling is not None or anisotropy is not None :
349-
350- dP = self ._resize_gradients (dP , to_y_size = Ly_0 , to_x_size = Lx_0 , to_z_size = Lz_0 ) # works for 2 or 3D:
351- cellprob = self ._resize_cellprob (cellprob , to_x_size = Lx_0 , to_y_size = Ly_0 , to_z_size = Lz_0 )
352-
353- if do_3D :
354- if compute_masks :
355- # Rescale xy then xz:
356- masks = transforms .resize_image (masks , Ly = Ly_0 , Lx = Lx_0 , no_channels = True , interpolation = cv2 .INTER_NEAREST )
357- masks = masks .transpose (1 , 0 , 2 )
358- masks = transforms .resize_image (masks , Ly = Lz_0 , Lx = Lx_0 , no_channels = True , interpolation = cv2 .INTER_NEAREST )
359- masks = masks .transpose (1 , 0 , 2 )
360-
361- else :
362- # 2D or 3D stitching case:
363- if compute_masks :
364- masks = transforms .resize_image (masks , Ly = Ly_0 , Lx = Lx_0 , no_channels = True , interpolation = cv2 .INTER_NEAREST )
365-
366368 return masks , [plot .dx_to_circ (dP ), dP , cellprob ], styles
367369
368370
369- def _resize_cellprob (self , prob : np .ndarray , to_y_size : int , to_x_size : int , to_z_size : int = None ) -> np .ndarray :
370- """
371- Resize cellprob array to specified dimensions for either 2D or 3D.
372-
373- Parameters:
374- prob (numpy.ndarray): The cellprobs to resize, either in 2D or 3D. Returns the same ndim as provided.
375- to_y_size (int): The target size along the Y-axis.
376- to_x_size (int): The target size along the X-axis.
377- to_z_size (int, optional): The target size along the Z-axis. Required
378- for 3D cellprobs.
379-
380- Returns:
381- numpy.ndarray: The resized cellprobs array with the same number of dimensions
382- as the input.
383-
384- Raises:
385- ValueError: If the input cellprobs array does not have 3 or 4 dimensions.
386- """
387- prob_shape = prob .shape
388- prob = prob .squeeze ()
389- squeeze_happened = prob .shape != prob_shape
390- prob_shape = np .array (prob_shape )
391-
392- if prob .ndim == 2 :
393- # 2D case:
394- prob = transforms .resize_image (prob , Ly = to_y_size , Lx = to_x_size , no_channels = True )
395- if squeeze_happened :
396- prob = np .expand_dims (prob , int (np .argwhere (prob_shape == 1 ))) # add back empty axis for compatibility
397- elif prob .ndim == 3 :
398- # 3D case:
399- prob = transforms .resize_image (prob , Ly = to_y_size , Lx = to_x_size , no_channels = True )
400- prob = prob .transpose (1 , 0 , 2 )
401- prob = transforms .resize_image (prob , Ly = to_z_size , Lx = to_x_size , no_channels = True )
402- prob = prob .transpose (1 , 0 , 2 )
403- else :
404- raise ValueError (f'gradients have incorrect dimension after squeezing. Should be 2 or 3, prob shape: { prob .shape } ' )
405-
406- return prob
407-
408-
409- def _resize_gradients (self , grads : np .ndarray , to_y_size : int , to_x_size : int , to_z_size : int = None ) -> np .ndarray :
410- """
411- Resize gradient arrays to specified dimensions for either 2D or 3D gradients.
412-
413- Parameters:
414- grads (np.ndarray): The gradients to resize, either in 2D or 3D. Returns the same ndim as provided.
415- to_y_size (int): The target size along the Y-axis.
416- to_x_size (int): The target size along the X-axis.
417- to_z_size (int, optional): The target size along the Z-axis. Required
418- for 3D gradients.
419-
420- Returns:
421- numpy.ndarray: The resized gradient array with the same number of dimensions
422- as the input.
423-
424- Raises:
425- ValueError: If the input gradient array does not have 3 or 4 dimensions.
426- """
427- grads_shape = grads .shape
428- grads = grads .squeeze ()
429- squeeze_happened = grads .shape != grads_shape
430- grads_shape = np .array (grads_shape )
431-
432- if grads .ndim == 3 :
433- # 2D case, with XY flows in 2 channels:
434- grads = np .moveaxis (grads , 0 , - 1 ) # Put gradients last
435- grads = transforms .resize_image (grads , Ly = to_y_size , Lx = to_x_size , no_channels = False )
436- grads = np .moveaxis (grads , - 1 , 0 ) # Put gradients first
437-
438- if squeeze_happened :
439- grads = np .expand_dims (grads , int (np .argwhere (grads_shape == 1 ))) # add back empty axis for compatibility
440- elif grads .ndim == 4 :
441- # dP has gradients that can be treated as channels:
442- grads = grads .transpose (1 , 2 , 3 , 0 ) # move gradients last:
443- grads = transforms .resize_image (grads , Ly = to_y_size , Lx = to_x_size , no_channels = False )
444- grads = grads .transpose (1 , 0 , 2 , 3 ) # switch axes to resize again
445- grads = transforms .resize_image (grads , Ly = to_z_size , Lx = to_x_size , no_channels = False )
446- grads = grads .transpose (3 , 1 , 0 , 2 ) # undo transposition
447- else :
448- raise ValueError (f'gradients have incorrect dimension after squeezing. Should be 3 or 4, grads shape: { grads .shape } ' )
449-
450- return grads
451-
452-
453371 def _run_net (self , x ,
454372 augment = False ,
455373 batch_size = 8 , tile_overlap = 0.1 ,
0 commit comments