Skip to content

Commit ef3505f

Browse files
committed
Working example of DuplicateIndex grid override
1 parent 693bf0b commit ef3505f

File tree

4 files changed

+75
-16
lines changed

4 files changed

+75
-16
lines changed

src/mdio/converters/segy.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,11 @@ def _scan_for_headers(
156156
"""Extract trace dimensions and index headers from the SEG-Y file.
157157
158158
This is an expensive operation.
159-
It scans the SEG-Y file in chunks by using ProcessPoolExecutor
159+
It scans the SEG-Y file in chunks by using ProcessPoolExecutor.
160+
161+
Note:
162+
If grid_overrides are applied to the template before calling this function,
163+
the chunk_size returned from get_grid_plan should match the template's chunk shape.
160164
"""
161165
full_chunk_shape = template.full_chunk_shape
162166
segy_dimensions, chunk_size, segy_headers = get_grid_plan(
@@ -167,13 +171,19 @@ def _scan_for_headers(
167171
chunksize=full_chunk_shape,
168172
grid_overrides=grid_overrides,
169173
)
174+
175+
# After applying grid overrides to the template, chunk sizes should match
176+
# If they don't match, it means the template wasn't properly updated
170177
if full_chunk_shape != chunk_size:
171-
# TODO(Dmitriy): implement grid overrides
172-
# https://github.com/TGSAI/mdio-python/issues/585
173-
# The returned 'chunksize' is used only for grid_overrides. We will need to use it when full
174-
# support for grid overrides is implemented
175-
err = "Support for changing full_chunk_shape in grid overrides is not yet implemented"
176-
raise NotImplementedError(err)
178+
logger.warning(
179+
"Chunk shape mismatch: template has %s but grid_plan returned %s. "
180+
"Using grid_plan chunk shape.",
181+
full_chunk_shape,
182+
chunk_size,
183+
)
184+
# Update the template's chunk shape to match what grid_plan returned
185+
template._var_chunk_shape = chunk_size
186+
177187
return segy_dimensions, segy_headers
178188

179189

@@ -562,6 +572,17 @@ def segy_to_mdio( # noqa PLR0913
562572
)
563573
grid = _build_and_check_grid(segy_dimensions, segy_file_info, segy_headers)
564574

575+
# Update template dimensions to match the actual grid dimensions after grid overrides
576+
# The chunk shape was already updated in _scan_for_headers, we just need to fix dimensions
577+
actual_spatial_dims = tuple(grid.dim_names[:-1]) # All dims except the vertical/time dimension
578+
if mdio_template.spatial_dimension_names != actual_spatial_dims:
579+
logger.info(
580+
"Adjusting template dimensions from %s to %s to match grid after overrides",
581+
mdio_template.spatial_dimension_names,
582+
actual_spatial_dims,
583+
)
584+
mdio_template._dim_names = actual_spatial_dims + (mdio_template.trace_domain,)
585+
565586
_, non_dim_coords = _get_coordinates(grid, segy_headers, mdio_template)
566587
header_dtype = to_structured_type(segy_spec.trace.header.dtype)
567588

src/mdio/segy/geometry.py

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,8 @@ def analyze_non_indexed_headers(index_headers: HeaderArray, dtype: DTypeLike = n
267267
header_names = []
268268
for header_key in index_headers.dtype.names:
269269
if header_key != "trace":
270-
unique_headers[header_key] = np.sort(np.unique(index_headers[header_key]))
270+
unique_vals = np.sort(np.unique(index_headers[header_key]))
271+
unique_headers[header_key] = unique_vals
271272
header_names.append(header_key)
272273
total_depth += 1
273274

@@ -382,7 +383,39 @@ def transform(
382383
"""Perform the grid transform."""
383384
self.validate(index_headers, grid_overrides)
384385

385-
return analyze_non_indexed_headers(index_headers)
386+
# Filter to only include dimension fields, not coordinate fields
387+
# Coordinate fields typically have _x, _y suffixes or specific names like 'gun'
388+
# We want to keep fields like shot_point, cable, channel but exclude source_coord_x, etc.
389+
dimension_fields = []
390+
coordinate_field_patterns = ['_x', '_y', '_coord', 'gun', 'source', 'group']
391+
392+
for field_name in index_headers.dtype.names:
393+
# Skip if it's already trace
394+
if field_name == 'trace':
395+
continue
396+
# Check if it looks like a coordinate field
397+
is_coordinate = any(pattern in field_name for pattern in coordinate_field_patterns)
398+
if not is_coordinate:
399+
dimension_fields.append(field_name)
400+
401+
# Extract only dimension fields for trace indexing
402+
if dimension_fields:
403+
dimension_headers = index_headers[dimension_fields]
404+
else:
405+
# If no dimension fields, use all fields
406+
dimension_headers = index_headers
407+
408+
# Create trace indices based on dimension fields only
409+
dimension_headers_with_trace = analyze_non_indexed_headers(dimension_headers)
410+
411+
# Add the trace field back to the full index_headers array
412+
if dimension_headers_with_trace is not None and 'trace' in dimension_headers_with_trace.dtype.names:
413+
# Extract just the trace values array (not the whole structured array)
414+
trace_values = np.array(dimension_headers_with_trace['trace'])
415+
# Append as a new field to the full headers
416+
index_headers = rfn.append_fields(index_headers, 'trace', trace_values, usemask=False)
417+
418+
return index_headers
386419

387420
def transform_index_names(self, index_names: Sequence[str]) -> Sequence[str]:
388421
"""Insert dimension "trace" to the sample-1 dimension."""

src/mdio/segy/utilities.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,15 @@ def get_grid_plan( # noqa: C901, PLR0913
7272
chunksize=chunksize,
7373
grid_overrides=grid_overrides,
7474
)
75+
# Use the spatial dimension names from horizontal_coordinates (which may have been modified by grid overrides)
76+
# Extract only the dimension names (not including non-dimension coordinates)
77+
# After grid overrides, trace might have been added to horizontal_coordinates
78+
transformed_spatial_dims = [name for name in horizontal_coordinates if name in horizontal_dimensions or name == "trace"]
7579

7680
dimensions = []
77-
for dim_name in horizontal_dimensions:
81+
for dim_name in transformed_spatial_dims:
82+
if dim_name not in headers_subset.dtype.names:
83+
continue
7884
dim_unique = np.unique(headers_subset[dim_name])
7985
dimensions.append(Dimension(coords=dim_unique, name=dim_name))
8086

tests/integration/test_import_streamer_grid_overrides.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,12 @@
2727
dask.config.set(scheduler="synchronous")
2828
os.environ["MDIO__IMPORT__SAVE_SEGY_FILE_HEADER"] = "true"
2929

30-
31-
# TODO(Altay): Finish implementing these grid overrides.
30+
# TODO(BrianMichell): Add non-binned back
3231
# https://github.com/TGSAI/mdio-python/issues/612
33-
@pytest.mark.skip(reason="NonBinned and HasDuplicates haven't been properly implemented yet.")
34-
@pytest.mark.parametrize("grid_override", [{"NonBinned": True}, {"HasDuplicates": True}])
32+
# @pytest.mark.parametrize("grid_override", [{"NonBinned": True, "chunksize": 4}, {"HasDuplicates": True}])
33+
@pytest.mark.parametrize("grid_override", [{"HasDuplicates": True}])
3534
@pytest.mark.parametrize("chan_header_type", [StreamerShotGeometryType.C])
36-
class TestImport4DNonReg: # pragma: no cover - tests is skipped
35+
class TestImport4DNonReg:
3736
"""Test for 4D segy import with grid overrides."""
3837

3938
def test_import_4d_segy( # noqa: PLR0913
@@ -67,7 +66,7 @@ def test_import_4d_segy( # noqa: PLR0913
6766
assert ds["segy_file_header"].attrs["binaryHeader"]["samples_per_trace"] == num_samples
6867
assert ds.attrs["attributes"]["gridOverrides"] == grid_override
6968

70-
assert npt.assert_array_equal(ds["shot_point"], shots)
69+
xrt.assert_duckarray_equal(ds["shot_point"], shots)
7170
xrt.assert_duckarray_equal(ds["cable"], cables)
7271

7372
# assert grid.select_dim("trace") == Dimension(range(1, np.amax(receivers_per_cable) + 1), "trace")

0 commit comments

Comments
 (0)