Skip to content

Commit 793e968

Browse files
authored
Merge pull request #202 from TGSAI/enh/ingestion_exception_handling
Ingestion: Replace `multiprocessing.Pool` with `concurrent.futures.ProcessPoolExecutor`
2 parents 79e7008 + 75bfa50 commit 793e968

File tree

3 files changed

+27
-54
lines changed

3 files changed

+27
-54
lines changed

src/mdio/segy/_workers.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -208,17 +208,3 @@ def trace_worker(
208208
max_val = tmp_data.max()
209209

210210
return count, chunk_sum, chunk_sum_squares, min_val, max_val
211-
212-
213-
# tqdm only works properly with pool.map
214-
# However, we need pool.starmap because we have more than one
215-
# argument to make pool.map work with multiple arguments, we
216-
# wrap the function and consolidate arguments to one
217-
def trace_worker_map(args):
218-
"""Wrapper for trace worker to use with tqdm."""
219-
return trace_worker(*args)
220-
221-
222-
def header_scan_worker_map(args):
223-
"""Wrapper for header scan worker to use with tqdm."""
224-
return header_scan_worker(*args)

src/mdio/segy/blocked_io.py

Lines changed: 17 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from __future__ import annotations
55

66
import multiprocessing as mp
7+
from concurrent.futures import ProcessPoolExecutor
78
from itertools import repeat
89

910
import numpy as np
@@ -19,7 +20,7 @@
1920

2021
from mdio.core import Grid
2122
from mdio.core.indexing import ChunkIterator
22-
from mdio.segy._workers import trace_worker_map
23+
from mdio.segy._workers import trace_worker
2324
from mdio.segy.byte_utils import ByteOrder
2425
from mdio.segy.byte_utils import Dtype
2526
from mdio.segy.creation import concat_files
@@ -132,35 +133,27 @@ def to_zarr(
132133
chunker = ChunkIterator(trace_array, chunk_samples=False)
133134
num_chunks = len(chunker)
134135

135-
# Setting all multiprocessing parameters.
136-
parallel_inputs = zip( # noqa: B905
137-
repeat(segy_path),
138-
repeat(trace_array),
139-
repeat(header_array),
140-
repeat(grid),
141-
chunker,
142-
repeat(segy_endian),
143-
)
144-
145-
# This is for Unix async writes to s3fs/fsspec, when using
146-
# multiprocessing. By default, Linux uses the 'fork' method.
147-
# 'spawn' is a little slower to spool up processes, but 'fork'
148-
# doesn't work. If you don't use this, processes get deadlocked
149-
# on cloud stores. 'spawn' is default in Windows.
136+
# For Unix async writes with s3fs/fsspec & multiprocessing,
137+
# use 'spawn' instead of default 'fork' to avoid deadlocks
138+
# on cloud stores. Slower but necessary. Default on Windows.
139+
num_workers = min(num_chunks, NUM_CORES)
150140
context = mp.get_context("spawn")
141+
executor = ProcessPoolExecutor(max_workers=num_workers, mp_context=context)
151142

152-
# This is the chunksize for multiprocessing. Not to be confused
153-
# with Zarr chunksize.
154-
num_workers = min(num_chunks, NUM_CORES)
143+
# Chunksize here is for multiprocessing, not Zarr chunksize.
155144
pool_chunksize, extra = divmod(num_chunks, num_workers * 4)
156145
pool_chunksize += 1 if extra else pool_chunksize
157146

158147
tqdm_kw = dict(unit="block", dynamic_ncols=True)
159-
with context.Pool(num_workers) as pool:
160-
# pool.imap is lazy
161-
lazy_work = pool.imap(
162-
func=trace_worker_map,
163-
iterable=parallel_inputs,
148+
with executor:
149+
lazy_work = executor.map(
150+
trace_worker, # fn
151+
repeat(segy_path),
152+
repeat(trace_array),
153+
repeat(header_array),
154+
repeat(grid),
155+
chunker,
156+
repeat(segy_endian),
164157
chunksize=pool_chunksize,
165158
)
166159

src/mdio/segy/parsers.py

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33

44
from __future__ import annotations
55

6+
from concurrent.futures import ProcessPoolExecutor
67
from itertools import repeat
78
from math import ceil
8-
from multiprocessing import Pool
99
from typing import Any
1010
from typing import Sequence
1111

@@ -15,7 +15,7 @@
1515
from tqdm.auto import tqdm
1616

1717
from mdio.core import Dimension
18-
from mdio.segy._workers import header_scan_worker_map
18+
from mdio.segy._workers import header_scan_worker
1919

2020

2121
NUM_CORES = cpu_count(logical=False)
@@ -104,24 +104,18 @@ def parse_trace_headers(
104104

105105
trace_ranges.append((start, stop))
106106

107-
# Note: Make sure the order of this is exactly
108-
# the same as the function call.
109-
parallel_inputs = zip( # noqa: B905 or strict=False >= py3.10
110-
repeat(segy_path),
111-
trace_ranges,
112-
repeat(byte_locs),
113-
repeat(byte_lengths),
114-
repeat(segy_endian),
115-
)
116-
117107
num_workers = min(n_blocks, NUM_CORES)
118108

119109
tqdm_kw = dict(unit="block", dynamic_ncols=True)
120-
with Pool(num_workers) as pool:
110+
with ProcessPoolExecutor(num_workers) as executor:
121111
# pool.imap is lazy
122-
lazy_work = pool.imap(
123-
func=header_scan_worker_map,
124-
iterable=parallel_inputs,
112+
lazy_work = executor.map(
113+
header_scan_worker, # fn
114+
repeat(segy_path),
115+
trace_ranges,
116+
repeat(byte_locs),
117+
repeat(byte_lengths),
118+
repeat(segy_endian),
125119
chunksize=2, # Not array chunks. This is for `multiprocessing`
126120
)
127121

0 commit comments

Comments
 (0)