@@ -219,9 +219,6 @@ def eval(self, x, batch_size=8, resample=None, channels=None, channel_axis=None,
219219 if channels is not None :
220220 models_logger .warning ("channels deprecated in v4.0.1+. If data contain more than 3 channels, only the first 3 channels will be used" )
221221
222- if self .device .type == 'mps' and do_3D :
223- raise ValueError ('MPS not working with 3D images. Disable GPU or use stitch threshold > 0' )
224-
225222 if isinstance (x , list ) or x .squeeze ().ndim == 5 :
226223 self .timing = []
227224 masks , styles , flows = [], [], []
@@ -421,6 +418,11 @@ def _compute_masks(self, shape, dP, cellprob, flow_threshold=0.4, cellprob_thres
421418 min_size = 15 , max_size_fraction = 0.4 , niter = None ,
422419 do_3D = False , stitch_threshold = 0.0 ):
423420 """ compute masks from flows and cell probability """
421+ changed_device_from = None
422+ if self .device .type == "mps" and do_3D :
423+ models_logger .warning ("MPS does not support 3D post-processing, switching to CPU" )
424+ self .device = torch .device ("cpu" )
425+ changed_device_from = "mps"
424426 Lz , Ly , Lx = shape [:3 ]
425427 tic = time .time ()
426428 if do_3D :
@@ -470,4 +472,7 @@ def _compute_masks(self, shape, dP, cellprob, flow_threshold=0.4, cellprob_thres
470472 if shape [0 ] > 1 :
471473 models_logger .info ("masks created in %2.2fs" % (flow_time ))
472474
475+ if changed_device_from is not None :
476+ models_logger .info ("switching back to device %s" % self .device )
477+ self .device = torch .device (changed_device_from )
473478 return masks
0 commit comments