@@ -125,6 +125,7 @@ def get_prediction(
125125 channels_to_standardize : Optional [List [int ]] = None ,
126126 mask : Optional [ArrayLike ] = None ,
127127 prediction : Optional [ArrayLike ] = None ,
128+ devices : Optional [List [str ]] = None ,
128129) -> ArrayLike :
129130 """Run prediction on a given volume.
130131
@@ -143,6 +144,8 @@ def get_prediction(
143144 the foreground region of the mask.
144145 prediction: An array like object for writing the prediction.
145146 If not given, the prediction will be computed in moemory.
147+ devices: The devices for running prediction. If not given will use the GPU
148+ if available, otherwise the CPU.
146149
147150 Returns:
148151 The predicted volume.
@@ -189,7 +192,7 @@ def get_prediction(
189192 # print(f"updated_tiling {updated_tiling}")
190193 prediction = get_prediction_torch_em (
191194 input_volume , updated_tiling , model_path , model , verbose , with_channels ,
192- mask = mask , prediction = prediction ,
195+ mask = mask , prediction = prediction , devices = devices ,
193196 )
194197
195198 return prediction
@@ -204,6 +207,7 @@ def get_prediction_torch_em(
204207 with_channels : bool = False ,
205208 mask : Optional [ArrayLike ] = None ,
206209 prediction : Optional [ArrayLike ] = None ,
210+ devices : Optional [List [str ]] = None ,
207211) -> np .ndarray :
208212 """Run prediction using torch-em on a given volume.
209213
@@ -218,6 +222,8 @@ def get_prediction_torch_em(
218222 the foreground region of the mask.
219223 prediction: An array like object for writing the prediction.
220224 If not given, the prediction will be computed in moemory.
225+ devices: The devices for running prediction. If not given will use the GPU
226+ if available, otherwise the CPU.
221227
222228 Returns:
223229 The predicted volume.
@@ -227,14 +233,15 @@ def get_prediction_torch_em(
227233 halo = [tiling ["halo" ]["z" ], tiling ["halo" ]["x" ], tiling ["halo" ]["y" ]]
228234
229235 t0 = time .time ()
230- device = "cuda" if torch .cuda .is_available () else "cpu"
236+ if devices is None :
237+ devices = ["cuda" if torch .cuda .is_available () else "cpu" ]
231238
232239 # Suppress warning when loading the model.
233240 with warnings .catch_warnings ():
234241 warnings .simplefilter ("ignore" )
235242 if model is None :
236243 if os .path .isdir (model_path ): # Load the model from a torch_em checkpoint.
237- model = torch_em .util .load_model (checkpoint = model_path , device = device )
244+ model = torch_em .util .load_model (checkpoint = model_path , device = devices [ 0 ] )
238245 else : # Load the model directly from a serialized pytorch model.
239246 model = torch .load (model_path , weights_only = False )
240247
@@ -253,7 +260,7 @@ def get_prediction_torch_em(
253260
254261 preprocess = None if isinstance (input_volume , np .ndarray ) else torch_em .transform .raw .standardize
255262 prediction = predict_with_halo (
256- input_volume , model , gpu_ids = [ device ] ,
263+ input_volume , model , gpu_ids = devices ,
257264 block_shape = block_shape , halo = halo ,
258265 preprocess = preprocess , with_channels = with_channels , mask = mask ,
259266 output = prediction ,
0 commit comments