Skip to content

Commit 17fb25f

Browse files
authored
Merge pull request #1374 from MouseLand/batch_img_refactor
Batch img refactor
2 parents 9604421 + f08576d commit 17fb25f

File tree

3 files changed

+58
-116
lines changed

3 files changed

+58
-116
lines changed

cellpose/models.py

Lines changed: 25 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -326,16 +326,37 @@ def eval(self, x, batch_size=8, resample=True, channels=None, channel_axis=None,
326326
torch.cuda.empty_cache()
327327
gc.collect()
328328

329-
if resample:
329+
if resample:
330+
# upsample flows flows before computing them:
331+
# dP = self._resize_gradients(dP, to_y_size=Ly_0, to_x_size=Lx_0, to_z_size=Lz_0)
332+
# cellprob = self._resize_cellprob(cellprob, to_x_size=Lx_0, to_y_size=Ly_0, to_z_size=Lz_0)
333+
334+
# resize XY then YZ and then put channels first
335+
dP = transforms.resize_image(dP.transpose(1, 2, 3, 0), Ly=Ly_0, Lx=Lx_0, no_channels=False)
336+
dP = transforms.resize_image(dP.transpose(1, 0, 2, 3), Lx=Lx_0, Ly=Lz_0, no_channels=False)
337+
dP = dP.transpose(3, 1, 0, 2)
338+
339+
# resize cellprob:
340+
cellprob = transforms.resize_image(cellprob, Ly=Ly_0, Lx=Lx_0, no_channels=True)
341+
cellprob = transforms.resize_image(cellprob.transpose(1, 0, 2), Lx=Lx_0, Ly=Lz_0, no_channels=True)
342+
cellprob = cellprob.transpose(1, 0, 2)
343+
344+
345+
# 2d case:
346+
if resample and not do_3D:
330347
# upsample flows before computing them:
331-
dP = self._resize_gradients(dP, to_y_size=Ly_0, to_x_size=Lx_0, to_z_size=Lz_0)
332-
cellprob = self._resize_cellprob(cellprob, to_x_size=Lx_0, to_y_size=Ly_0, to_z_size=Lz_0)
348+
# dP = self._resize_gradients(dP, to_y_size=Ly_0, to_x_size=Lx_0, to_z_size=Lz_0)
349+
# cellprob = self._resize_cellprob(cellprob, to_x_size=Lx_0, to_y_size=Ly_0, to_z_size=Lz_0)
350+
351+
# 2D images have N = 1 in batch dimension:
352+
dP = transforms.resize_image(dP.transpose(1, 2, 3, 0), Ly=Ly_0, Lx=Lx_0, no_channels=False).transpose(3, 0, 1, 2)
353+
cellprob = transforms.resize_image(cellprob, Ly=Ly_0, Lx=Lx_0, no_channels=True)
333354

334355
if compute_masks:
335356
# use user niter if specified, otherwise scale niter (200) with diameter
336357
niter_scale = 1 if image_scaling is None else image_scaling
337358
niter = int(200/niter_scale) if niter is None or niter == 0 else niter
338-
masks = self._compute_masks(x.shape, dP, cellprob, flow_threshold=flow_threshold,
359+
masks = self._compute_masks((Lz_0 or nimg, Ly_0, Lx_0), dP, cellprob, flow_threshold=flow_threshold,
339360
cellprob_threshold=cellprob_threshold, min_size=min_size,
340361
max_size_fraction=max_size_fraction, niter=niter,
341362
stitch_threshold=stitch_threshold, do_3D=do_3D)
@@ -344,112 +365,9 @@ def eval(self, x, batch_size=8, resample=True, channels=None, channel_axis=None,
344365

345366
masks, dP, cellprob = masks.squeeze(), dP.squeeze(), cellprob.squeeze()
346367

347-
# undo resizing:
348-
if image_scaling is not None or anisotropy is not None:
349-
350-
dP = self._resize_gradients(dP, to_y_size=Ly_0, to_x_size=Lx_0, to_z_size=Lz_0) # works for 2 or 3D:
351-
cellprob = self._resize_cellprob(cellprob, to_x_size=Lx_0, to_y_size=Ly_0, to_z_size=Lz_0)
352-
353-
if do_3D:
354-
if compute_masks:
355-
# Rescale xy then xz:
356-
masks = transforms.resize_image(masks, Ly=Ly_0, Lx=Lx_0, no_channels=True, interpolation=cv2.INTER_NEAREST)
357-
masks = masks.transpose(1, 0, 2)
358-
masks = transforms.resize_image(masks, Ly=Lz_0, Lx=Lx_0, no_channels=True, interpolation=cv2.INTER_NEAREST)
359-
masks = masks.transpose(1, 0, 2)
360-
361-
else:
362-
# 2D or 3D stitching case:
363-
if compute_masks:
364-
masks = transforms.resize_image(masks, Ly=Ly_0, Lx=Lx_0, no_channels=True, interpolation=cv2.INTER_NEAREST)
365-
366368
return masks, [plot.dx_to_circ(dP), dP, cellprob], styles
367369

368370

369-
def _resize_cellprob(self, prob: np.ndarray, to_y_size: int, to_x_size: int, to_z_size: int = None) -> np.ndarray:
370-
"""
371-
Resize cellprob array to specified dimensions for either 2D or 3D.
372-
373-
Parameters:
374-
prob (numpy.ndarray): The cellprobs to resize, either in 2D or 3D. Returns the same ndim as provided.
375-
to_y_size (int): The target size along the Y-axis.
376-
to_x_size (int): The target size along the X-axis.
377-
to_z_size (int, optional): The target size along the Z-axis. Required
378-
for 3D cellprobs.
379-
380-
Returns:
381-
numpy.ndarray: The resized cellprobs array with the same number of dimensions
382-
as the input.
383-
384-
Raises:
385-
ValueError: If the input cellprobs array does not have 3 or 4 dimensions.
386-
"""
387-
prob_shape = prob.shape
388-
prob = prob.squeeze()
389-
squeeze_happened = prob.shape != prob_shape
390-
prob_shape = np.array(prob_shape)
391-
392-
if prob.ndim == 2:
393-
# 2D case:
394-
prob = transforms.resize_image(prob, Ly=to_y_size, Lx=to_x_size, no_channels=True)
395-
if squeeze_happened:
396-
prob = np.expand_dims(prob, int(np.argwhere(prob_shape == 1))) # add back empty axis for compatibility
397-
elif prob.ndim == 3:
398-
# 3D case:
399-
prob = transforms.resize_image(prob, Ly=to_y_size, Lx=to_x_size, no_channels=True)
400-
prob = prob.transpose(1, 0, 2)
401-
prob = transforms.resize_image(prob, Ly=to_z_size, Lx=to_x_size, no_channels=True)
402-
prob = prob.transpose(1, 0, 2)
403-
else:
404-
raise ValueError(f'gradients have incorrect dimension after squeezing. Should be 2 or 3, prob shape: {prob.shape}')
405-
406-
return prob
407-
408-
409-
def _resize_gradients(self, grads: np.ndarray, to_y_size: int, to_x_size: int, to_z_size: int = None) -> np.ndarray:
410-
"""
411-
Resize gradient arrays to specified dimensions for either 2D or 3D gradients.
412-
413-
Parameters:
414-
grads (np.ndarray): The gradients to resize, either in 2D or 3D. Returns the same ndim as provided.
415-
to_y_size (int): The target size along the Y-axis.
416-
to_x_size (int): The target size along the X-axis.
417-
to_z_size (int, optional): The target size along the Z-axis. Required
418-
for 3D gradients.
419-
420-
Returns:
421-
numpy.ndarray: The resized gradient array with the same number of dimensions
422-
as the input.
423-
424-
Raises:
425-
ValueError: If the input gradient array does not have 3 or 4 dimensions.
426-
"""
427-
grads_shape = grads.shape
428-
grads = grads.squeeze()
429-
squeeze_happened = grads.shape != grads_shape
430-
grads_shape = np.array(grads_shape)
431-
432-
if grads.ndim == 3:
433-
# 2D case, with XY flows in 2 channels:
434-
grads = np.moveaxis(grads, 0, -1) # Put gradients last
435-
grads = transforms.resize_image(grads, Ly=to_y_size, Lx=to_x_size, no_channels=False)
436-
grads = np.moveaxis(grads, -1, 0) # Put gradients first
437-
438-
if squeeze_happened:
439-
grads = np.expand_dims(grads, int(np.argwhere(grads_shape == 1))) # add back empty axis for compatibility
440-
elif grads.ndim == 4:
441-
# dP has gradients that can be treated as channels:
442-
grads = grads.transpose(1, 2, 3, 0) # move gradients last:
443-
grads = transforms.resize_image(grads, Ly=to_y_size, Lx=to_x_size, no_channels=False)
444-
grads = grads.transpose(1, 0, 2, 3) # switch axes to resize again
445-
grads = transforms.resize_image(grads, Ly=to_z_size, Lx=to_x_size, no_channels=False)
446-
grads = grads.transpose(3, 1, 0, 2) # undo transposition
447-
else:
448-
raise ValueError(f'gradients have incorrect dimension after squeezing. Should be 3 or 4, grads shape: {grads.shape}')
449-
450-
return grads
451-
452-
453371
def _run_net(self, x,
454372
augment=False,
455373
batch_size=8, tile_overlap=0.1,

cellpose/transforms.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -545,6 +545,7 @@ def convert_image(x, channel_axis=None, z_axis=None, do_3D=False):
545545
Accepts:
546546
- 2D images with no channel dimension: `z_axis` and `channel_axis` must be `None`
547547
- 2D images with channel dimension: `channel_axis` will be guessed between first or last axis, can also specify `channel_axis`. `z_axis` must be `None`
548+
- Batch of 2D images having shape: [N, H, W, C] with N images in the batch
548549
- 3D images with or without channels:
549550
550551
Args:
@@ -554,11 +555,10 @@ def convert_image(x, channel_axis=None, z_axis=None, do_3D=False):
554555
do_3D (bool): Whether to process the image in 3D mode. Defaults to False.
555556
556557
Returns:
557-
numpy.ndarray: The converted image.
558+
numpy.ndarray: The converted image with channels last.
558559
559560
Raises:
560561
ValueError: If the input image is 2D and do_3D is True.
561-
ValueError: If the input image is 4D and do_3D is False.
562562
"""
563563

564564
# check if image is a torch array instead of numpy array, convert to numpy
@@ -571,10 +571,6 @@ def convert_image(x, channel_axis=None, z_axis=None, do_3D=False):
571571
if z_axis is not None and not do_3D:
572572
raise ValueError("2D image provided, but z_axis is not None. Set z_axis=None to process 2D images of ndim=2 or 3.")
573573

574-
# make sure that channel_axis and z_axis are specified if 3D
575-
if ndim == 4 and not do_3D:
576-
raise ValueError("3D input image provided, but do_3D is False. Set do_3D=True to process 3D images. ndims=4")
577-
578574
# make sure that channel_axis and z_axis are specified if 3D
579575
if do_3D:
580576
return _convert_image_3d(x, channel_axis=channel_axis, z_axis=z_axis)
@@ -616,6 +612,7 @@ def convert_image(x, channel_axis=None, z_axis=None, do_3D=False):
616612
x_out[..., 0] = x
617613
x = x_out
618614
del x_out
615+
transforms_logger.info(f'processing grayscale image with {x.shape[0], x.shape[1]} HW')
619616
elif ndim == 3:
620617
# assume 2d with channels
621618
# find dim with smaller size between first and last dims
@@ -632,6 +629,20 @@ def convert_image(x, channel_axis=None, z_axis=None, do_3D=False):
632629
x_out[..., :num_channels] = x[..., :num_channels]
633630
x = x_out
634631
del x_out
632+
transforms_logger.info(f'processing image with {x.shape[0], x.shape[1]} HW, and {x.shape[2]} channels')
633+
elif ndim == 4:
634+
# assume batch of 2d with channels
635+
636+
# zero padding up to 3 channels:
637+
num_channels = x.shape[-1]
638+
if num_channels > 3:
639+
transforms_logger.warning("Found more than 3 channels, only using first 3")
640+
num_channels = 3
641+
x_out = np.zeros((x.shape[0], x.shape[1], x.shape[2], 3), dtype=x.dtype)
642+
x_out[..., :num_channels] = x[..., :num_channels]
643+
x = x_out
644+
del x_out
645+
transforms_logger.info(f'processing image batch with {x.shape[0]} images, {x.shape[1], x.shape[2]} HW, and {x.shape[3]} channels')
635646
else:
636647
# something is wrong: yell
637648
expected_shapes = "2D (H, W), 3D (H, W, C), or 4D (Z, H, W, C)"

tests/test_output.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from cellpose import io, metrics, utils, models
22
import pytest
33
from subprocess import check_output, STDOUT
4+
from pathlib import Path
45
import os
56
import numpy as np
67

@@ -43,7 +44,9 @@ def clear_output(data_dir, image_names):
4344
(True, True, 40),
4445
(True, True, None),
4546
(False, True, None),
46-
(False, False, None)
47+
(False, False, None),
48+
(True, False, None),
49+
(True, False, 40)
4750
]
4851
)
4952
def test_class_2D_one_img(data_dir, image_names, cellposemodel_fixture_24layer, compute_masks, resample, diameter):
@@ -56,11 +59,21 @@ def test_class_2D_one_img(data_dir, image_names, cellposemodel_fixture_24layer,
5659

5760
masks_pred, _, _ = cellposemodel_fixture_24layer.eval(img, normalize=True, compute_masks=compute_masks, resample=resample, diameter=diameter)
5861

59-
if not compute_masks or diameter:
62+
if not compute_masks:
6063
# not compute_masks won't return masks so can't check
61-
# different diameter will give different masks, so can't check
6264
return
6365

66+
if diameter and compute_masks:
67+
# size of masks will be different, so need to adjust calculation
68+
masks_gt_file = Path(str(img_file).replace('_tif.tif', '_tif_cp4_gt_masks.png'))
69+
masks_gt = io.imread_2D(masks_gt_file)
70+
71+
masks_pred_shape = [int(s * diameter/30) for s in masks_pred.shape]
72+
assert [a == b for a, b in zip(masks_gt.shape[:2], masks_pred_shape[:2])]
73+
74+
# don't compare the images, because they are different sizes and won't match
75+
return
76+
6477
io.imsave(data_dir / '2D' / (img_file.stem + "_cp_masks.png"), masks_pred)
6578
# flowsp_pred = np.concatenate([flows_pred[1], flows_pred[2][None, ...]], axis=0)
6679
# mse = np.sqrt((flowsp_pred - flowps) ** 2).sum()

0 commit comments

Comments
 (0)