Skip to content

Commit 79b5ca8

Browse files
committed
revert back to MPS after mask creation
1 parent ce50456 commit 79b5ca8

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

cellpose/models.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -418,10 +418,11 @@ def _compute_masks(self, shape, dP, cellprob, flow_threshold=0.4, cellprob_thres
418418
min_size=15, max_size_fraction=0.4, niter=None,
419419
do_3D=False, stitch_threshold=0.0):
420420
""" compute masks from flows and cell probability """
421-
421+
changed_device_from = None
422422
if self.device.type == "mps" and do_3D:
423423
models_logger.warning("MPS does not support 3D post-processing, switching to CPU")
424424
self.device = torch.device("cpu")
425+
changed_device_from = "mps"
425426
Lz, Ly, Lx = shape[:3]
426427
tic = time.time()
427428
if do_3D:
@@ -471,4 +472,7 @@ def _compute_masks(self, shape, dP, cellprob, flow_threshold=0.4, cellprob_thres
471472
if shape[0] > 1:
472473
models_logger.info("masks created in %2.2fs" % (flow_time))
473474

475+
if changed_device_from is not None:
476+
models_logger.info("switching back to device %s" % self.device)
477+
self.device = torch.device(changed_device_from)
474478
return masks

0 commit comments

Comments
 (0)