Skip to content

Commit fffa9f4

Browse files
Merge pull request #138 from MannLabs/dev_mp
Fix Out of Bounds issue upon multiple multiprocessing worker runs
2 parents 622a45f + 43f4a98 commit fffa9f4

File tree

2 files changed

+7
-3
lines changed

2 files changed

+7
-3
lines changed

src/scportrait/pipeline/segmentation/segmentation.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -742,7 +742,6 @@ def _resolve_sharding(self, sharding_plan):
742742
local_hf = h5py.File(local_output, "r")
743743
local_hdf_labels = local_hf.get(self.DEFAULT_MASK_NAME)[:]
744744

745-
print(type(local_hdf_labels))
746745
shifted_map, edge_labels = shift_labels(
747746
local_hdf_labels,
748747
class_id_shift,
@@ -902,8 +901,9 @@ def _resolve_sharding(self, sharding_plan):
902901
if not self.deep_debug:
903902
self._cleanup_shards(sharding_plan)
904903

905-
def _initializer_function(self, gpu_id_list):
904+
def _initializer_function(self, gpu_id_list, n_processes):
906905
current_process().gpu_id_list = gpu_id_list
906+
current_process().n_processes = n_processes
907907

908908
def _perform_segmentation(self, shard_list):
909909
# get GPU status
@@ -921,7 +921,7 @@ def _perform_segmentation(self, shard_list):
921921
with mp.get_context(self.context).Pool(
922922
processes=self.n_processes,
923923
initializer=self._initializer_function,
924-
initargs=[self.gpu_id_list],
924+
initargs=[self.gpu_id_list, self.n_processes],
925925
) as pool:
926926
list(
927927
tqdm(

src/scportrait/pipeline/segmentation/workflows.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from skimage.filters import median
1616
from skimage.morphology import binary_erosion, dilation, disk, erosion
1717
from skimage.segmentation import watershed
18+
import _multiprocessing
1819

1920
from scportrait.pipeline._utils.segmentation import (
2021
contact_filter,
@@ -1353,6 +1354,9 @@ def _check_gpu_status(self):
13531354
gpu_id_list = current.gpu_id_list
13541355
cpu_id = int(cpu_name[cpu_name.find("-") + 1 :]) - 1
13551356

1357+
if cpu_id >= len(gpu_id_list):
1358+
cpu_id = cpu_id%current.n_processes
1359+
13561360
# track gpu_id and update GPU status
13571361
self.gpu_id = gpu_id_list[cpu_id]
13581362
self.status = "multi_GPU"

0 commit comments

Comments
 (0)