|
4 | 4 | from __future__ import annotations |
5 | 5 |
|
6 | 6 | import multiprocessing as mp |
| 7 | +from concurrent.futures import ProcessPoolExecutor |
7 | 8 | from itertools import repeat |
8 | 9 |
|
9 | 10 | import numpy as np |
|
19 | 20 |
|
20 | 21 | from mdio.core import Grid |
21 | 22 | from mdio.core.indexing import ChunkIterator |
22 | | -from mdio.segy._workers import trace_worker_map |
| 23 | +from mdio.segy._workers import trace_worker |
23 | 24 | from mdio.segy.byte_utils import ByteOrder |
24 | 25 | from mdio.segy.byte_utils import Dtype |
25 | 26 | from mdio.segy.creation import concat_files |
@@ -132,35 +133,27 @@ def to_zarr( |
132 | 133 | chunker = ChunkIterator(trace_array, chunk_samples=False) |
133 | 134 | num_chunks = len(chunker) |
134 | 135 |
|
135 | | - # Setting all multiprocessing parameters. |
136 | | - parallel_inputs = zip( # noqa: B905 |
137 | | - repeat(segy_path), |
138 | | - repeat(trace_array), |
139 | | - repeat(header_array), |
140 | | - repeat(grid), |
141 | | - chunker, |
142 | | - repeat(segy_endian), |
143 | | - ) |
144 | | - |
145 | | - # This is for Unix async writes to s3fs/fsspec, when using |
146 | | - # multiprocessing. By default, Linux uses the 'fork' method. |
147 | | - # 'spawn' is a little slower to spool up processes, but 'fork' |
148 | | - # doesn't work. If you don't use this, processes get deadlocked |
149 | | - # on cloud stores. 'spawn' is default in Windows. |
| 136 | + # For Unix async writes with s3fs/fsspec & multiprocessing, |
| 137 | + # use 'spawn' instead of default 'fork' to avoid deadlocks |
| 138 | + # on cloud stores. Slower but necessary. Default on Windows. |
| 139 | + num_workers = min(num_chunks, NUM_CORES) |
150 | 140 | context = mp.get_context("spawn") |
| 141 | + executor = ProcessPoolExecutor(max_workers=num_workers, mp_context=context) |
151 | 142 |
|
152 | | - # This is the chunksize for multiprocessing. Not to be confused |
153 | | - # with Zarr chunksize. |
154 | | - num_workers = min(num_chunks, NUM_CORES) |
| 143 | + # Chunksize here is for multiprocessing, not Zarr chunksize. |
155 | 144 | pool_chunksize, extra = divmod(num_chunks, num_workers * 4) |
156 | 145 | pool_chunksize += 1 if extra else pool_chunksize |
157 | 146 |
|
158 | 147 | tqdm_kw = dict(unit="block", dynamic_ncols=True) |
159 | | - with context.Pool(num_workers) as pool: |
160 | | - # pool.imap is lazy |
161 | | - lazy_work = pool.imap( |
162 | | - func=trace_worker_map, |
163 | | - iterable=parallel_inputs, |
| 148 | + with executor: |
| 149 | + lazy_work = executor.map( |
| 150 | + trace_worker, # fn |
| 151 | + repeat(segy_path), |
| 152 | + repeat(trace_array), |
| 153 | + repeat(header_array), |
| 154 | + repeat(grid), |
| 155 | + chunker, |
| 156 | + repeat(segy_endian), |
164 | 157 | chunksize=pool_chunksize, |
165 | 158 | ) |
166 | 159 |
|
|
0 commit comments