Skip to content

Commit fad687e

Browse files
committed
fix override parameter bug and add tests
1 parent db30fa2 commit fad687e

File tree

2 files changed

+220
-33
lines changed

2 files changed

+220
-33
lines changed

src/mdio/segy/geometry.py

Lines changed: 83 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -110,12 +110,21 @@ class GridOverrideCommand(ABC):
110110
def required_keys(self) -> set:
111111
"""Get the set of required keys for the grid override command."""
112112

113+
@property
114+
@abstractmethod
115+
def required_parameters(self) -> set:
116+
"""Get the set of required parameters for the grid override command."""
117+
113118
@abstractmethod
114-
def validate(self, index_headers: npt.NDArray, **kwargs) -> None:
119+
def validate(
120+
self, index_headers: npt.NDArray, grid_overrides: dict[str, bool | int]
121+
) -> None:
115122
"""Validate if this transform should run on the type of data."""
116123

117124
@abstractmethod
118-
def transform(self, index_headers: npt.NDArray, **kwargs) -> npt.NDArray:
125+
def transform(
126+
self, index_headers: npt.NDArray, grid_overrides: dict[str, bool | int]
127+
) -> dict[str, npt.NDArray]:
119128
"""Perform the grid transform."""
120129

121130
@property
@@ -129,25 +138,42 @@ def check_required_keys(self, index_headers: npt.NDArray) -> None:
129138
if not self.required_keys.issubset(index_names):
130139
raise GridOverrideKeysError(self.name, self.required_keys)
131140

141+
def check_required_params(self, grid_overrides: dict[str, str | int]) -> None:
142+
"""Check if all required keys are present in the index headers."""
143+
if self.required_parameters is None:
144+
return
145+
146+
passed_parameters = set(grid_overrides.keys())
147+
148+
if not self.required_parameters.issubset(passed_parameters):
149+
missing_params = self.required_parameters - passed_parameters
150+
raise GridOverrideMissingParameterError(self.name, missing_params)
151+
132152

133153
class AutoChannelWrap(GridOverrideCommand):
134154
"""Automatically determine Streamer acquisition type."""
135155

136156
required_keys = {"shot", "cable", "channel"}
157+
required_parameters = None
137158

138-
def validate(self, index_headers: npt.NDArray, **kwargs) -> None:
159+
def validate(
160+
self, index_headers: npt.NDArray, grid_overrides: dict[str, bool | int]
161+
) -> None:
139162
"""Validate if this transform should run on the type of data."""
140-
self.check_required_keys(index_headers)
141-
142-
if "ChannelWrap" in kwargs:
163+
if "ChannelWrap" in grid_overrides:
143164
raise GridOverrideIncompatibleError(self.name, "ChannelWrap")
144165

145-
if "CalculateCable" in kwargs:
166+
if "CalculateCable" in grid_overrides:
146167
raise GridOverrideIncompatibleError(self.name, "CalculateCable")
147168

148-
def transform(self, index_headers: npt.NDArray, **kwargs):
169+
self.check_required_keys(index_headers)
170+
self.check_required_params(grid_overrides)
171+
172+
def transform(
173+
self, index_headers: npt.NDArray, grid_overrides: dict[str, bool | int]
174+
) -> dict[str, npt.NDArray]:
149175
"""Perform the grid transform."""
150-
self.validate(index_headers, **kwargs)
176+
self.validate(index_headers, grid_overrides)
151177

152178
result = analyze_streamer_headers(index_headers)
153179
unique_cables, cable_chan_min, cable_chan_max, geom_type = result
@@ -179,22 +205,25 @@ class ChannelWrap(GridOverrideCommand):
179205
"""Wrap channels to start from one at cable boundaries."""
180206

181207
required_keys = {"shot", "cable", "channel"}
208+
required_parameters = {"ChannelsPerCable"}
182209

183-
def validate(self, index_headers: npt.NDArray, **kwargs) -> None:
210+
def validate(
211+
self, index_headers: npt.NDArray, grid_overrides: dict[str, bool | int]
212+
) -> None:
184213
"""Validate if this transform should run on the type of data."""
185-
self.check_required_keys(index_headers)
186-
187-
if "ChannelsPerCable" not in kwargs:
188-
raise GridOverrideMissingParameterError(self.name, "ChannelsPerCable")
189-
190-
if "AutoCableChannel" in kwargs:
214+
if "AutoChannelWrap" in grid_overrides:
191215
raise GridOverrideIncompatibleError(self.name, "AutoCableChannel")
192216

193-
def transform(self, index_headers: npt.NDArray, **kwargs) -> npt.NDArray:
217+
self.check_required_keys(index_headers)
218+
self.check_required_params(grid_overrides)
219+
220+
def transform(
221+
self, index_headers: npt.NDArray, grid_overrides: dict[str, bool | int]
222+
) -> dict[str, npt.NDArray]:
194223
"""Perform the grid transform."""
195-
self.validate(index_headers, **kwargs)
224+
self.validate(index_headers, grid_overrides)
196225

197-
channels_per_cable = kwargs["ChannelsPerCable"]
226+
channels_per_cable = grid_overrides["ChannelsPerCable"]
198227
index_headers["channel"] = (
199228
index_headers["channel"] - 1
200229
) % channels_per_cable + 1
@@ -206,20 +235,25 @@ class CalculateCable(GridOverrideCommand):
206235
"""Calculate cable numbers from unwrapped channels."""
207236

208237
required_keys = {"shot", "cable", "channel"}
238+
required_parameters = {"ChannelsPerCable"}
209239

210-
def validate(self, index_headers: npt.NDArray, **kwargs) -> None:
240+
def validate(
241+
self, index_headers: npt.NDArray, grid_overrides: dict[str, bool | int]
242+
) -> None:
211243
"""Validate if this transform should run on the type of data."""
212-
self.check_required_keys(index_headers)
213-
214-
if "ChannelsPerCable" not in kwargs:
215-
raise GridOverrideMissingParameterError(self.name, "ChannelsPerCable")
216-
217-
if "AutoCableChannel" in kwargs:
244+
if "AutoChannelWrap" in grid_overrides:
218245
raise GridOverrideIncompatibleError(self.name, "AutoCableChannel")
219246

220-
def transform(self, index_headers, **kwargs):
247+
self.check_required_keys(index_headers)
248+
self.check_required_params(grid_overrides)
249+
250+
def transform(
251+
self, index_headers, grid_overrides: dict[str, bool | int]
252+
) -> dict[str, npt.NDArray]:
221253
"""Perform the grid transform."""
222-
channels_per_cable = kwargs["ChannelsPerCable"]
254+
self.validate(index_headers, grid_overrides)
255+
256+
channels_per_cable = grid_overrides["ChannelsPerCable"]
223257
index_headers["cable"] = (
224258
index_headers["channel"] - 1
225259
) // channels_per_cable + 1
@@ -237,24 +271,40 @@ class GridOverrider:
237271
"""
238272

239273
def __init__(self):
240-
"""Define allowed overrides here."""
274+
"""Define allowed overrides and parameters here."""
241275
self.commands = {
242276
"AutoChannelWrap": AutoChannelWrap(),
243277
"CalculateCable": CalculateCable(),
244278
"ChannelWrap": ChannelWrap(),
245279
}
246280

281+
self.parameters = self.get_allowed_parameters()
282+
283+
def get_allowed_parameters(self) -> set:
284+
"""Get list of allowed parameters from the allowed commands."""
285+
parameters = set()
286+
for command in self.commands.values():
287+
if command.required_parameters is None:
288+
continue
289+
290+
parameters.update(command.required_parameters)
291+
292+
return parameters
293+
247294
def run(
248295
self,
249296
index_headers: npt.NDArray,
250297
grid_overrides: dict[str, bool],
251298
) -> npt.NDArray:
252299
"""Run grid overrides and return result."""
253300
for override in grid_overrides:
254-
if override in self.commands:
255-
function = self.commands[override].transform
256-
index_headers = function(index_headers, grid_overrides=grid_overrides)
257-
else:
301+
if override in self.parameters:
302+
continue
303+
304+
if override not in self.commands:
258305
raise GridOverrideUnknownError(override)
259306

307+
function = self.commands[override].transform
308+
index_headers = function(index_headers, grid_overrides=grid_overrides)
309+
260310
return index_headers
Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
"""Check grid overrides."""
2+
3+
4+
from __future__ import annotations
5+
6+
import numpy.typing as npt
7+
import pytest
8+
from numpy import arange
9+
from numpy import column_stack
10+
from numpy import meshgrid
11+
from numpy import unique
12+
from numpy.testing import assert_array_equal
13+
14+
from mdio.core import Dimension
15+
from mdio.segy.exceptions import GridOverrideIncompatibleError
16+
from mdio.segy.exceptions import GridOverrideMissingParameterError
17+
from mdio.segy.exceptions import GridOverrideUnknownError
18+
from mdio.segy.geometry import GridOverrider
19+
20+
21+
SHOTS = arange(100, 104, dtype="int32")
22+
CABLES = arange(11, 15, dtype="int32")
23+
RECEIVERS = arange(1, 6, dtype="int32")
24+
25+
26+
@pytest.fixture
27+
def mock_streamer_headers() -> dict[str, npt.NDArray]:
28+
"""Generate dictionary of mocked streamer index headers."""
29+
grids = meshgrid(SHOTS, CABLES, RECEIVERS, indexing="ij")
30+
permutations = column_stack([grid.ravel() for grid in grids])
31+
32+
# Make channel from receiver ids
33+
for shot in SHOTS:
34+
shot_mask = permutations[:, 0] == shot
35+
permutations[shot_mask, -1] = arange(1, len(CABLES) * len(RECEIVERS) + 1)
36+
37+
result = dict(
38+
shot=permutations[:, 0],
39+
cable=permutations[:, 1],
40+
channel=permutations[:, 2],
41+
)
42+
43+
return result
44+
45+
46+
class TestStreamerGridOverrides:
47+
"""Check grid overrides for shot data with streamer acquisition."""
48+
49+
def test_channel_wrap(self, mock_streamer_headers: npt.NDArray) -> None:
50+
"""Test the ChannelWrap command."""
51+
grid_overrides = {"ChannelWrap": True, "ChannelsPerCable": len(RECEIVERS)}
52+
53+
overrider = GridOverrider()
54+
results = overrider.run(mock_streamer_headers, grid_overrides)
55+
56+
dims = []
57+
for index_name, index_coords in results.items():
58+
dim_unique = unique(index_coords)
59+
dims.append(Dimension(coords=dim_unique, name=index_name))
60+
61+
assert_array_equal(dims[0], SHOTS)
62+
assert_array_equal(dims[1], CABLES)
63+
assert_array_equal(dims[2], RECEIVERS)
64+
65+
def test_calculate_cable(self, mock_streamer_headers: npt.NDArray) -> None:
66+
"""Test the CalculateCable command."""
67+
grid_overrides = {
68+
"CalculateCable": True,
69+
"ChannelsPerCable": len(RECEIVERS),
70+
}
71+
72+
overrider = GridOverrider()
73+
results = overrider.run(mock_streamer_headers, grid_overrides)
74+
75+
dims = []
76+
for index_name, index_coords in results.items():
77+
dim_unique = unique(index_coords)
78+
dims.append(Dimension(coords=dim_unique, name=index_name))
79+
80+
# We need channels because unwrap isn't done here
81+
channels = unique(mock_streamer_headers["channel"])
82+
83+
# We reset the cables to start from 1.
84+
cables = arange(1, len(CABLES) + 1, dtype="uint32")
85+
86+
assert_array_equal(dims[0], SHOTS)
87+
assert_array_equal(dims[1], cables)
88+
assert_array_equal(dims[2], channels)
89+
90+
def test_wrap_and_calc_cable(self, mock_streamer_headers: npt.NDArray) -> None:
91+
"""Test the combined ChannelWrap and CalculateCable commands."""
92+
grid_overrides = {
93+
"CalculateCable": True,
94+
"ChannelWrap": True,
95+
"ChannelsPerCable": len(RECEIVERS),
96+
}
97+
98+
overrider = GridOverrider()
99+
results = overrider.run(mock_streamer_headers, grid_overrides)
100+
101+
dims = []
102+
for index_name, index_coords in results.items():
103+
dim_unique = unique(index_coords)
104+
dims.append(Dimension(coords=dim_unique, name=index_name))
105+
106+
# We reset the cables to start from 1.
107+
cables = arange(1, len(CABLES) + 1, dtype="uint32")
108+
109+
assert_array_equal(dims[0], SHOTS)
110+
assert_array_equal(dims[1], cables)
111+
assert_array_equal(dims[2], RECEIVERS)
112+
113+
def test_missing_param(self, mock_streamer_headers: npt.NDArray) -> None:
114+
"""Test missing parameters for the commands."""
115+
overrider = GridOverrider()
116+
with pytest.raises(GridOverrideMissingParameterError):
117+
overrider.run(mock_streamer_headers, {"ChannelWrap": True})
118+
119+
with pytest.raises(GridOverrideMissingParameterError):
120+
overrider.run(mock_streamer_headers, {"CalculateCable": True})
121+
122+
def test_incompatible_overrides(self, mock_streamer_headers: npt.NDArray) -> None:
123+
"""Test commands that can't be run together."""
124+
overrider = GridOverrider()
125+
with pytest.raises(GridOverrideIncompatibleError):
126+
grid_overrides = {"ChannelWrap": True, "AutoChannelWrap": True}
127+
overrider.run(mock_streamer_headers, grid_overrides)
128+
129+
with pytest.raises(GridOverrideIncompatibleError):
130+
grid_overrides = {"CalculateCable": True, "AutoChannelWrap": True}
131+
overrider.run(mock_streamer_headers, grid_overrides)
132+
133+
def test_unknown_override(self, mock_streamer_headers: npt.NDArray) -> None:
134+
"""Test exception if user provides a command that's not allowed."""
135+
overrider = GridOverrider()
136+
with pytest.raises(GridOverrideUnknownError):
137+
overrider.run(mock_streamer_headers, {"WrongCommand": True})

0 commit comments

Comments
 (0)