Skip to content

Commit 0cf9667

Browse files
committed
Remove hardcoded values + linting
1 parent ef3505f commit 0cf9667

File tree

4 files changed

+29
-25
lines changed

4 files changed

+29
-25
lines changed

src/mdio/converters/segy.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -176,8 +176,7 @@ def _scan_for_headers(
176176
# If they don't match, it means the template wasn't properly updated
177177
if full_chunk_shape != chunk_size:
178178
logger.warning(
179-
"Chunk shape mismatch: template has %s but grid_plan returned %s. "
180-
"Using grid_plan chunk shape.",
179+
"Chunk shape mismatch: template has %s but grid_plan returned %s. Using grid_plan chunk shape.",
181180
full_chunk_shape,
182181
chunk_size,
183182
)

src/mdio/segy/geometry.py

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
from numpy.typing import NDArray
2525
from segy.arrays import HeaderArray
2626

27+
from mdio.builder.templates.base import AbstractDatasetTemplate
28+
2729

2830
logger = logging.getLogger(__name__)
2931

@@ -303,6 +305,7 @@ def transform(
303305
self,
304306
index_headers: HeaderArray,
305307
grid_overrides: dict[str, bool | int],
308+
template: AbstractDatasetTemplate, # noqa: ARG002
306309
) -> NDArray:
307310
"""Perform the grid transform."""
308311

@@ -379,42 +382,38 @@ def transform(
379382
self,
380383
index_headers: HeaderArray,
381384
grid_overrides: dict[str, bool | int],
385+
template: AbstractDatasetTemplate, # noqa: ARG002
382386
) -> NDArray:
383387
"""Perform the grid transform."""
384388
self.validate(index_headers, grid_overrides)
385389

386390
# Filter to only include dimension fields, not coordinate fields
387-
# Coordinate fields typically have _x, _y suffixes or specific names like 'gun'
388-
# We want to keep fields like shot_point, cable, channel but exclude source_coord_x, etc.
391+
# We want to keep fields like shot_point, cable, channel but exclude coordinate fields
392+
# Use the template's coordinate names to determine which fields are coordinates
393+
coordinate_fields = set(template.coordinate_names)
389394
dimension_fields = []
390-
coordinate_field_patterns = ['_x', '_y', '_coord', 'gun', 'source', 'group']
391-
395+
392396
for field_name in index_headers.dtype.names:
393397
# Skip if it's already trace
394-
if field_name == 'trace':
398+
if field_name == "trace":
395399
continue
396-
# Check if it looks like a coordinate field
397-
is_coordinate = any(pattern in field_name for pattern in coordinate_field_patterns)
398-
if not is_coordinate:
400+
# Check if this field is a coordinate field according to the template
401+
if field_name not in coordinate_fields:
399402
dimension_fields.append(field_name)
400-
403+
401404
# Extract only dimension fields for trace indexing
402-
if dimension_fields:
403-
dimension_headers = index_headers[dimension_fields]
404-
else:
405-
# If no dimension fields, use all fields
406-
dimension_headers = index_headers
407-
405+
dimension_headers = index_headers[dimension_fields] if dimension_fields else index_headers
406+
408407
# Create trace indices based on dimension fields only
409408
dimension_headers_with_trace = analyze_non_indexed_headers(dimension_headers)
410-
409+
411410
# Add the trace field back to the full index_headers array
412-
if dimension_headers_with_trace is not None and 'trace' in dimension_headers_with_trace.dtype.names:
411+
if dimension_headers_with_trace is not None and "trace" in dimension_headers_with_trace.dtype.names:
413412
# Extract just the trace values array (not the whole structured array)
414-
trace_values = np.array(dimension_headers_with_trace['trace'])
413+
trace_values = np.array(dimension_headers_with_trace["trace"])
415414
# Append as a new field to the full headers
416-
index_headers = rfn.append_fields(index_headers, 'trace', trace_values, usemask=False)
417-
415+
index_headers = rfn.append_fields(index_headers, "trace", trace_values, usemask=False)
416+
418417
return index_headers
419418

420419
def transform_index_names(self, index_names: Sequence[str]) -> Sequence[str]:
@@ -467,6 +466,7 @@ def transform(
467466
self,
468467
index_headers: HeaderArray,
469468
grid_overrides: dict[str, bool | int],
469+
template: AbstractDatasetTemplate, # noqa: ARG002
470470
) -> NDArray:
471471
"""Perform the grid transform."""
472472
self.validate(index_headers, grid_overrides)
@@ -504,6 +504,7 @@ def transform(
504504
self,
505505
index_headers: HeaderArray,
506506
grid_overrides: dict[str, bool | int],
507+
template: AbstractDatasetTemplate, # noqa: ARG002
507508
) -> NDArray:
508509
"""Perform the grid transform."""
509510
self.validate(index_headers, grid_overrides)
@@ -565,6 +566,7 @@ def run(
565566
index_names: Sequence[str],
566567
grid_overrides: dict[str, bool],
567568
chunksize: Sequence[int] | None = None,
569+
template: AbstractDatasetTemplate | None = None,
568570
) -> tuple[HeaderArray, tuple[str], tuple[int]]:
569571
"""Run grid overrides and return result."""
570572
for override in grid_overrides:
@@ -575,7 +577,7 @@ def run(
575577
raise GridOverrideUnknownError(override)
576578

577579
function = self.commands[override].transform
578-
index_headers = function(index_headers, grid_overrides=grid_overrides)
580+
index_headers = function(index_headers, grid_overrides=grid_overrides, template=template)
579581

580582
function = self.commands[override].transform_index_names
581583
index_names = function(index_names)

src/mdio/segy/utilities.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,11 +71,14 @@ def get_grid_plan( # noqa: C901, PLR0913
7171
horizontal_coordinates,
7272
chunksize=chunksize,
7373
grid_overrides=grid_overrides,
74+
template=template,
7475
)
7576
# Use the spatial dimension names from horizontal_coordinates (which may have been modified by grid overrides)
7677
# Extract only the dimension names (not including non-dimension coordinates)
7778
# After grid overrides, trace might have been added to horizontal_coordinates
78-
transformed_spatial_dims = [name for name in horizontal_coordinates if name in horizontal_dimensions or name == "trace"]
79+
transformed_spatial_dims = [
80+
name for name in horizontal_coordinates if name in horizontal_dimensions or name == "trace"
81+
]
7982

8083
dimensions = []
8184
for dim_name in transformed_spatial_dims:

tests/integration/test_import_streamer_grid_overrides.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
import dask
99
import numpy as np
10-
import numpy.testing as npt
1110
import pytest
1211
import xarray.testing as xrt
1312
from tests.integration.conftest import get_segy_mock_4d_spec
@@ -27,6 +26,7 @@
2726
dask.config.set(scheduler="synchronous")
2827
os.environ["MDIO__IMPORT__SAVE_SEGY_FILE_HEADER"] = "true"
2928

29+
3030
# TODO(BrianMichell): Add non-binned back
3131
# https://github.com/TGSAI/mdio-python/issues/612
3232
# @pytest.mark.parametrize("grid_override", [{"NonBinned": True, "chunksize": 4}, {"HasDuplicates": True}])

0 commit comments

Comments
 (0)