Skip to content

Commit bbabc08

Browse files
Refactor initial file information handling while avoiding asyncio issues. (#701)
* Fix bug 613 on top of latest main * Exclude TestBug613HandlesNotClosed from coverage * Space is added by pre-commit * Exclude from code coverage: test_teapot_import_cloud_to_cloud * Remove tests as per the PR review * PR review: move _get_coordinate_scalar() * PR Review fixes * refactor: replace individual SEG-Y file arguments with SegyFileInfo object for cleaner function signatures and consistent data access * refactor: rename parameter in _get_horizontal_coordinate_unit for consistency with SegyFileInfo usage * refactor: rename parameter in _add_segy_file_headers for consistency with SegyFileInfo usage * refactor: rename segy_kw to segy_file_kwargs for consistency with naming conventions * refactor: rename segy_kw to segy_file_kwargs across functions for improved clarity and consistency * refactor: reorder imports in scalar.py for adherence to style guidelines and future type checking compatibility * refactor: simplify info_worker by removing trace_indices handling and unused imports * refactor: update type hints in _workers.py for improved readability and compliance with TYPE_CHECKING --------- Co-authored-by: Altay Sansal <[email protected]>
1 parent a664dbd commit bbabc08

File tree

6 files changed

+151
-94
lines changed

6 files changed

+151
-94
lines changed

src/mdio/converters/segy.py

Lines changed: 55 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@
44

55
import base64
66
import logging
7+
import multiprocessing as mp
78
import os
8-
from dataclasses import dataclass
9+
from concurrent.futures import ProcessPoolExecutor
910
from typing import TYPE_CHECKING
1011

1112
import numpy as np
1213
import zarr
13-
from segy import SegyFile
1414
from segy.config import SegyFileSettings
1515
from segy.config import SegyHeaderOverrides
1616
from segy.standards.codes import MeasurementSystem as SegyMeasurementSystem
@@ -37,9 +37,10 @@
3737
from mdio.core.utils_write import MAX_SIZE_LIVE_MASK
3838
from mdio.core.utils_write import get_constrained_chunksize
3939
from mdio.segy import blocked_io
40+
from mdio.segy._workers import SegyFileInfo
41+
from mdio.segy._workers import info_worker
4042
from mdio.segy.scalar import SCALE_COORDINATE_KEYS
4143
from mdio.segy.scalar import _apply_coordinate_scalar
42-
from mdio.segy.scalar import _get_coordinate_scalar
4344
from mdio.segy.utilities import get_grid_plan
4445

4546
if TYPE_CHECKING:
@@ -54,6 +55,7 @@
5455
from mdio.builder.schemas import Dataset
5556
from mdio.builder.templates.abstract_dataset_template import AbstractDatasetTemplate
5657
from mdio.core.dimension import Dimension
58+
from mdio.segy._workers import SegyFileArguments
5759

5860
logger = logging.getLogger(__name__)
5961

@@ -135,37 +137,9 @@ def grid_density_qc(grid: Grid, num_traces: int) -> None:
135137
raise GridTraceSparsityError(grid.shape, num_traces, msg)
136138

137139

138-
@dataclass
139-
class SegyFileHeaderDump:
140-
"""Segy metadata information."""
141-
142-
text_header: str
143-
binary_header_dict: dict
144-
raw_binary_headers: bytes
145-
146-
147-
def _get_segy_file_header_dump(segy_file: SegyFile) -> SegyFileHeaderDump:
148-
"""Reads information from a SEG-Y file."""
149-
text_header = segy_file.text_header
150-
151-
raw_binary_headers: bytes = segy_file.fs.read_block(
152-
fn=segy_file.url,
153-
offset=segy_file.spec.binary_header.offset,
154-
length=segy_file.spec.binary_header.itemsize,
155-
)
156-
157-
# We read here twice, but it's ok for now. Only 400-bytes.
158-
binary_header_dict = segy_file.binary_header.to_dict()
159-
160-
return SegyFileHeaderDump(
161-
text_header=text_header,
162-
binary_header_dict=binary_header_dict,
163-
raw_binary_headers=raw_binary_headers,
164-
)
165-
166-
167140
def _scan_for_headers(
168-
segy_file: SegyFile,
141+
segy_file_kwargs: SegyFileArguments,
142+
segy_file_info: SegyFileInfo,
169143
template: AbstractDatasetTemplate,
170144
grid_overrides: dict[str, Any] | None = None,
171145
) -> tuple[list[Dimension], SegyHeaderArray]:
@@ -176,7 +150,8 @@ def _scan_for_headers(
176150
"""
177151
full_chunk_size = template.full_chunk_size
178152
segy_dimensions, chunk_size, segy_headers = get_grid_plan(
179-
segy_file=segy_file,
153+
segy_file_kwargs=segy_file_kwargs,
154+
segy_file_info=segy_file_info,
180155
return_headers=True,
181156
template=template,
182157
chunksize=full_chunk_size,
@@ -192,12 +167,29 @@ def _scan_for_headers(
192167
return segy_dimensions, segy_headers
193168

194169

195-
def _build_and_check_grid(segy_dimensions: list[Dimension], num_traces: int, segy_headers: SegyHeaderArray) -> Grid:
170+
def _read_segy_file_info(segy_file_kwargs: SegyFileArguments) -> SegyFileInfo:
171+
"""Read SEG-Y file in a separate process.
172+
173+
This is an ugly workaround for Zarr issues 3487 'Explicitly using fsspec and zarr FsspecStore causes
174+
RuntimeError "Task attached to a different loop"'
175+
"""
176+
# TODO (Dmitriy Repin): when Zarr issue 3487 is resolved, we can remove this workaround
177+
# https://github.com/zarr-developers/zarr-python/issues/3487
178+
with ProcessPoolExecutor(max_workers=1, mp_context=mp.get_context("spawn")) as executor:
179+
future = executor.submit(info_worker, segy_file_kwargs)
180+
return future.result()
181+
182+
183+
def _build_and_check_grid(
184+
segy_dimensions: list[Dimension],
185+
segy_file_info: SegyFileInfo,
186+
segy_headers: SegyHeaderArray,
187+
) -> Grid:
196188
"""Build and check the grid from the SEG-Y headers and dimensions.
197189
198190
Args:
199191
segy_dimensions: List of of all SEG-Y dimensions to build grid from.
200-
num_traces: Number of traces in the SEG-Y file.
192+
segy_file_info: SegyFileInfo instance containing the SEG-Y file information.
201193
segy_headers: Headers read in from SEG-Y file for building the trace map.
202194
203195
Returns:
@@ -207,6 +199,7 @@ def _build_and_check_grid(segy_dimensions: list[Dimension], num_traces: int, seg
207199
GridTraceCountError: If number of traces in SEG-Y file does not match the parsed grid
208200
"""
209201
grid = Grid(dims=segy_dimensions)
202+
num_traces = segy_file_info.num_traces
210203
grid_density_qc(grid, num_traces)
211204
grid.build_map(segy_headers)
212205
# Check grid validity by comparing trace numbers
@@ -303,9 +296,9 @@ def populate_non_dim_coordinates(
303296
return dataset, drop_vars_delayed
304297

305298

306-
def _get_horizontal_coordinate_unit(segy_info: SegyFileHeaderDump) -> LengthUnitModel | None:
299+
def _get_horizontal_coordinate_unit(segy_file_info: SegyFileInfo) -> LengthUnitModel | None:
307300
"""Get the coordinate unit from the SEG-Y headers."""
308-
measurement_system_code = int(segy_info.binary_header_dict[MEASUREMENT_SYSTEM_KEY])
301+
measurement_system_code = int(segy_file_info.binary_header_dict[MEASUREMENT_SYSTEM_KEY])
309302

310303
if measurement_system_code not in (1, 2):
311304
logger.warning(
@@ -359,19 +352,19 @@ def _populate_coordinates(
359352
return dataset, drop_vars_delayed
360353

361354

362-
def _add_segy_file_headers(xr_dataset: xr_Dataset, segy_file_header_dump: SegyFileHeaderDump) -> xr_Dataset:
355+
def _add_segy_file_headers(xr_dataset: xr_Dataset, segy_file_info: SegyFileInfo) -> xr_Dataset:
363356
save_file_header = os.getenv("MDIO__IMPORT__SAVE_SEGY_FILE_HEADER", "") in ("1", "true", "yes", "on")
364357
if not save_file_header:
365358
return xr_dataset
366359

367360
expected_rows = 40
368361
expected_cols = 80
369362

370-
text_header_rows = segy_file_header_dump.text_header.splitlines()
363+
text_header_rows = segy_file_info.text_header.splitlines()
371364
text_header_cols_bad = [len(row) != expected_cols for row in text_header_rows]
372365

373366
if len(text_header_rows) != expected_rows:
374-
err = f"Invalid text header count: expected {expected_rows}, got {len(segy_file_header_dump.text_header)}"
367+
err = f"Invalid text header count: expected {expected_rows}, got {len(segy_file_info.text_header)}"
375368
raise ValueError(err)
376369

377370
if any(text_header_cols_bad):
@@ -381,12 +374,12 @@ def _add_segy_file_headers(xr_dataset: xr_Dataset, segy_file_header_dump: SegyFi
381374
xr_dataset["segy_file_header"] = ((), "")
382375
xr_dataset["segy_file_header"].attrs.update(
383376
{
384-
"textHeader": segy_file_header_dump.text_header,
385-
"binaryHeader": segy_file_header_dump.binary_header_dict,
377+
"textHeader": segy_file_info.text_header,
378+
"binaryHeader": segy_file_info.binary_header_dict,
386379
}
387380
)
388381
if os.getenv("MDIO__IMPORT__RAW_HEADERS") in ("1", "true", "yes", "on"):
389-
raw_binary_base64 = base64.b64encode(segy_file_header_dump.raw_binary_headers).decode("ascii")
382+
raw_binary_base64 = base64.b64encode(segy_file_info.raw_binary_headers).decode("ascii")
390383
xr_dataset["segy_file_header"].attrs.update({"rawBinaryHeader": raw_binary_base64})
391384

392385
return xr_dataset
@@ -532,17 +525,21 @@ def segy_to_mdio( # noqa PLR0913
532525
raise FileExistsError(err)
533526

534527
segy_settings = SegyFileSettings(storage_options=input_path.storage_options)
535-
segy_file = SegyFile(
536-
url=input_path.as_posix(),
537-
spec=segy_spec,
538-
settings=segy_settings,
539-
header_overrides=segy_header_overrides,
528+
segy_file_kwargs: SegyFileArguments = {
529+
"url": input_path.as_posix(),
530+
"spec": segy_spec,
531+
"settings": segy_settings,
532+
"header_overrides": segy_header_overrides,
533+
}
534+
segy_file_info = _read_segy_file_info(segy_file_kwargs)
535+
536+
segy_dimensions, segy_headers = _scan_for_headers(
537+
segy_file_kwargs,
538+
segy_file_info,
539+
template=mdio_template,
540+
grid_overrides=grid_overrides,
540541
)
541-
segy_info: SegyFileHeaderDump = _get_segy_file_header_dump(segy_file)
542-
543-
segy_dimensions, segy_headers = _scan_for_headers(segy_file, mdio_template, grid_overrides)
544-
545-
grid = _build_and_check_grid(segy_dimensions, segy_file.num_traces, segy_headers)
542+
grid = _build_and_check_grid(segy_dimensions, segy_file_info, segy_headers)
546543

547544
_, non_dim_coords = _get_coordinates(grid, segy_headers, mdio_template)
548545
header_dtype = to_structured_type(segy_spec.trace.header.dtype)
@@ -554,7 +551,7 @@ def segy_to_mdio( # noqa PLR0913
554551
logger.warning("MDIO__IMPORT__RAW_HEADERS is experimental and expected to change or be removed.")
555552
mdio_template = _add_raw_headers_to_template(mdio_template)
556553

557-
horizontal_unit = _get_horizontal_coordinate_unit(segy_info)
554+
horizontal_unit = _get_horizontal_coordinate_unit(segy_file_info)
558555
mdio_ds: Dataset = mdio_template.build_dataset(
559556
name=mdio_template.name,
560557
sizes=grid.shape,
@@ -571,15 +568,14 @@ def segy_to_mdio( # noqa PLR0913
571568

572569
xr_dataset: xr_Dataset = to_xarray_dataset(mdio_ds=mdio_ds)
573570

574-
coordinate_scalar = _get_coordinate_scalar(segy_file)
575571
xr_dataset, drop_vars_delayed = _populate_coordinates(
576572
dataset=xr_dataset,
577573
grid=grid,
578574
coords=non_dim_coords,
579-
horizontal_coordinate_scalar=coordinate_scalar,
575+
horizontal_coordinate_scalar=segy_file_info.coordinate_scalar,
580576
)
581577

582-
xr_dataset = _add_segy_file_headers(xr_dataset, segy_info)
578+
xr_dataset = _add_segy_file_headers(xr_dataset, segy_file_info)
583579

584580
xr_dataset.trace_mask.data[:] = grid.live_mask
585581
# IMPORTANT: Do not drop the "trace_mask" here, as it will be used later in
@@ -600,7 +596,7 @@ def segy_to_mdio( # noqa PLR0913
600596
# This is an memory-expensive and time-consuming read-write operation
601597
# performed in chunks to save the memory
602598
blocked_io.to_zarr(
603-
segy_file=segy_file,
599+
segy_file_kwargs=segy_file_kwargs,
604600
output_path=output_path,
605601
grid_map=grid.map,
606602
dataset=xr_dataset,

src/mdio/segy/_workers.py

Lines changed: 68 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22

33
from __future__ import annotations
44

5+
import logging
56
import os
7+
from dataclasses import dataclass
68
from typing import TYPE_CHECKING
79
from typing import TypedDict
810

@@ -13,9 +15,11 @@
1315
from mdio.api.io import to_mdio
1416
from mdio.builder.schemas.dtype import ScalarType
1517
from mdio.segy._raw_trace_wrapper import SegyFileRawTraceWrapper
18+
from mdio.segy.scalar import _get_coordinate_scalar
1619

1720
if TYPE_CHECKING:
1821
from segy.config import SegyFileSettings
22+
from segy.config import SegyHeaderOverrides
1923
from segy.schema import SegySpec
2024
from upath import UPath
2125
from xarray import Dataset as xr_Dataset
@@ -29,32 +33,41 @@
2933
from mdio.builder.xarray_builder import _get_fill_value
3034
from mdio.constants import fill_value_map
3135

36+
if TYPE_CHECKING:
37+
from numpy.typing import NDArray
38+
39+
40+
logger = logging.getLogger(__name__)
41+
3242

3343
class SegyFileArguments(TypedDict):
3444
"""Arguments to open SegyFile instance creation."""
3545

3646
url: str
3747
spec: SegySpec | None
3848
settings: SegyFileSettings | None
49+
header_overrides: SegyHeaderOverrides | None
3950

4051

4152
def header_scan_worker(
42-
segy_kw: SegyFileArguments, trace_range: tuple[int, int], subset: list[str] | None = None
53+
segy_file_kwargs: SegyFileArguments,
54+
trace_range: tuple[int, int],
55+
subset: list[str] | None = None,
4356
) -> HeaderArray:
4457
"""Header scan worker.
4558
4659
If SegyFile is not open, it can either accept a path string or a handle that was opened in
4760
a different context manager.
4861
4962
Args:
50-
segy_kw: Arguments to open SegyFile instance.
63+
segy_file_kwargs: Arguments to open SegyFile instance.
5164
trace_range: Tuple consisting of the trace ranges to read.
5265
subset: List of header names to filter and keep.
5366
5467
Returns:
5568
HeaderArray parsed from SEG-Y library.
5669
"""
57-
segy_file = SegyFile(**segy_kw)
70+
segy_file = SegyFile(**segy_file_kwargs)
5871

5972
slice_ = slice(*trace_range)
6073

@@ -82,7 +95,7 @@ def header_scan_worker(
8295

8396

8497
def trace_worker( # noqa: PLR0913
85-
segy_kw: SegyFileArguments,
98+
segy_file_kwargs: SegyFileArguments,
8699
output_path: UPath,
87100
data_variable_name: str,
88101
region: dict[str, slice],
@@ -92,7 +105,7 @@ def trace_worker( # noqa: PLR0913
92105
"""Writes a subset of traces from a region of the dataset of Zarr file.
93106
94107
Args:
95-
segy_kw: Arguments to open SegyFile instance.
108+
segy_file_kwargs: Arguments to open SegyFile instance.
96109
output_path: Universal Path for the output Zarr dataset
97110
(e.g. local file path or cloud storage URI) the location
98111
also includes storage options for cloud storage.
@@ -114,7 +127,7 @@ def trace_worker( # noqa: PLR0913
114127
return None
115128

116129
# Open the SEG-Y file in this process since the open file handles cannot be shared across processes.
117-
segy_file = SegyFile(**segy_kw)
130+
segy_file = SegyFile(**segy_file_kwargs)
118131

119132
# Setting the zarr config to 1 thread to ensure we honor the `MDIO__IMPORT__MAX_WORKERS` environment variable.
120133
# The Zarr 3 engine utilizes multiple threads. This can lead to resource contention and unpredictable memory usage.
@@ -196,3 +209,52 @@ def trace_worker( # noqa: PLR0913
196209
sum_squares=(np.ma.power(nonzero_samples, 2).sum(dtype="float64")),
197210
histogram=histogram,
198211
)
212+
213+
214+
@dataclass
215+
class SegyFileInfo:
216+
"""SEG-Y file header information."""
217+
218+
num_traces: int
219+
sample_labels: NDArray[np.int32]
220+
text_header: str
221+
binary_header_dict: dict
222+
raw_binary_headers: bytes
223+
coordinate_scalar: int
224+
225+
226+
def info_worker(segy_file_kwargs: SegyFileArguments) -> SegyFileInfo:
227+
"""Reads information from a SEG-Y file.
228+
229+
Args:
230+
segy_file_kwargs: Arguments to open SegyFile instance.
231+
232+
Returns:
233+
SegyFileInfo containing number of traces, sample labels, and header info.
234+
"""
235+
segy_file = SegyFile(**segy_file_kwargs)
236+
num_traces = segy_file.num_traces
237+
sample_labels = segy_file.sample_labels
238+
239+
text_header = segy_file.text_header
240+
241+
# Get header information directly
242+
raw_binary_headers = segy_file.fs.read_block(
243+
fn=segy_file.url,
244+
offset=segy_file.spec.binary_header.offset,
245+
length=segy_file.spec.binary_header.itemsize,
246+
)
247+
248+
# We read here twice, but it's ok for now. Only 400-bytes.
249+
binary_header_dict = segy_file.binary_header.to_dict()
250+
251+
coordinate_scalar = _get_coordinate_scalar(segy_file)
252+
253+
return SegyFileInfo(
254+
num_traces=num_traces,
255+
sample_labels=sample_labels,
256+
text_header=text_header,
257+
binary_header_dict=binary_header_dict,
258+
raw_binary_headers=raw_binary_headers,
259+
coordinate_scalar=coordinate_scalar,
260+
)

0 commit comments

Comments
 (0)