Skip to content

Commit 42f0ab6

Browse files
committed
Runtime regression fix
1 parent 2e18d97 commit 42f0ab6

File tree

3 files changed

+38
-83
lines changed

3 files changed

+38
-83
lines changed

src/mdio/segy/_disaster_recovery_wrapper.py

Lines changed: 4 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -13,50 +13,6 @@
1313
from segy.transforms import Transform, TransformPipeline, ByteSwapTransform, IbmFloatTransform
1414
from numpy.typing import NDArray
1515

16-
17-
class HeaderRawTransformedAccessor:
18-
"""Utility class to access both raw and transformed header data with single filesystem read.
19-
20-
This class works as a consumer of SegyFile objects without modifying the package.
21-
It achieves the goal by:
22-
1. Reading raw data from filesystem once
23-
2. Applying transforms to get transformed data
24-
3. Keeping both versions available
25-
26-
The transforms used in SEG-Y processing are reversible:
27-
- ByteSwapTransform: Self-inverse (swapping twice returns to original)
28-
- IbmFloatTransform: Can be reversed by swapping direction
29-
"""
30-
31-
def __init__(self, segy_file: SegyFile):
32-
"""Initialize with a SegyFile instance.
33-
34-
Args:
35-
segy_file: The SegyFile instance to work with
36-
"""
37-
self.segy_file = segy_file
38-
self.transform_pipeline = self.segy_file.header.transform_pipeline
39-
40-
def _reverse_transforms(self, transformed_data: NDArray) -> NDArray:
41-
"""Reverse the transform pipeline to get raw data from transformed data.
42-
43-
Args:
44-
transformed_data: Data that has been processed through the transform pipeline
45-
46-
Returns:
47-
Raw data equivalent to what was read directly from filesystem
48-
"""
49-
# Start with the transformed data
50-
raw_data = transformed_data.copy() if hasattr(transformed_data, 'copy') else transformed_data
51-
52-
53-
# Apply transforms in reverse order with reversed operations
54-
for i, transform in enumerate(reversed(self.transform_pipeline.transforms)):
55-
raw_data = _reverse_single_transform(raw_data, transform)
56-
57-
return raw_data
58-
59-
@profile
6016
def _reverse_single_transform(data: NDArray, transform: Transform) -> NDArray:
6117
"""Reverse a single transform operation.
6218
@@ -98,11 +54,10 @@ def _reverse_single_transform(data: NDArray, transform: Transform) -> NDArray:
9854
# This maintains compatibility if new transforms are added
9955
return data
10056

101-
10257
def get_header_raw_and_transformed(
10358
segy_file: SegyFile,
10459
indices: int | list[int] | np.ndarray | slice
105-
) -> tuple[NDArray, NDArray]:
60+
) -> tuple[NDArray, NDArray, NDArray]:
10661
"""Convenience function to get both raw and transformed header data.
10762
10863
This is a drop-in replacement that provides the functionality you requested
@@ -127,38 +82,17 @@ def get_header_raw_and_transformed(
12782
# Slice of headers
12883
raw_hdrs, transformed_hdrs = get_header_raw_and_transformed(segy_file, slice(0, 10))
12984
"""
130-
return _get_header_raw_optimized(segy_file, indices)
131-
132-
@profile
133-
def _get_header_raw_optimized(
134-
segy_file: SegyFile,
135-
indices: int | list[int] | np.ndarray | slice
136-
) -> tuple[NDArray, NDArray]:
137-
"""Ultra-optimized function that eliminates double disk reads entirely.
13885

139-
This function:
140-
1. Gets transformed headers using the normal API (single disk read)
141-
2. Reverses the transforms on the already-loaded data (no second disk read)
142-
3. Returns both raw and transformed headers
86+
traces = segy_file.trace[indices]
14387

144-
Args:
145-
segy_file: The SegyFile instance
146-
indices: Which headers to retrieve
147-
148-
Returns:
149-
Tuple of (raw_headers, transformed_headers) where transformed_headers
150-
is the same as what segy_file.header[indices] would return
151-
"""
152-
# Get transformed headers using the normal API (single disk read)
153-
transformed_headers = segy_file.header[indices]
88+
transformed_headers = traces.header
15489

15590
# Reverse the transforms on the already-loaded transformed data
15691
# This eliminates the second disk read entirely!
15792
raw_headers = _reverse_transforms(transformed_headers, segy_file.header.transform_pipeline)
15893

159-
return raw_headers, transformed_headers
94+
return raw_headers, transformed_headers, traces
16095

161-
@profile
16296
def _reverse_transforms(transformed_data: NDArray, transform_pipeline) -> NDArray:
16397
"""Reverse the transform pipeline to get raw data from transformed data.
16498

src/mdio/segy/_workers.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,8 @@ def trace_worker( # noqa: PLR0913
121121
zarr_config.set({"threading.max_workers": 1})
122122

123123
live_trace_indexes = local_grid_map[not_null].tolist()
124-
traces = segy_file.trace[live_trace_indexes]
124+
# traces = segy_file.trace[live_trace_indexes]
125+
raw_headers, transformed_headers, traces = get_header_raw_and_transformed(segy_file, live_trace_indexes)
125126

126127
header_key = "headers"
127128
raw_header_key = "raw_headers"
@@ -135,7 +136,7 @@ def trace_worker( # noqa: PLR0913
135136
worker_variables.append(raw_header_key)
136137

137138
ds_to_write = dataset[worker_variables]
138-
raw_headers, transformed_headers = get_header_raw_and_transformed(segy_file, live_trace_indexes)
139+
# raw_headers, transformed_headers = get_header_raw_and_transformed(segy_file, live_trace_indexes)
139140

140141
if header_key in worker_variables:
141142
# Create temporary array for headers with the correct shape
@@ -153,7 +154,7 @@ def trace_worker( # noqa: PLR0913
153154
attrs=ds_to_write[header_key].attrs,
154155
encoding=ds_to_write[header_key].encoding, # Not strictly necessary, but safer than not doing it.
155156
)
156-
# del transformed_headers # Manage memory
157+
del transformed_headers # Manage memory
157158
if raw_header_key in worker_variables:
158159
tmp_raw_headers = np.zeros_like(dataset[raw_header_key])
159160
tmp_raw_headers[not_null] = raw_headers.view("|V240")
@@ -163,8 +164,8 @@ def trace_worker( # noqa: PLR0913
163164
attrs=ds_to_write[raw_header_key].attrs,
164165
encoding=ds_to_write[raw_header_key].encoding, # Not strictly necessary, but safer than not doing it.
165166
)
166-
del raw_headers # Manage memory
167167

168+
del raw_headers # Manage memory
168169
data_variable = ds_to_write[data_variable_name]
169170
fill_value = _get_fill_value(ScalarType(data_variable.dtype.name))
170171
tmp_samples = np.full_like(data_variable, fill_value=fill_value)

src/mdio/segy/blocked_io.py

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,10 @@
22

33
from __future__ import annotations
44

5+
import multiprocessing as mp
56
import os
7+
from concurrent.futures import ProcessPoolExecutor
8+
from concurrent.futures import as_completed
69
from pathlib import Path
710
from typing import TYPE_CHECKING
811

@@ -77,20 +80,37 @@ def to_zarr( # noqa: PLR0913, PLR0915
7780
chunk_iter = ChunkIterator(shape=data.shape, chunks=worker_chunks, dim_names=data.dims)
7881
num_chunks = chunk_iter.num_chunks
7982

83+
# For Unix async writes with s3fs/fsspec & multiprocessing, use 'spawn' instead of default
84+
# 'fork' to avoid deadlocks on cloud stores. Slower but necessary. Default on Windows.
85+
num_cpus = int(os.getenv("MDIO__IMPORT__CPU_COUNT", default_cpus))
86+
num_workers = min(num_chunks, num_cpus)
87+
context = mp.get_context("spawn")
88+
executor = ProcessPoolExecutor(max_workers=num_workers, mp_context=context)
89+
8090
segy_kw = {
8191
"url": segy_file.fs.unstrip_protocol(segy_file.url),
8292
"spec": segy_file.spec,
8393
"settings": segy_file.settings,
8494
}
95+
with executor:
96+
futures = []
97+
common_args = (segy_kw, output_path, data_variable_name)
98+
for region in chunk_iter:
99+
subset_args = (region, grid_map, dataset.isel(region))
100+
future = executor.submit(trace_worker, *common_args, *subset_args)
101+
futures.append(future)
102+
103+
iterable = tqdm(
104+
as_completed(futures),
105+
total=num_chunks,
106+
unit="block",
107+
desc="Ingesting traces",
108+
)
85109

86-
common_args = (segy_kw, output_path, data_variable_name)
87-
88-
# Execute trace_worker serially for profiling
89-
for region in tqdm(chunk_iter, total=num_chunks, unit="block", desc="Ingesting traces"):
90-
subset_args = (region, grid_map, dataset.isel(region))
91-
result = trace_worker(*common_args, *subset_args)
92-
if result is not None:
93-
_update_stats(final_stats, result)
110+
for future in iterable:
111+
result = future.result()
112+
if result is not None:
113+
_update_stats(final_stats, result)
94114

95115
# Xarray doesn't directly support incremental attribute updates when appending to an existing Zarr store.
96116
# HACK: We will update the array attribute using zarr's API directly.
@@ -260,4 +280,4 @@ def to_segy(
260280

261281
non_consecutive_axes -= 1
262282

263-
return block_io_records
283+
return block_io_records

0 commit comments

Comments
 (0)