Skip to content

Commit 806b32b

Browse files
committed
Implement working test for DuplicateTraces override
1 parent ef29660 commit 806b32b

File tree

3 files changed

+45
-9
lines changed

3 files changed

+45
-9
lines changed

src/mdio/builder/dataset_builder.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -134,12 +134,9 @@ def push_dimension(self, dimension: NamedDimension, position: int, new_dim_chunk
134134
def propogate_dimension(variable: Variable, position: int, new_dim_chunk_size: int) -> Variable:
135135
"""Propogates the dimension to the variable or coordinate."""
136136
from mdio.builder.schemas.chunk_grid import RegularChunkGrid, RegularChunkShape
137-
if len(variable.dimensions) <= position:
137+
if len(variable.dimensions) + 1 <= position:
138138
# Don't do anything if the new dimension is not within the Variable's domain
139139
return variable
140-
if variable.name == "trace_mask":
141-
# Special case for trace_mask. Don't do anything.
142-
return variable
143140
new_dimensions = variable.dimensions[:position] + [dimension] + variable.dimensions[position:]
144141

145142
# Get current chunk shape from metadata

src/mdio/segy/utilities.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,46 @@
2626
logger = logging.getLogger(__name__)
2727

2828

29+
def _create_delayed_trace_dimension_transform(headers_subset: HeaderArray, position: int) -> callable:
30+
"""Create a delayed transform function that adds a trace dimension and its coordinate.
31+
32+
This function creates a closure that captures the headers_subset and position,
33+
but defers the actual computation until the transform is executed by the dataset builder.
34+
The transform adds both the trace dimension and a corresponding coordinate.
35+
36+
Args:
37+
headers_subset: The header array containing trace information
38+
position: The position where the trace dimension should be inserted
39+
40+
Returns:
41+
A callable that can be used as a transform function
42+
"""
43+
def delayed_transform(builder):
44+
from mdio.builder.schemas.dtype import ScalarType
45+
46+
# Calculate the trace dimension size at execution time
47+
if "trace" in headers_subset.dtype.names:
48+
trace_size = int(np.max(headers_subset["trace"]))
49+
else:
50+
# Fallback: if trace field doesn't exist, we need to determine size differently
51+
raise ValueError("Trace field not found in headers_subset when executing delayed transform")
52+
53+
# Add the trace dimension
54+
trace_dimension = NamedDimension(name="trace", size=trace_size)
55+
builder.push_dimension(trace_dimension, position=position, new_dim_chunk_size=1)
56+
57+
# Add the corresponding coordinate for the trace dimension
58+
builder.add_coordinate(
59+
"trace",
60+
dimensions=("trace",),
61+
data_type=ScalarType.INT32,
62+
)
63+
64+
return builder
65+
66+
return delayed_transform
67+
68+
2969
def get_grid_plan( # noqa: C901
3070
segy_file: SegyFile,
3171
chunksize: tuple[int, ...] | None,
@@ -70,7 +110,8 @@ def get_grid_plan( # noqa: C901
70110

71111
if grid_overrides.get("HasDuplicates", False):
72112
pos = len(template.dimension_names) - 1 # TODO: Implement the negative position case...
73-
template._queue_transform(lambda builder: builder.push_dimension(NamedDimension(name="trace", size=np.max(headers_subset["trace"])), position=pos, new_dim_chunk_size=1))
113+
# Use the delayed transform function instead of a simple lambda
114+
template._queue_transform(_create_delayed_trace_dimension_transform(headers_subset, pos))
74115
horizontal_dimensions = (*horizontal_dimensions, "trace")
75116

76117
dimensions = []

tests/integration/test_import_streamer_grid_overrides.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,7 @@ def test_import_4d_segy( # noqa: PLR0913
5252
segy_spec=segy_spec,
5353
mdio_template=TemplateRegistry().get("PreStackShotGathers3DTime"),
5454
input_path=segy_path,
55-
# output_path=zarr_tmp,
56-
output_path="test_has_duplicates.mdio",
55+
output_path=zarr_tmp,
5756
overwrite=True,
5857
grid_overrides=grid_override,
5958
)
@@ -69,10 +68,9 @@ def test_import_4d_segy( # noqa: PLR0913
6968
assert ds["segy_file_header"].attrs["binaryHeader"]["samples_per_trace"] == num_samples
7069
assert ds.attrs["attributes"]["gridOverrides"] == grid_override
7170

72-
assert npt.assert_array_equal(ds["shot_point"], shots)
71+
xrt.assert_duckarray_equal(ds["shot_point"], shots)
7372
xrt.assert_duckarray_equal(ds["cable"], cables)
7473

75-
# assert grid.select_dim("trace") == Dimension(range(1, np.amax(receivers_per_cable) + 1), "trace")
7674
expected = list(range(1, np.amax(receivers_per_cable) + 1))
7775
xrt.assert_duckarray_equal(ds["trace"], expected)
7876

0 commit comments

Comments
 (0)