Skip to content

Commit 81662c7

Browse files
committed
Working example
1 parent b727a2b commit 81662c7

File tree

3 files changed

+243
-57
lines changed

3 files changed

+243
-57
lines changed
Lines changed: 227 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,227 @@
1+
"""Consumer-side utility to get both raw and transformed header data with single filesystem read."""
2+
3+
from __future__ import annotations
4+
5+
import numpy as np
6+
from typing import TYPE_CHECKING
7+
from segy.transforms import ByteSwapTransform
8+
from segy.transforms import IbmFloatTransform
9+
10+
if TYPE_CHECKING:
11+
from segy.file import SegyFile
12+
from segy.indexing import HeaderIndexer
13+
from segy.transforms import Transform, TransformPipeline, ByteSwapTransform, IbmFloatTransform
14+
from numpy.typing import NDArray
15+
16+
17+
def debug_compare_raw_vs_processed(segy_file, trace_index=0):
18+
"""Debug function to compare raw filesystem data vs processed data."""
19+
from segy.indexing import HeaderIndexer
20+
21+
# Create a fresh indexer to get raw data
22+
indexer = HeaderIndexer(
23+
segy_file.fs,
24+
segy_file.url,
25+
segy_file.spec.trace,
26+
segy_file.num_traces,
27+
transform_pipeline=None # No transforms = raw data
28+
)
29+
30+
# Get raw data directly from filesystem
31+
raw_data = indexer[trace_index]
32+
33+
# Get processed data with transforms
34+
processed_data = segy_file.header[trace_index]
35+
36+
print("=== Raw vs Processed Comparison ===")
37+
print(f"Raw data shape: {raw_data.shape}")
38+
print(f"Processed data shape: {processed_data.shape}")
39+
40+
if hasattr(raw_data, 'dtype') and raw_data.dtype.names:
41+
if 'inline_number' in raw_data.dtype.names:
42+
print(f"Raw inline_number: {raw_data['inline_number']}")
43+
print(f"Raw inline_number (hex): {raw_data['inline_number']:08x}")
44+
print(f"Processed inline_number: {processed_data['inline_number']}")
45+
print(f"Processed inline_number (hex): {processed_data['inline_number']:08x}")
46+
print(f"Are they equal? {raw_data['inline_number'] == processed_data['inline_number']}")
47+
48+
return raw_data, processed_data
49+
50+
51+
class HeaderRawTransformedAccessor:
52+
"""Utility class to access both raw and transformed header data with single filesystem read.
53+
54+
This class works as a consumer of SegyFile objects without modifying the package.
55+
It achieves the goal by:
56+
1. Reading raw data from filesystem once
57+
2. Applying transforms to get transformed data
58+
3. Keeping both versions available
59+
60+
The transforms used in SEG-Y processing are reversible:
61+
- ByteSwapTransform: Self-inverse (swapping twice returns to original)
62+
- IbmFloatTransform: Can be reversed by swapping direction
63+
"""
64+
65+
def __init__(self, segy_file: SegyFile):
66+
"""Initialize with a SegyFile instance.
67+
68+
Args:
69+
segy_file: The SegyFile instance to work with
70+
"""
71+
self.segy_file = segy_file
72+
self.header_indexer = segy_file.header
73+
self.transform_pipeline = self.header_indexer.transform_pipeline
74+
75+
# Debug: Print transform pipeline information
76+
import sys
77+
print(f"Debug: System endianness: {sys.byteorder}")
78+
print(f"Debug: File endianness: {self.segy_file.spec.endianness}")
79+
print(f"Debug: Transform pipeline has {len(self.transform_pipeline.transforms)} transforms:")
80+
for i, transform in enumerate(self.transform_pipeline.transforms):
81+
print(f" Transform {i}: {type(transform).__name__}")
82+
if hasattr(transform, 'target_order'):
83+
print(f" Target order: {transform.target_order}")
84+
if hasattr(transform, 'direction'):
85+
print(f" Direction: {transform.direction}")
86+
if hasattr(transform, 'keys'):
87+
print(f" Keys: {transform.keys}")
88+
89+
def get_raw_and_transformed(
90+
self, indices: int | list[int] | np.ndarray | slice
91+
) -> tuple[NDArray, NDArray]:
92+
"""Get both raw and transformed header data with single filesystem read.
93+
94+
Args:
95+
indices: Which headers to retrieve (int, list, ndarray, or slice)
96+
97+
Returns:
98+
Tuple of (raw_headers, transformed_headers)
99+
"""
100+
# Get the transformed data using the normal API
101+
# This reads from filesystem and applies transforms
102+
transformed_data = self.header_indexer[indices]
103+
104+
print(f"Debug: Transformed data shape: {transformed_data.shape}")
105+
if hasattr(transformed_data, 'dtype') and transformed_data.dtype.names:
106+
print(f"Debug: Transformed data dtype names: {transformed_data.dtype.names[:5]}...") # First 5 fields
107+
if 'inline_number' in transformed_data.dtype.names:
108+
print(f"Debug: First transformed inline_number: {transformed_data['inline_number'][0]}")
109+
print(f"Debug: First transformed inline_number (hex): {transformed_data['inline_number'][0]:08x}")
110+
111+
# Now reverse the transforms to get back to raw data
112+
raw_data = self._reverse_transforms(transformed_data)
113+
114+
print(f"Debug: Raw data shape: {raw_data.shape}")
115+
if hasattr(raw_data, 'dtype') and raw_data.dtype.names:
116+
if 'inline_number' in raw_data.dtype.names:
117+
print(f"Debug: First raw inline_number: {raw_data['inline_number'][0]}")
118+
print(f"Debug: First raw inline_number (hex): {raw_data['inline_number'][0]:08x}")
119+
120+
return raw_data, transformed_data
121+
122+
def _reverse_transforms(self, transformed_data: NDArray) -> NDArray:
123+
"""Reverse the transform pipeline to get raw data from transformed data.
124+
125+
Args:
126+
transformed_data: Data that has been processed through the transform pipeline
127+
128+
Returns:
129+
Raw data equivalent to what was read directly from filesystem
130+
"""
131+
# Start with the transformed data
132+
raw_data = transformed_data.copy() if hasattr(transformed_data, 'copy') else transformed_data
133+
134+
print(f"Debug: Starting reversal with {len(self.transform_pipeline.transforms)} transforms")
135+
136+
# Apply transforms in reverse order with reversed operations
137+
for i, transform in enumerate(reversed(self.transform_pipeline.transforms)):
138+
print(f"Debug: Reversing transform {len(self.transform_pipeline.transforms)-1-i}: {type(transform).__name__}")
139+
if 'inline_number' in raw_data.dtype.names:
140+
print(f"Debug: Before reversal - inline_number: {raw_data['inline_number'][0]:08x}")
141+
raw_data = self._reverse_single_transform(raw_data, transform)
142+
if 'inline_number' in raw_data.dtype.names:
143+
print(f"Debug: After reversal - inline_number: {raw_data['inline_number'][0]:08x}")
144+
145+
return raw_data
146+
147+
def _reverse_single_transform(self, data: NDArray, transform: Transform) -> NDArray:
148+
"""Reverse a single transform operation.
149+
150+
Args:
151+
data: The data to reverse transform
152+
transform: The transform to reverse
153+
154+
Returns:
155+
Data with the transform reversed
156+
"""
157+
# Import here to avoid circular imports
158+
from segy.transforms import get_endianness
159+
from segy.schema import Endianness
160+
161+
if isinstance(transform, ByteSwapTransform):
162+
# For byte swap, we need to reverse the endianness conversion
163+
# If the transform was converting to little-endian, we need to convert back to big-endian
164+
print(f"Debug: Reversing byte swap (target was: {transform.target_order})")
165+
166+
# Get current data endianness
167+
current_endianness = get_endianness(data)
168+
print(f"Debug: Current data endianness: {current_endianness}")
169+
170+
# If transform was converting TO little-endian, we need to convert TO big-endian
171+
if transform.target_order == Endianness.LITTLE:
172+
reverse_target = Endianness.BIG
173+
else:
174+
reverse_target = Endianness.LITTLE
175+
176+
print(f"Debug: Reversing to target: {reverse_target}")
177+
reverse_transform = ByteSwapTransform(reverse_target)
178+
result = reverse_transform.apply(data)
179+
180+
if 'inline_number' in data.dtype.names:
181+
print(f"Debug: Byte swap reversal - before: {data['inline_number'][0]:08x}, after: {result['inline_number'][0]:08x}")
182+
return result
183+
184+
elif isinstance(transform, IbmFloatTransform):
185+
# Reverse IBM float conversion by swapping direction
186+
reverse_direction = "to_ibm" if transform.direction == "to_ieee" else "to_ieee"
187+
print(f"Debug: Applying IBM float reversal (direction: {transform.direction} -> {reverse_direction})")
188+
reverse_transform = IbmFloatTransform(reverse_direction, transform.keys)
189+
return reverse_transform.apply(data)
190+
191+
else:
192+
# For unknown transforms, return data unchanged
193+
# This maintains compatibility if new transforms are added
194+
print(f"Warning: Unknown transform type {type(transform).__name__}, cannot reverse")
195+
return data
196+
197+
198+
def get_header_raw_and_transformed(
199+
segy_file: SegyFile,
200+
indices: int | list[int] | np.ndarray | slice
201+
) -> tuple[NDArray, NDArray]:
202+
"""Convenience function to get both raw and transformed header data.
203+
204+
This is a drop-in replacement that provides the functionality you requested
205+
without modifying the segy package.
206+
207+
Args:
208+
segy_file: The SegyFile instance
209+
indices: Which headers to retrieve
210+
211+
Returns:
212+
Tuple of (raw_headers, transformed_headers)
213+
214+
Example:
215+
from header_raw_transformed_accessor import get_header_raw_and_transformed
216+
217+
# Single header
218+
raw_hdr, transformed_hdr = get_header_raw_and_transformed(segy_file, 0)
219+
220+
# Multiple headers
221+
raw_hdrs, transformed_hdrs = get_header_raw_and_transformed(segy_file, [0, 1, 2])
222+
223+
# Slice of headers
224+
raw_hdrs, transformed_hdrs = get_header_raw_and_transformed(segy_file, slice(0, 10))
225+
"""
226+
accessor = HeaderRawTransformedAccessor(segy_file)
227+
return accessor.get_raw_and_transformed(indices)

src/mdio/segy/_workers.py

Lines changed: 8 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
from mdio.api.io import to_mdio
1515
from mdio.builder.schemas.dtype import ScalarType
16+
from mdio.segy._disaster_recovery_wrapper import get_header_raw_and_transformed
1617

1718
if TYPE_CHECKING:
1819
from segy.arrays import HeaderArray
@@ -81,7 +82,7 @@ def header_scan_worker(
8182

8283
return cast("HeaderArray", trace_header)
8384

84-
85+
@profile
8586
def trace_worker( # noqa: PLR0913
8687
segy_kw: SegyFileArguments,
8788
output_path: UPath,
@@ -134,12 +135,13 @@ def trace_worker( # noqa: PLR0913
134135
worker_variables.append(raw_header_key)
135136

136137
ds_to_write = dataset[worker_variables]
138+
raw_headers, transformed_headers = get_header_raw_and_transformed(segy_file, live_trace_indexes)
137139

138140
if header_key in worker_variables:
139141
# Create temporary array for headers with the correct shape
140142
# TODO(BrianMichell): Implement this better so that we can enable fill values without changing the code. #noqa: TD003
141143
tmp_headers = np.zeros_like(dataset[header_key])
142-
tmp_headers[not_null] = traces.header
144+
tmp_headers[not_null] = transformed_headers
143145
# Create a new Variable object to avoid copying the temporary array
144146
# The ideal solution is to use `ds_to_write[header_key][:] = tmp_headers`
145147
# but Xarray appears to be copying memory instead of doing direct assignment.
@@ -150,40 +152,17 @@ def trace_worker( # noqa: PLR0913
150152
attrs=ds_to_write[header_key].attrs,
151153
encoding=ds_to_write[header_key].encoding, # Not strictly necessary, but safer than not doing it.
152154
)
155+
del transformed_headers # Manage memory
153156
if raw_header_key in worker_variables:
154157
tmp_raw_headers = np.zeros_like(dataset[raw_header_key])
155-
156-
# Get the indices where we need to place results
157-
live_mask = not_null
158-
live_positions = np.where(live_mask.ravel())[0]
159-
160-
if len(live_positions) > 0:
161-
# Calculate byte ranges for headers
162-
header_size = 240
163-
trace_offset = segy_file.spec.trace.offset
164-
trace_itemsize = segy_file.spec.trace.itemsize
165-
166-
starts = []
167-
ends = []
168-
for global_trace_idx in live_trace_indexes:
169-
header_start = trace_offset + global_trace_idx * trace_itemsize
170-
header_end = header_start + header_size
171-
starts.append(header_start)
172-
ends.append(header_end)
173-
174-
# Capture raw bytes
175-
raw_header_bytes = merge_cat_file(segy_file.fs, segy_file.url, starts, ends)
176-
177-
# Convert and place results
178-
raw_headers_array = np.frombuffer(bytes(raw_header_bytes), dtype="|V240")
179-
tmp_raw_headers.ravel()[live_positions] = raw_headers_array
180-
158+
tmp_raw_headers[not_null] = raw_headers.view("|V240")
181159
ds_to_write[raw_header_key] = Variable(
182160
ds_to_write[raw_header_key].dims,
183161
tmp_raw_headers,
184162
attrs=ds_to_write[raw_header_key].attrs,
185-
encoding=ds_to_write[raw_header_key].encoding,
163+
encoding=ds_to_write[raw_header_key].encoding, # Not strictly necessary, but safer than not doing it.
186164
)
165+
del raw_headers # Manage memory
187166

188167
data_variable = ds_to_write[data_variable_name]
189168
fill_value = _get_fill_value(ScalarType(data_variable.dtype.name))

src/mdio/segy/blocked_io.py

Lines changed: 8 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,7 @@
22

33
from __future__ import annotations
44

5-
import multiprocessing as mp
65
import os
7-
from concurrent.futures import ProcessPoolExecutor
8-
from concurrent.futures import as_completed
96
from pathlib import Path
107
from typing import TYPE_CHECKING
118

@@ -80,37 +77,20 @@ def to_zarr( # noqa: PLR0913, PLR0915
8077
chunk_iter = ChunkIterator(shape=data.shape, chunks=worker_chunks, dim_names=data.dims)
8178
num_chunks = chunk_iter.num_chunks
8279

83-
# For Unix async writes with s3fs/fsspec & multiprocessing, use 'spawn' instead of default
84-
# 'fork' to avoid deadlocks on cloud stores. Slower but necessary. Default on Windows.
85-
num_cpus = int(os.getenv("MDIO__IMPORT__CPU_COUNT", default_cpus))
86-
num_workers = min(num_chunks, num_cpus)
87-
context = mp.get_context("spawn")
88-
executor = ProcessPoolExecutor(max_workers=num_workers, mp_context=context)
89-
9080
segy_kw = {
9181
"url": segy_file.fs.unstrip_protocol(segy_file.url),
9282
"spec": segy_file.spec,
9383
"settings": segy_file.settings,
9484
}
95-
with executor:
96-
futures = []
97-
common_args = (segy_kw, output_path, data_variable_name)
98-
for region in chunk_iter:
99-
subset_args = (region, grid_map, dataset.isel(region))
100-
future = executor.submit(trace_worker, *common_args, *subset_args)
101-
futures.append(future)
102-
103-
iterable = tqdm(
104-
as_completed(futures),
105-
total=num_chunks,
106-
unit="block",
107-
desc="Ingesting traces",
108-
)
10985

110-
for future in iterable:
111-
result = future.result()
112-
if result is not None:
113-
_update_stats(final_stats, result)
86+
common_args = (segy_kw, output_path, data_variable_name)
87+
88+
# Execute trace_worker serially for profiling
89+
for region in tqdm(chunk_iter, total=num_chunks, unit="block", desc="Ingesting traces"):
90+
subset_args = (region, grid_map, dataset.isel(region))
91+
result = trace_worker(*common_args, *subset_args)
92+
if result is not None:
93+
_update_stats(final_stats, result)
11494

11595
# Xarray doesn't directly support incremental attribute updates when appending to an existing Zarr store.
11696
# HACK: We will update the array attribute using zarr's API directly.

0 commit comments

Comments
 (0)