Skip to content

Commit e04678f

Browse files
BrianMichelltasansal
authored andcommitted
Fully functional demo
1 parent 92184a9 commit e04678f

File tree

6 files changed

+89
-22
lines changed

6 files changed

+89
-22
lines changed

src/mdio/builder/dataset_builder.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -150,12 +150,32 @@ def add_coordinate( # noqa: PLR0913
150150
msg = "Adding coordinate with the same name twice is not allowed"
151151
raise ValueError(msg)
152152

153-
# Validate that all referenced dimensions are already defined
153+
# Resolve referenced dimensions strictly, except allow a single substitution with 'trace' if present.
154154
named_dimensions = []
155+
trace_dim = _get_named_dimension(self._dimensions, "trace")
156+
resolved_dim_names: list[str] = []
157+
trace_used = False
158+
missing_dims: list[str] = []
155159
for dim_name in dimensions:
156160
nd = _get_named_dimension(self._dimensions, dim_name)
161+
if nd is not None:
162+
if dim_name not in resolved_dim_names:
163+
resolved_dim_names.append(dim_name)
164+
continue
165+
if trace_dim is not None and not trace_used and "trace" not in resolved_dim_names:
166+
resolved_dim_names.append("trace")
167+
trace_used = True
168+
else:
169+
missing_dims.append(dim_name)
170+
171+
if missing_dims:
172+
msg = f"Pre-existing dimension named {missing_dims[0]!r} is not found"
173+
raise ValueError(msg)
174+
175+
for resolved_name in resolved_dim_names:
176+
nd = _get_named_dimension(self._dimensions, resolved_name)
157177
if nd is None:
158-
msg = f"Pre-existing dimension named {dim_name!r} is not found"
178+
msg = f"Pre-existing dimension named {resolved_name!r} is not found"
159179
raise ValueError(msg)
160180
named_dimensions.append(nd)
161181

@@ -174,7 +194,7 @@ def add_coordinate( # noqa: PLR0913
174194
self.add_variable(
175195
name=coord.name,
176196
long_name=coord.long_name,
177-
dimensions=dimensions, # dimension names (list[str])
197+
dimensions=tuple(resolved_dim_names), # resolved dimension names
178198
data_type=coord.data_type,
179199
compressor=compressor,
180200
coordinates=[name], # Use the coordinate name as a reference

src/mdio/builder/templates/base.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,21 @@ def build_dataset(
8989
self._builder = MDIODatasetBuilder(name=name, attributes=attributes)
9090
self._add_dimensions()
9191
self._add_coordinates()
92+
# Ensure any coordinates declared on the template but not added by subclass overrides
93+
# are materialized with generic defaults. This keeps templates override-agnostic while
94+
# allowing runtime-augmented coordinate lists to be respected.
95+
for coord_name in self.coordinate_names:
96+
try:
97+
self._builder.add_coordinate(
98+
name=coord_name,
99+
dimensions=self.spatial_dimension_names,
100+
data_type=ScalarType.FLOAT64,
101+
compressor=compressors.Blosc(cname=compressors.BloscCname.zstd),
102+
metadata=CoordinateMetadata(units_v1=self.get_unit_by_key(coord_name)),
103+
)
104+
except ValueError as exc: # coordinate may already exist from subclass override
105+
if "same name twice" not in str(exc):
106+
raise
92107
self._add_variables()
93108
self._add_trace_mask()
94109

src/mdio/converters/segy.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -178,16 +178,19 @@ def _scan_for_headers(
178178
)
179179
template._dim_names = actual_spatial_dims + (template.trace_domain,)
180180

181-
# Handle NonBinned: move non-binned dimensions to coordinates
181+
# If using NonBinned override, expose non-binned dims as logical coordinates on the template instance
182182
if grid_overrides and "NonBinned" in grid_overrides and "non_binned_dims" in grid_overrides:
183183
non_binned_dims = tuple(grid_overrides["non_binned_dims"])
184184
if non_binned_dims:
185185
logger.debug(
186-
"NonBinned grid override: moving dimensions %s to coordinates",
186+
"NonBinned grid override: exposing non-binned dims as coordinates: %s",
187187
non_binned_dims,
188188
)
189-
# Add non-binned dimensions as logical coordinates
190-
template._logical_coord_names = template._logical_coord_names + non_binned_dims
189+
# Append any missing names; keep existing order and avoid duplicates
190+
existing = set(template.coordinate_names)
191+
to_add = tuple(n for n in non_binned_dims if n not in existing)
192+
if to_add:
193+
template._logical_coord_names = template._logical_coord_names + to_add
191194

192195
return segy_dimensions, segy_headers
193196

src/mdio/segy/utilities.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -82,11 +82,28 @@ def get_grid_plan( # noqa: C901, PLR0913
8282
# Use the spatial dimension names from horizontal_coordinates (which may have been modified by grid overrides)
8383
# Extract only the dimension names (not including non-dimension coordinates or non-binned dimensions)
8484
# After grid overrides, trace might have been added to horizontal_coordinates
85-
transformed_spatial_dims = [
86-
name
87-
for name in horizontal_coordinates
88-
if (name in horizontal_dimensions or name == "trace") and name not in non_binned_dims
89-
]
85+
# Compute transformed spatial dims: drop non-binned dims, insert trace if present in headers
86+
transformed_spatial_dims = []
87+
for name in horizontal_coordinates:
88+
if name in non_binned_dims:
89+
continue
90+
if name == "trace" or name in horizontal_dimensions:
91+
transformed_spatial_dims.append(name)
92+
93+
# Recompute chunksize to match transformed dims
94+
original_spatial_dims = list(template.spatial_dimension_names)
95+
original_chunks = list(template.full_chunk_shape)
96+
new_spatial_chunks: list[int] = []
97+
# Insert trace chunk size at N-1 when present, otherwise map remaining dims
98+
for dim_name in transformed_spatial_dims:
99+
if dim_name == "trace":
100+
chunk_val = int(grid_overrides.get("chunksize", 1)) if "NonBinned" in grid_overrides else 1
101+
new_spatial_chunks.append(chunk_val)
102+
continue
103+
if dim_name in original_spatial_dims:
104+
idx = original_spatial_dims.index(dim_name)
105+
new_spatial_chunks.append(original_chunks[idx])
106+
chunksize = tuple(new_spatial_chunks + [original_chunks[-1]])
90107

91108
dimensions = []
92109
for dim_name in transformed_spatial_dims:

tests/integration/test_import_streamer_grid_overrides.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@
2727
os.environ["MDIO__IMPORT__SAVE_SEGY_FILE_HEADER"] = "true"
2828

2929

30-
@pytest.mark.parametrize("grid_override", [{"HasDuplicates": True}])
30+
# @pytest.mark.parametrize("grid_override", [{"NonBinned": True, "chunksize": 4, "non_binned_dims": ["channel"]}, {"HasDuplicates": True}])
31+
@pytest.mark.parametrize("grid_override", [{"NonBinned": True, "chunksize": 4, "non_binned_dims": ["channel"]}])
3132
@pytest.mark.parametrize("chan_header_type", [StreamerShotGeometryType.C])
3233
class TestImport4DNonReg:
3334
"""Test for 4D segy import with grid overrides."""
@@ -43,11 +44,14 @@ def test_import_4d_segy( # noqa: PLR0913
4344
segy_spec: SegySpec = get_segy_mock_4d_spec()
4445
segy_path = segy_mock_4d_shots[chan_header_type]
4546

47+
path = "tmp.mdio"
48+
print(f"Running test with grid override: {grid_override}")
49+
4650
segy_to_mdio(
4751
segy_spec=segy_spec,
4852
mdio_template=TemplateRegistry().get("PreStackShotGathers3DTime"),
4953
input_path=segy_path,
50-
output_path=zarr_tmp,
54+
output_path=path,
5155
overwrite=True,
5256
grid_overrides=grid_override,
5357
)
@@ -58,25 +62,31 @@ def test_import_4d_segy( # noqa: PLR0913
5862
cables = [0, 101, 201, 301]
5963
receivers_per_cable = [1, 5, 7, 5]
6064

61-
ds = open_mdio(zarr_tmp)
65+
ds = open_mdio(path)
6266

6367
assert ds["segy_file_header"].attrs["binaryHeader"]["samples_per_trace"] == num_samples
6468
assert ds.attrs["attributes"]["gridOverrides"] == grid_override
6569

6670
xrt.assert_duckarray_equal(ds["shot_point"], shots)
6771
xrt.assert_duckarray_equal(ds["cable"], cables)
6872

69-
# HasDuplicates should create a trace dimension
73+
# Both HasDuplicates and NonBinned should create a trace dimension
7074
expected = list(range(1, np.amax(receivers_per_cable) + 1))
7175
xrt.assert_duckarray_equal(ds["trace"], expected)
7276

7377
times_expected = list(range(0, num_samples, 1))
7478
xrt.assert_duckarray_equal(ds["time"], times_expected)
7579

76-
# HasDuplicates uses chunksize of 1 for trace dimension
80+
# Check trace chunk size based on grid override
7781
trace_chunks = ds["amplitude"].chunksizes.get("trace", None)
7882
if trace_chunks is not None:
79-
assert all(chunk == 1 for chunk in trace_chunks)
83+
if "NonBinned" in grid_override:
84+
# NonBinned uses specified chunksize for trace dimension
85+
expected_chunksize = grid_override.get("chunksize", 1)
86+
assert all(chunk == expected_chunksize for chunk in trace_chunks)
87+
else:
88+
# HasDuplicates uses chunksize of 1 for trace dimension
89+
assert all(chunk == 1 for chunk in trace_chunks)
8090

8191

8292
@pytest.mark.parametrize("grid_override", [{"AutoChannelWrap": True}, None])

tests/unit/test_segy_grid_overrides.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -103,10 +103,10 @@ def test_duplicates(self, mock_streamer_headers: dict[str, npt.NDArray]) -> None
103103
def test_non_binned(self, mock_streamer_headers: dict[str, npt.NDArray]) -> None:
104104
"""Test the NonBinned Grid Override command."""
105105
index_names = ("shot_point", "cable")
106-
grid_overrides = {"NonBinned": True, "chunksize": 4}
106+
grid_overrides = {"NonBinned": True, "chunksize": 4, "non_binned_dims": ["channel"]}
107107

108-
# Remove channel header
109-
streamer_headers = mock_streamer_headers[list(index_names)]
108+
# Keep channel header for non-binned processing
109+
streamer_headers = mock_streamer_headers
110110
chunksize = (4, 4, 8)
111111

112112
new_headers, new_names, new_chunks = run_override(
@@ -123,7 +123,9 @@ def test_non_binned(self, mock_streamer_headers: dict[str, npt.NDArray]) -> None
123123

124124
assert_array_equal(dims[0].coords, SHOTS)
125125
assert_array_equal(dims[1].coords, CABLES)
126-
assert_array_equal(dims[2].coords, RECEIVERS)
126+
# Trace coords are the unique channel values (1-20)
127+
expected_trace_coords = np.arange(1, 21, dtype="int32")
128+
assert_array_equal(dims[2].coords, expected_trace_coords)
127129

128130

129131
class TestStreamerGridOverrides:

0 commit comments

Comments
 (0)