Skip to content

Commit 3e5763a

Browse files
committed
Add test to cover prestack template.
1 parent 515c3a7 commit 3e5763a

File tree

4 files changed

+256
-38
lines changed

4 files changed

+256
-38
lines changed

src/mdio/builder/template_registry.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from mdio.builder.templates.seismic_3d_prestack_cdp import Seismic3DPreStackCDPTemplate
1212
from mdio.builder.templates.seismic_3d_prestack_coca import Seismic3DPreStackCocaTemplate
1313
from mdio.builder.templates.seismic_3d_prestack_shot import Seismic3DPreStackShotTemplate
14+
from mdio.builder.templates.seismic_prestack import SeismicPreStackTemplate
1415

1516

1617
class TemplateRegistry:
@@ -85,6 +86,7 @@ def _register_default_templates(self) -> None:
8586
self.register(Seismic3DPreStackCocaTemplate("depth"))
8687

8788
# Field (shot) data
89+
self.register(SeismicPreStackTemplate("time"))
8890
self.register(Seismic2DPreStackShotTemplate("time"))
8991
self.register(Seismic3DPreStackShotTemplate("time"))
9092

Lines changed: 38 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
"""SeismicPreStackTemplate MDIO v1 dataset templates."""
22

3-
from mdio.schemas.dtype import ScalarType
4-
from mdio.schemas.metadata import UserAttributes
5-
from mdio.schemas.v1.templates.abstract_dataset_template import AbstractDatasetTemplate
3+
from typing import Any
4+
5+
from mdio.builder.schemas.dtype import ScalarType
6+
from mdio.builder.schemas.v1.variable import CoordinateMetadata
7+
from mdio.builder.templates.abstract_dataset_template import AbstractDatasetTemplate
8+
from mdio.builder.templates.types import SeismicDataDomain
69

710

811
class SeismicPreStackTemplate(AbstractDatasetTemplate):
@@ -14,8 +17,8 @@ class SeismicPreStackTemplate(AbstractDatasetTemplate):
1417
domain: The domain of the dataset.
1518
"""
1619

17-
def __init__(self, domain: str = "time"):
18-
super().__init__(domain=domain)
20+
def __init__(self, data_domain: SeismicDataDomain):
21+
super().__init__(data_domain=data_domain)
1922

2023
self._coord_dim_names = [
2124
"shot_line",
@@ -24,54 +27,59 @@ def __init__(self, domain: str = "time"):
2427
"cable",
2528
"channel",
2629
] # Custom coordinates for shot gathers
27-
self._dim_names = [*self._coord_dim_names, self._trace_domain]
28-
self._coord_names = ["source_coord_x", "source_coord_y", "group_coord_x", "group_coord_y"]
30+
self._dim_names = [*self._coord_dim_names, self._data_domain]
31+
self._coord_names = [
32+
"energy_source_point_number",
33+
"source_coord_x",
34+
"source_coord_y",
35+
"group_coord_x",
36+
"group_coord_y",
37+
]
2938
self._var_chunk_shape = [1, 1, 16, 1, 32, -1]
3039

3140
@property
3241
def _name(self) -> str:
33-
return f"PreStackGathers3D{self._trace_domain.capitalize()}"
42+
return f"PreStackGathers3D{self._data_domain.capitalize()}"
3443

35-
def _load_dataset_attributes(self) -> UserAttributes:
36-
return UserAttributes(
37-
attributes={
38-
"surveyDimensionality": "3D",
39-
"ensembleType": "shot_point",
40-
"processingStage": "pre-stack",
41-
}
42-
)
44+
def _load_dataset_attributes(self) -> dict[str, Any]:
45+
return {
46+
"surveyDimensionality": "3D",
47+
"ensembleType": "shot_point",
48+
"processingStage": "pre-stack",
49+
}
4350

4451
def _add_coordinates(self) -> None:
4552
# Add dimension coordinates
4653
for name in self._dim_names:
47-
self._builder.add_coordinate(
48-
name,
49-
dimensions=[name],
50-
data_type=ScalarType.INT32,
51-
metadata_info=None,
52-
)
54+
self._builder.add_coordinate(name, dimensions=(name,), data_type=ScalarType.INT32)
5355

56+
# Add non-dimension coordinates
57+
self._builder.add_coordinate(
58+
"energy_source_point_number",
59+
dimensions=("shot_line", "gun", "shot_point"),
60+
data_type=ScalarType.INT32,
61+
)
5462
self._builder.add_coordinate(
5563
"source_coord_x",
56-
dimensions=["shot_line", "gun", "shot_point", "cable", "channel"],
64+
dimensions=("shot_line", "gun", "shot_point"),
5765
data_type=ScalarType.FLOAT64,
58-
metadata_info=[self._horizontal_coord_unit],
66+
metadata=CoordinateMetadata(units_v1=self._horizontal_coord_unit),
5967
)
6068
self._builder.add_coordinate(
6169
"source_coord_y",
62-
dimensions=["shot_line", "gun", "shot_point", "cable", "channel"],
70+
dimensions=("shot_line", "gun", "shot_point"),
6371
data_type=ScalarType.FLOAT64,
64-
metadata_info=[self._horizontal_coord_unit],
72+
metadata=CoordinateMetadata(units_v1=self._horizontal_coord_unit),
6573
)
6674
self._builder.add_coordinate(
6775
"group_coord_x",
68-
dimensions=["shot_line", "gun", "shot_point", "cable", "channel"],
76+
dimensions=("shot_line", "gun", "shot_point", "cable", "channel"),
6977
data_type=ScalarType.FLOAT64,
70-
metadata_info=[self._horizontal_coord_unit],
78+
metadata=CoordinateMetadata(units_v1=self._horizontal_coord_unit),
7179
)
7280
self._builder.add_coordinate(
7381
"group_coord_y",
74-
dimensions=["shot_line", "gun", "shot_point", "cable", "channel"],
82+
dimensions=("shot_line", "gun", "shot_point", "cable", "channel"),
7583
data_type=ScalarType.FLOAT64,
76-
metadata_info=[self._horizontal_coord_unit],
84+
metadata=CoordinateMetadata(units_v1=self._horizontal_coord_unit),
7785
)
Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
"""Unit tests for SeismicPreStackTemplate."""
2+
3+
import pytest
4+
from tests.unit.v1.helpers import validate_variable
5+
6+
from mdio.builder.schemas.chunk_grid import RegularChunkGrid
7+
from mdio.builder.schemas.compressors import Blosc
8+
from mdio.builder.schemas.compressors import BloscCname
9+
from mdio.builder.schemas.dtype import ScalarType
10+
from mdio.builder.schemas.dtype import StructuredType
11+
from mdio.builder.schemas.v1.dataset import Dataset
12+
from mdio.builder.schemas.v1.units import LengthUnitEnum
13+
from mdio.builder.schemas.v1.units import LengthUnitModel
14+
from mdio.builder.templates.seismic_prestack import SeismicPreStackTemplate
15+
16+
UNITS_METER = LengthUnitModel(length=LengthUnitEnum.METER)
17+
18+
19+
def _validate_coordinates_headers_trace_mask(dataset: Dataset, headers: StructuredType, domain: str) -> None:
20+
"""Validate the coordinate, headers, trace_mask variables in the dataset."""
21+
# Verify variables
22+
# 6 dim coords + 5 non-dim coords + 1 data + 1 trace mask + 1 headers = 14 variables
23+
assert len(dataset.variables) == 14
24+
25+
# Verify trace headers
26+
validate_variable(
27+
dataset,
28+
name="headers",
29+
dims=[("shot_line", 1), ("gun", 3), ("shot_point", 256), ("cable", 512), ("channel", 24)],
30+
coords=["energy_source_point_number", "source_coord_x", "source_coord_y", "group_coord_x", "group_coord_y"],
31+
dtype=headers,
32+
)
33+
34+
validate_variable(
35+
dataset,
36+
name="trace_mask",
37+
dims=[("shot_line", 1), ("gun", 3), ("shot_point", 256), ("cable", 512), ("channel", 24)],
38+
coords=["energy_source_point_number", "source_coord_x", "source_coord_y", "group_coord_x", "group_coord_y"],
39+
dtype=ScalarType.BOOL,
40+
)
41+
42+
# Verify dimension coordinate variables
43+
shot_line = validate_variable(
44+
dataset,
45+
name="shot_line",
46+
dims=[("shot_line", 1)],
47+
coords=["shot_line"],
48+
dtype=ScalarType.INT32,
49+
)
50+
assert shot_line.metadata is None
51+
52+
gun = validate_variable(
53+
dataset,
54+
name="gun",
55+
dims=[("gun", 3)],
56+
coords=["gun"],
57+
dtype=ScalarType.INT32,
58+
)
59+
assert gun.metadata is None
60+
61+
shot_point = validate_variable(
62+
dataset,
63+
name="shot_point",
64+
dims=[("shot_point", 256)],
65+
coords=["shot_point"],
66+
dtype=ScalarType.INT32,
67+
)
68+
assert shot_point.metadata is None
69+
70+
cable = validate_variable(
71+
dataset,
72+
name="cable",
73+
dims=[("cable", 512)],
74+
coords=["cable"],
75+
dtype=ScalarType.INT32,
76+
)
77+
assert cable.metadata is None
78+
79+
channel = validate_variable(
80+
dataset,
81+
name="channel",
82+
dims=[("channel", 24)],
83+
coords=["channel"],
84+
dtype=ScalarType.INT32,
85+
)
86+
assert channel.metadata is None
87+
88+
domain_var = validate_variable(
89+
dataset,
90+
name=domain,
91+
dims=[(domain, 2048)],
92+
coords=[domain],
93+
dtype=ScalarType.INT32,
94+
)
95+
assert domain_var.metadata is None
96+
97+
# Verify non-dimension coordinate variables
98+
validate_variable(
99+
dataset,
100+
name="energy_source_point_number",
101+
dims=[("shot_line", 1), ("gun", 3), ("shot_point", 256)],
102+
coords=["energy_source_point_number"],
103+
dtype=ScalarType.INT32,
104+
)
105+
106+
source_coord_x = validate_variable(
107+
dataset,
108+
name="source_coord_x",
109+
dims=[("shot_line", 1), ("gun", 3), ("shot_point", 256)],
110+
coords=["source_coord_x"],
111+
dtype=ScalarType.FLOAT64,
112+
)
113+
assert source_coord_x.metadata.units_v1.length == LengthUnitEnum.METER
114+
115+
source_coord_y = validate_variable(
116+
dataset,
117+
name="source_coord_y",
118+
dims=[("shot_line", 1), ("gun", 3), ("shot_point", 256)],
119+
coords=["source_coord_y"],
120+
dtype=ScalarType.FLOAT64,
121+
)
122+
assert source_coord_y.metadata.units_v1.length == LengthUnitEnum.METER
123+
124+
group_coord_x = validate_variable(
125+
dataset,
126+
name="group_coord_x",
127+
dims=[("shot_line", 1), ("gun", 3), ("shot_point", 256), ("cable", 512), ("channel", 24)],
128+
coords=["group_coord_x"],
129+
dtype=ScalarType.FLOAT64,
130+
)
131+
assert group_coord_x.metadata.units_v1.length == LengthUnitEnum.METER
132+
133+
group_coord_y = validate_variable(
134+
dataset,
135+
name="group_coord_y",
136+
dims=[("shot_line", 1), ("gun", 3), ("shot_point", 256), ("cable", 512), ("channel", 24)],
137+
coords=["group_coord_y"],
138+
dtype=ScalarType.FLOAT64,
139+
)
140+
assert group_coord_y.metadata.units_v1.length == LengthUnitEnum.METER
141+
142+
143+
class TestSeismic3DPreStackShotTemplate:
144+
"""Unit tests for SeismicPreStackTemplate."""
145+
146+
def test_configuration(self) -> None:
147+
"""Unit tests for SeismicPreStackTemplate in time domain."""
148+
t = SeismicPreStackTemplate(data_domain="time")
149+
150+
# Template attributes for prestack shot
151+
assert t._data_domain == "time"
152+
assert t._coord_dim_names == ["shot_line", "gun", "shot_point", "cable", "channel"]
153+
assert t._dim_names == ["shot_line", "gun", "shot_point", "cable", "channel", "time"]
154+
assert t._coord_names == ["energy_source_point_number", "source_coord_x", "source_coord_y", "group_coord_x", "group_coord_y"]
155+
assert t._var_chunk_shape == [1, 1, 16, 1, 32, -1]
156+
157+
# Variables instantiated when build_dataset() is called
158+
assert t._builder is None
159+
assert t._dim_sizes == ()
160+
assert t._horizontal_coord_unit is None
161+
162+
# Verify prestack shot attributes
163+
attrs = t._load_dataset_attributes()
164+
assert attrs == {"surveyDimensionality": "3D", "ensembleType": "shot_point", "processingStage": "pre-stack"}
165+
166+
assert t.name == "PreStackGathers3DTime"
167+
168+
def test_build_dataset(self, structured_headers: StructuredType) -> None:
169+
"""Unit tests for SeismicPreStackTemplate build in time domain."""
170+
t = SeismicPreStackTemplate(data_domain="time")
171+
172+
assert t.name == "PreStackGathers3DTime"
173+
dataset = t.build_dataset(
174+
"North Sea 3D Shot Time",
175+
sizes=(1, 3, 256, 512, 24, 2048),
176+
horizontal_coord_unit=UNITS_METER,
177+
header_dtype=structured_headers,
178+
)
179+
180+
assert dataset.metadata.name == "North Sea 3D Shot Time"
181+
assert dataset.metadata.attributes["surveyDimensionality"] == "3D"
182+
assert dataset.metadata.attributes["ensembleType"] == "shot_point"
183+
assert dataset.metadata.attributes["processingStage"] == "pre-stack"
184+
185+
_validate_coordinates_headers_trace_mask(dataset, structured_headers, "time")
186+
187+
# Verify seismic variable (prestack shot time data)
188+
seismic = validate_variable(
189+
dataset,
190+
name="amplitude",
191+
dims=[("shot_line", 1), ("gun", 3), ("shot_point", 256), ("cable", 512), ("channel", 24), ("time", 2048)],
192+
coords=["energy_source_point_number", "source_coord_x", "source_coord_y", "group_coord_x", "group_coord_y"],
193+
dtype=ScalarType.FLOAT32,
194+
)
195+
assert isinstance(seismic.compressor, Blosc)
196+
assert seismic.compressor.cname == BloscCname.zstd
197+
assert isinstance(seismic.metadata.chunk_grid, RegularChunkGrid)
198+
assert seismic.metadata.chunk_grid.configuration.chunk_shape == (1, 1, 16, 1, 32, -1)
199+
assert seismic.metadata.stats_v1 is None
200+
201+
202+
@pytest.mark.parametrize("data_domain", ["Time", "TiME"])
203+
def test_domain_case_handling(data_domain: str) -> None:
204+
"""Test that domain parameter handles different cases correctly."""
205+
template = SeismicPreStackTemplate(data_domain=data_domain)
206+
assert template._data_domain == data_domain.lower()
207+
assert template.name.endswith(data_domain.capitalize())

0 commit comments

Comments
 (0)