@@ -125,6 +125,7 @@ def get_prediction(
125125    channels_to_standardize : Optional [List [int ]] =  None ,
126126    mask : Optional [ArrayLike ] =  None ,
127127    prediction : Optional [ArrayLike ] =  None ,
128+     devices : Optional [List [str ]] =  None ,
128129) ->  ArrayLike :
129130    """Run prediction on a given volume. 
130131
@@ -143,6 +144,8 @@ def get_prediction(
143144            the foreground region of the mask. 
144145        prediction: An array like object for writing the prediction. 
145146            If not given, the prediction will be computed in moemory. 
147+         devices: The devices for running prediction. If not given will use the GPU 
148+             if available, otherwise the CPU. 
146149
147150    Returns: 
148151        The predicted volume. 
@@ -189,7 +192,7 @@ def get_prediction(
189192        # print(f"updated_tiling {updated_tiling}") 
190193        prediction  =  get_prediction_torch_em (
191194            input_volume , updated_tiling , model_path , model , verbose , with_channels ,
192-             mask = mask , prediction = prediction ,
195+             mask = mask , prediction = prediction ,  devices = devices , 
193196        )
194197
195198    return  prediction 
@@ -204,6 +207,7 @@ def get_prediction_torch_em(
204207    with_channels : bool  =  False ,
205208    mask : Optional [ArrayLike ] =  None ,
206209    prediction : Optional [ArrayLike ] =  None ,
210+     devices : Optional [List [str ]] =  None ,
207211) ->  np .ndarray :
208212    """Run prediction using torch-em on a given volume. 
209213
@@ -218,6 +222,8 @@ def get_prediction_torch_em(
218222            the foreground region of the mask. 
219223        prediction: An array like object for writing the prediction. 
220224            If not given, the prediction will be computed in moemory. 
225+         devices: The devices for running prediction. If not given will use the GPU 
226+             if available, otherwise the CPU. 
221227
222228    Returns: 
223229        The predicted volume. 
@@ -227,14 +233,15 @@ def get_prediction_torch_em(
227233    halo  =  [tiling ["halo" ]["z" ], tiling ["halo" ]["x" ], tiling ["halo" ]["y" ]]
228234
229235    t0  =  time .time ()
230-     device  =  "cuda"  if  torch .cuda .is_available () else  "cpu" 
236+     if  devices  is  None :
237+         devices  =  ["cuda"  if  torch .cuda .is_available () else  "cpu" ]
231238
232239    # Suppress warning when loading the model. 
233240    with  warnings .catch_warnings ():
234241        warnings .simplefilter ("ignore" )
235242        if  model  is  None :
236243            if  os .path .isdir (model_path ):  # Load the model from a torch_em checkpoint. 
237-                 model  =  torch_em .util .load_model (checkpoint = model_path , device = device )
244+                 model  =  torch_em .util .load_model (checkpoint = model_path , device = devices [ 0 ] )
238245            else :  # Load the model directly from a serialized pytorch model. 
239246                model  =  torch .load (model_path , weights_only = False )
240247
@@ -253,7 +260,7 @@ def get_prediction_torch_em(
253260
254261        preprocess  =  None  if  isinstance (input_volume , np .ndarray ) else  torch_em .transform .raw .standardize 
255262        prediction  =  predict_with_halo (
256-             input_volume , model , gpu_ids = [ device ] ,
263+             input_volume , model , gpu_ids = devices ,
257264            block_shape = block_shape , halo = halo ,
258265            preprocess = preprocess , with_channels = with_channels , mask = mask ,
259266            output = prediction ,
0 commit comments