Skip to content

Commit a770823

Browse files
authored
Ensure header overrides functionality of TGSAI/segy is available and applied. (TGSAI#700)
* Support SEG-Y header overrides during file ingestion * Clean up redundant SEG-Y file initialization and fix import naming inconsistency * Refactor SEG-Y file initialization and adjust import formatting for clarity * refine coordinate unit extraction logic * Refactor coordinate scalar handling: extract scalar logic into `scalar.py`, apply scalars to coordinates, and streamline related imports * Add unit tests for coordinate scalar getters and apply functions * Fix typo in coordinate scalar key: `cdp_y_` → `cdp_y` * remove redundant checks * more informative warning message * update schema chunk shape for shots * ensure result it wrapped in HeaderArray * add tutorial for dealing with corrupt SEG-Y files. * nicer syntax * fix broken tests after schema change --------- Co-authored-by: Altay Sansal <[email protected]>
1 parent a0df710 commit a770823

File tree

12 files changed

+1675
-21
lines changed

12 files changed

+1675
-21
lines changed

docs/tutorials/corrupt_files.ipynb

Lines changed: 1469 additions & 0 deletions
Large diffs are not rendered by default.

docs/tutorials/index.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,5 @@ quickstart
77
creation
88
compression
99
rechunking
10+
corrupt_files
1011
```

src/mdio/builder/templates/seismic_2d_prestack_shot.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def __init__(self, data_domain: SeismicDataDomain):
1818
self._coord_dim_names = ("shot_point", "channel")
1919
self._dim_names = (*self._coord_dim_names, self._data_domain)
2020
self._coord_names = ("gun", "source_coord_x", "source_coord_y", "group_coord_x", "group_coord_y")
21-
self._var_chunk_shape = (16, 64, 1024)
21+
self._var_chunk_shape = (16, 32, 2048)
2222

2323
@property
2424
def _name(self) -> str:

src/mdio/builder/templates/seismic_3d_prestack_shot.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def __init__(self, data_domain: SeismicDataDomain):
1818
self._coord_dim_names = ("shot_point", "cable", "channel")
1919
self._dim_names = (*self._coord_dim_names, self._data_domain)
2020
self._coord_names = ("gun", "source_coord_x", "source_coord_y", "group_coord_x", "group_coord_y")
21-
self._var_chunk_shape = (8, 2, 128, 1024)
21+
self._var_chunk_shape = (8, 1, 128, 2048)
2222

2323
@property
2424
def _name(self) -> str:

src/mdio/converters/segy.py

Lines changed: 44 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,9 @@
1212
import zarr
1313
from segy import SegyFile
1414
from segy.config import SegyFileSettings
15-
from segy.standards.codes import MeasurementSystem as segy_MeasurementSystem
16-
from segy.standards.fields.trace import Rev0 as TraceHeaderFieldsRev0
15+
from segy.config import SegyHeaderOverrides
16+
from segy.standards.codes import MeasurementSystem as SegyMeasurementSystem
17+
from segy.standards.fields import binary as binary_header_fields
1718

1819
from mdio.api.io import _normalize_path
1920
from mdio.api.io import to_mdio
@@ -36,6 +37,9 @@
3637
from mdio.core.utils_write import MAX_SIZE_LIVE_MASK
3738
from mdio.core.utils_write import get_constrained_chunksize
3839
from mdio.segy import blocked_io
40+
from mdio.segy.scalar import SCALE_COORDINATE_KEYS
41+
from mdio.segy.scalar import _apply_coordinate_scalar
42+
from mdio.segy.scalar import _get_coordinate_scalar
3943
from mdio.segy.utilities import get_grid_plan
4044

4145
if TYPE_CHECKING:
@@ -54,6 +58,9 @@
5458
logger = logging.getLogger(__name__)
5559

5660

61+
MEASUREMENT_SYSTEM_KEY = binary_header_fields.Rev0.MEASUREMENT_SYSTEM_CODE.model.name
62+
63+
5764
def grid_density_qc(grid: Grid, num_traces: int) -> None:
5865
"""Quality control for sensible grid density during SEG-Y to MDIO conversion.
5966
@@ -269,6 +276,7 @@ def populate_non_dim_coordinates(
269276
grid: Grid,
270277
coordinates: dict[str, SegyHeaderArray],
271278
drop_vars_delayed: list[str],
279+
horizontal_coordinate_scalar: int,
272280
) -> tuple[xr_Dataset, list[str]]:
273281
"""Populate the xarray dataset with coordinate variables."""
274282
non_data_domain_dims = grid.dim_names[:-1] # minus the data domain dimension
@@ -282,6 +290,10 @@ def populate_non_dim_coordinates(
282290

283291
not_null = coord_trace_indices != grid.map.fill_value
284292
tmp_coord_values[not_null] = coord_values[coord_trace_indices[not_null]]
293+
294+
if coord_name in SCALE_COORDINATE_KEYS:
295+
tmp_coord_values = _apply_coordinate_scalar(tmp_coord_values, horizontal_coordinate_scalar)
296+
285297
dataset[coord_name][:] = tmp_coord_values
286298
drop_vars_delayed.append(coord_name)
287299

@@ -291,16 +303,22 @@ def populate_non_dim_coordinates(
291303
return dataset, drop_vars_delayed
292304

293305

294-
def _get_horizontal_coordinate_unit(segy_headers: list[Dimension]) -> LengthUnitModel | None:
306+
def _get_horizontal_coordinate_unit(segy_info: SegyFileHeaderDump) -> LengthUnitModel | None:
295307
"""Get the coordinate unit from the SEG-Y headers."""
296-
name = TraceHeaderFieldsRev0.COORDINATE_UNIT.name.upper()
297-
unit_hdr = next((c for c in segy_headers if c.name.upper() == name), None)
298-
if unit_hdr is None or len(unit_hdr.coords) == 0:
308+
measurement_system_code = int(segy_info.binary_header_dict[MEASUREMENT_SYSTEM_KEY])
309+
310+
if measurement_system_code not in (1, 2):
311+
logger.warning(
312+
"Unexpected value in coordinate unit (%s) header: %s. Can't extract coordinate unit and will "
313+
"ingest without coordinate units.",
314+
MEASUREMENT_SYSTEM_KEY,
315+
measurement_system_code,
316+
)
299317
return None
300318

301-
if segy_MeasurementSystem(unit_hdr.coords[0]) == segy_MeasurementSystem.METERS:
319+
if measurement_system_code == SegyMeasurementSystem.METERS:
302320
unit = LengthUnitEnum.METER
303-
if segy_MeasurementSystem(unit_hdr.coords[0]) == segy_MeasurementSystem.FEET:
321+
if measurement_system_code == SegyMeasurementSystem.FEET:
304322
unit = LengthUnitEnum.FOOT
305323

306324
return LengthUnitModel(length=unit)
@@ -310,6 +328,7 @@ def _populate_coordinates(
310328
dataset: xr_Dataset,
311329
grid: Grid,
312330
coords: dict[str, SegyHeaderArray],
331+
horizontal_coordinate_scalar: int,
313332
) -> tuple[xr_Dataset, list[str]]:
314333
"""Populate dim and non-dim coordinates in the xarray dataset and write to Zarr.
315334
@@ -319,6 +338,7 @@ def _populate_coordinates(
319338
dataset: The xarray dataset to populate.
320339
grid: The grid object containing the grid map.
321340
coords: The non-dim coordinates to populate.
341+
horizontal_coordinate_scalar: The X/Y coordinate scalar from the SEG-Y file.
322342
323343
Returns:
324344
Xarray dataset with filled coordinates and updated variables to drop after writing
@@ -329,7 +349,11 @@ def _populate_coordinates(
329349

330350
# Populate the non-dimension coordinate variables (N-dim arrays)
331351
dataset, vars_to_drop_later = populate_non_dim_coordinates(
332-
dataset, grid, coordinates=coords, drop_vars_delayed=drop_vars_delayed
352+
dataset,
353+
grid,
354+
coordinates=coords,
355+
drop_vars_delayed=drop_vars_delayed,
356+
horizontal_coordinate_scalar=horizontal_coordinate_scalar,
333357
)
334358

335359
return dataset, drop_vars_delayed
@@ -465,6 +489,7 @@ def segy_to_mdio( # noqa PLR0913
465489
output_path: UPath | Path | str,
466490
overwrite: bool = False,
467491
grid_overrides: dict[str, Any] | None = None,
492+
segy_header_overrides: SegyHeaderOverrides | None = None,
468493
) -> None:
469494
"""A function that converts a SEG-Y file to an MDIO v1 file.
470495
@@ -477,6 +502,7 @@ def segy_to_mdio( # noqa PLR0913
477502
output_path: The universal path for the output MDIO v1 file.
478503
overwrite: Whether to overwrite the output file if it already exists. Defaults to False.
479504
grid_overrides: Option to add grid overrides.
505+
segy_header_overrides: Option to override specific SEG-Y headers during ingestion.
480506
481507
Raises:
482508
FileExistsError: If the output location already exists and overwrite is False.
@@ -489,7 +515,12 @@ def segy_to_mdio( # noqa PLR0913
489515
raise FileExistsError(err)
490516

491517
segy_settings = SegyFileSettings(storage_options=input_path.storage_options)
492-
segy_file = SegyFile(url=input_path.as_posix(), spec=segy_spec, settings=segy_settings)
518+
segy_file = SegyFile(
519+
url=input_path.as_posix(),
520+
spec=segy_spec,
521+
settings=segy_settings,
522+
header_overrides=segy_header_overrides,
523+
)
493524
segy_info: SegyFileHeaderDump = _get_segy_file_header_dump(segy_file)
494525

495526
segy_dimensions, segy_headers = _scan_for_headers(segy_file, mdio_template, grid_overrides)
@@ -506,7 +537,7 @@ def segy_to_mdio( # noqa PLR0913
506537
logger.warning("MDIO__IMPORT__RAW_HEADERS is experimental and expected to change or be removed.")
507538
mdio_template = _add_raw_headers_to_template(mdio_template)
508539

509-
horizontal_unit = _get_horizontal_coordinate_unit(segy_dimensions)
540+
horizontal_unit = _get_horizontal_coordinate_unit(segy_info)
510541
mdio_ds: Dataset = mdio_template.build_dataset(
511542
name=mdio_template.name,
512543
sizes=grid.shape,
@@ -523,10 +554,12 @@ def segy_to_mdio( # noqa PLR0913
523554

524555
xr_dataset: xr_Dataset = to_xarray_dataset(mdio_ds=mdio_ds)
525556

557+
coordinate_scalar = _get_coordinate_scalar(segy_file)
526558
xr_dataset, drop_vars_delayed = _populate_coordinates(
527559
dataset=xr_dataset,
528560
grid=grid,
529561
coords=non_dim_coords,
562+
horizontal_coordinate_scalar=coordinate_scalar,
530563
)
531564

532565
xr_dataset = _add_segy_file_headers(xr_dataset, segy_info)

src/mdio/segy/_workers.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,16 @@
55
import os
66
from typing import TYPE_CHECKING
77
from typing import TypedDict
8-
from typing import cast
98

109
import numpy as np
1110
from segy import SegyFile
11+
from segy.arrays import HeaderArray
1212

1313
from mdio.api.io import to_mdio
1414
from mdio.builder.schemas.dtype import ScalarType
1515
from mdio.segy._raw_trace_wrapper import SegyFileRawTraceWrapper
1616

1717
if TYPE_CHECKING:
18-
from segy.arrays import HeaderArray
1918
from segy.config import SegyFileSettings
2019
from segy.schema import SegySpec
2120
from upath import UPath
@@ -79,7 +78,7 @@ def header_scan_worker(
7978
# (singleton) so we can concat and assign stuff later.
8079
trace_header = np.array(trace_header, dtype=new_dtype, ndmin=1)
8180

82-
return cast("HeaderArray", trace_header)
81+
return HeaderArray(trace_header) # wrap back so we can use aliases
8382

8483

8584
def trace_worker( # noqa: PLR0913

src/mdio/segy/blocked_io.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ def to_zarr( # noqa: PLR0913, PLR0915
9191
"url": segy_file.fs.unstrip_protocol(segy_file.url),
9292
"spec": segy_file.spec,
9393
"settings": segy_file.settings,
94+
"header_overrides": segy_file.header_overrides,
9495
}
9596
with executor:
9697
futures = []

src/mdio/segy/parsers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
def parse_headers(
2626
segy_file: SegyFile,
27-
subset: list[str] | None = None,
27+
subset: tuple[str, ...] | None = None,
2828
block_size: int = 10000,
2929
progress_bar: bool = True,
3030
) -> HeaderArray:

src/mdio/segy/scalar.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
"""Utilities to read, parse, and apply coordinate scalars."""
2+
3+
from __future__ import annotations
4+
5+
import logging
6+
from typing import TYPE_CHECKING
7+
8+
from segy.standards import SegyStandard
9+
from segy.standards.fields import trace as trace_header_fields
10+
11+
if TYPE_CHECKING:
12+
from numpy.typing import NDArray
13+
from segy import SegyFile
14+
15+
16+
logger = logging.getLogger(__name__)
17+
18+
19+
COORD_SCALAR_KEY = trace_header_fields.Rev0.COORDINATE_SCALAR.model.name
20+
VALID_COORD_SCALAR = {1, 10, 100, 1000, 10000}
21+
SCALE_COORDINATE_KEYS = [
22+
"cdp_x",
23+
"cdp_y",
24+
"source_coord_x",
25+
"source_coord_y",
26+
"group_coord_x",
27+
"group_coord_y",
28+
]
29+
30+
31+
def _get_coordinate_scalar(segy_file: SegyFile) -> int:
32+
"""Get and parse the coordinate scalar from the first SEG-Y trace header."""
33+
file_revision = segy_file.spec.segy_standard
34+
first_header = segy_file.header[0]
35+
coord_scalar = int(first_header[COORD_SCALAR_KEY])
36+
37+
# Per Rev2, standardize 0 to 1 if a file is 2+.
38+
if coord_scalar == 0 and file_revision >= SegyStandard.REV2:
39+
logger.warning("Coordinate scalar is 0 and file is %s. Setting to 1.", file_revision)
40+
return 1
41+
42+
def validate_segy_scalar(scalar: int) -> bool:
43+
"""Validate if coord scalar matches the seg-y standard."""
44+
logger.debug("Coordinate scalar is %s", scalar)
45+
return abs(scalar) in VALID_COORD_SCALAR # valid values
46+
47+
is_valid = validate_segy_scalar(coord_scalar)
48+
if not is_valid:
49+
msg = f"Invalid coordinate scalar: {coord_scalar} for file revision {file_revision}."
50+
raise ValueError(msg)
51+
52+
logger.info("Coordinate scalar is parsed as %s", coord_scalar)
53+
return coord_scalar
54+
55+
56+
def _apply_coordinate_scalar(data: NDArray, scalar: int) -> NDArray:
57+
if scalar < 0:
58+
scalar = 1 / scalar
59+
return data * abs(scalar)
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
"""Tests for coordinate scalar getters and apply functions."""
2+
3+
from __future__ import annotations
4+
5+
from typing import TYPE_CHECKING
6+
from unittest.mock import MagicMock
7+
8+
import numpy as np
9+
import pytest
10+
from segy import SegyFile
11+
from segy.standards import SegyStandard
12+
from segy.standards.fields import trace as trace_header_fields
13+
14+
from mdio.segy.scalar import _apply_coordinate_scalar
15+
from mdio.segy.scalar import _get_coordinate_scalar
16+
17+
if TYPE_CHECKING:
18+
from numpy.typing import NDArray
19+
20+
COORD_SCALAR_KEY = trace_header_fields.Rev0.COORDINATE_SCALAR.model.name
21+
22+
23+
@pytest.fixture
24+
def mock_segy_file() -> SegyFile:
25+
"""Mock SegyFile object."""
26+
segy_file = MagicMock(spec=SegyFile)
27+
segy_file.spec = MagicMock()
28+
segy_file.header = [MagicMock()]
29+
return segy_file
30+
31+
32+
@pytest.mark.parametrize("scalar", [1, 100, 10000, -10, -1000])
33+
def test_get_coordinate_scalar_valid(mock_segy_file: SegyFile, scalar: int) -> None:
34+
"""Test valid options when getting coordinate scalar."""
35+
mock_segy_file.spec.segy_standard = SegyStandard.REV1
36+
mock_segy_file.header[0].__getitem__.return_value = scalar
37+
38+
result = _get_coordinate_scalar(mock_segy_file)
39+
40+
assert result == scalar
41+
42+
43+
@pytest.mark.parametrize(
44+
"revision",
45+
[SegyStandard.REV2, SegyStandard.REV21],
46+
)
47+
def test_get_coordinate_scalar_zero_rev2_plus(mock_segy_file: SegyFile, revision: SegyStandard) -> None:
48+
"""Test when scalar is normalized to 1 (from 0) in Rev2+."""
49+
mock_segy_file.spec.segy_standard = revision
50+
mock_segy_file.header[0].__getitem__.return_value = 0
51+
52+
result = _get_coordinate_scalar(mock_segy_file)
53+
54+
assert result == 1
55+
56+
57+
@pytest.mark.parametrize(
58+
("scalar", "revision", "error_msg"),
59+
[
60+
(0, SegyStandard.REV0, "Invalid coordinate scalar: 0 for file revision SegyStandard.REV0."),
61+
(110, SegyStandard.REV1, "Invalid coordinate scalar: 110 for file revision SegyStandard.REV1."),
62+
(32768, SegyStandard.REV1, "Invalid coordinate scalar: 32768 for file revision SegyStandard.REV1."),
63+
],
64+
)
65+
def test_get_coordinate_scalar_invalid(
66+
mock_segy_file: SegyFile, scalar: int, revision: SegyStandard, error_msg: str
67+
) -> None:
68+
"""Test invalid options when getting coordinate scalar."""
69+
mock_segy_file.spec.segy_standard = revision
70+
mock_segy_file.header[0].__getitem__.return_value = scalar
71+
72+
with pytest.raises(ValueError, match=error_msg):
73+
_get_coordinate_scalar(mock_segy_file)
74+
75+
76+
@pytest.mark.parametrize(
77+
("data", "scalar", "expected"),
78+
[
79+
# POSITIVE
80+
(np.array([1, 2, 3]), 1, np.array([1, 2, 3])),
81+
(np.array([1, 2, 3]), 10, np.array([10, 20, 30])),
82+
(np.array([[1, 2], [3, 4]]), 1000, np.array([[1000, 2000], [3000, 4000]])),
83+
# NEGATIVE
84+
(np.array([1, 2, 3]), -1, np.array([1, 2, 3])),
85+
(np.array([10, 20, 30]), -10, np.array([1, 2, 3])),
86+
(np.array([[1000, 2000], [3000, 4000]]), -1000, np.array([[1, 2], [3, 4]])),
87+
],
88+
)
89+
def test_apply_coordinate_scalar(data: NDArray, scalar: int, expected: NDArray) -> None:
90+
"""Test applying coordinate scalar with negative and positive code."""
91+
result = _apply_coordinate_scalar(data, scalar)
92+
assert np.allclose(result, expected)

0 commit comments

Comments
 (0)