Skip to content

Commit 1ca75d9

Browse files
authored
Merge pull request #1168 from MouseLand/mps_transformer_3d
Run transformer on MPS then do CPU post-processing
2 parents 785d339 + 79b5ca8 commit 1ca75d9

File tree

1 file changed

+8
-3
lines changed

1 file changed

+8
-3
lines changed

cellpose/models.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -219,9 +219,6 @@ def eval(self, x, batch_size=8, resample=None, channels=None, channel_axis=None,
219219
if channels is not None:
220220
models_logger.warning("channels deprecated in v4.0.1+. If data contain more than 3 channels, only the first 3 channels will be used")
221221

222-
if self.device.type == 'mps' and do_3D:
223-
raise ValueError('MPS not working with 3D images. Disable GPU or use stitch threshold > 0')
224-
225222
if isinstance(x, list) or x.squeeze().ndim == 5:
226223
self.timing = []
227224
masks, styles, flows = [], [], []
@@ -421,6 +418,11 @@ def _compute_masks(self, shape, dP, cellprob, flow_threshold=0.4, cellprob_thres
421418
min_size=15, max_size_fraction=0.4, niter=None,
422419
do_3D=False, stitch_threshold=0.0):
423420
""" compute masks from flows and cell probability """
421+
changed_device_from = None
422+
if self.device.type == "mps" and do_3D:
423+
models_logger.warning("MPS does not support 3D post-processing, switching to CPU")
424+
self.device = torch.device("cpu")
425+
changed_device_from = "mps"
424426
Lz, Ly, Lx = shape[:3]
425427
tic = time.time()
426428
if do_3D:
@@ -470,4 +472,7 @@ def _compute_masks(self, shape, dP, cellprob, flow_threshold=0.4, cellprob_thres
470472
if shape[0] > 1:
471473
models_logger.info("masks created in %2.2fs" % (flow_time))
472474

475+
if changed_device_from is not None:
476+
models_logger.info("switching back to device %s" % self.device)
477+
self.device = torch.device(changed_device_from)
473478
return masks

0 commit comments

Comments
 (0)