Skip to content

Commit eb2d3ed

Browse files
Support setting multiple devices in segmentation
1 parent 9c252ed commit eb2d3ed

File tree

2 files changed

+15
-5
lines changed

2 files changed

+15
-5
lines changed

synapse_net/inference/scalable_segmentation.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ def scalable_segmentation(
7979
prediction: Optional[ArrayLike] = None,
8080
verbose: bool = True,
8181
mask: Optional[ArrayLike] = None,
82+
devices: Optional[List[str]] = None,
8283
) -> None:
8384
"""Run segmentation based on a prediction with foreground and boundary channel.
8485
@@ -100,6 +101,8 @@ def scalable_segmentation(
100101
If given, this can be a numpy array, a zarr array, or similar
101102
If not given will be stored in a temporary n5 array.
102103
verbose: Whether to print timing information.
104+
devices: The devices for running prediction. If not given will use the GPU
105+
if available, otherwise the CPU.
103106
"""
104107
if mask is not None:
105108
raise NotImplementedError
@@ -133,5 +136,5 @@ def scalable_segmentation(
133136
seeds = f.create_dataset("seeds", shape=input_.shape, dtype="uint64", chunks=chunks)
134137

135138
# Run prediction and segmentation.
136-
get_prediction(input_, prediction=prediction, tiling=tiling, model=model, verbose=verbose)
139+
get_prediction(input_, prediction=prediction, tiling=tiling, model=model, verbose=verbose, devices=devices)
137140
_run_segmentation(prediction, output, seeds, chunks, seed_threshold, min_size, verbose, original_shape)

synapse_net/inference/util.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ def get_prediction(
125125
channels_to_standardize: Optional[List[int]] = None,
126126
mask: Optional[ArrayLike] = None,
127127
prediction: Optional[ArrayLike] = None,
128+
devices: Optional[List[str]] = None,
128129
) -> ArrayLike:
129130
"""Run prediction on a given volume.
130131
@@ -143,6 +144,8 @@ def get_prediction(
143144
the foreground region of the mask.
144145
prediction: An array like object for writing the prediction.
145146
If not given, the prediction will be computed in moemory.
147+
devices: The devices for running prediction. If not given will use the GPU
148+
if available, otherwise the CPU.
146149
147150
Returns:
148151
The predicted volume.
@@ -189,7 +192,7 @@ def get_prediction(
189192
# print(f"updated_tiling {updated_tiling}")
190193
prediction = get_prediction_torch_em(
191194
input_volume, updated_tiling, model_path, model, verbose, with_channels,
192-
mask=mask, prediction=prediction,
195+
mask=mask, prediction=prediction, devices=devices,
193196
)
194197

195198
return prediction
@@ -204,6 +207,7 @@ def get_prediction_torch_em(
204207
with_channels: bool = False,
205208
mask: Optional[ArrayLike] = None,
206209
prediction: Optional[ArrayLike] = None,
210+
devices: Optional[List[str]] = None,
207211
) -> np.ndarray:
208212
"""Run prediction using torch-em on a given volume.
209213
@@ -218,6 +222,8 @@ def get_prediction_torch_em(
218222
the foreground region of the mask.
219223
prediction: An array like object for writing the prediction.
220224
If not given, the prediction will be computed in moemory.
225+
devices: The devices for running prediction. If not given will use the GPU
226+
if available, otherwise the CPU.
221227
222228
Returns:
223229
The predicted volume.
@@ -227,14 +233,15 @@ def get_prediction_torch_em(
227233
halo = [tiling["halo"]["z"], tiling["halo"]["x"], tiling["halo"]["y"]]
228234

229235
t0 = time.time()
230-
device = "cuda" if torch.cuda.is_available() else "cpu"
236+
if devices is None:
237+
devices = ["cuda" if torch.cuda.is_available() else "cpu"]
231238

232239
# Suppress warning when loading the model.
233240
with warnings.catch_warnings():
234241
warnings.simplefilter("ignore")
235242
if model is None:
236243
if os.path.isdir(model_path): # Load the model from a torch_em checkpoint.
237-
model = torch_em.util.load_model(checkpoint=model_path, device=device)
244+
model = torch_em.util.load_model(checkpoint=model_path, device=devices[0])
238245
else: # Load the model directly from a serialized pytorch model.
239246
model = torch.load(model_path, weights_only=False)
240247

@@ -253,7 +260,7 @@ def get_prediction_torch_em(
253260

254261
preprocess = None if isinstance(input_volume, np.ndarray) else torch_em.transform.raw.standardize
255262
prediction = predict_with_halo(
256-
input_volume, model, gpu_ids=[device],
263+
input_volume, model, gpu_ids=devices,
257264
block_shape=block_shape, halo=halo,
258265
preprocess=preprocess, with_channels=with_channels, mask=mask,
259266
output=prediction,

0 commit comments

Comments
 (0)