Skip to content

Commit 12dfee9

Browse files
Make templates more clear by removing spatial dim names, and also allow changing full chunk shape. (TGSAI#722)
* refactor: replace `abstract_dataset_template` with `base`, update dimensions handling * feat: add chunk_shape property with setter to manage variable chunk shapes * refactor: rename `chunk_shape` to `full_chunk_shape` for clarity and consistency * refactor: remove `_spatial_dim_names` * refactor: update tests to use `full_chunk_shape`, replacing `_var_chunk_shape` * refactor: update `chunk_grid` construction to use `full_chunk_shape` * refactor: update required fields extraction to exclude data domain dimension * refactor: replace hardcoded dimension slicing with `spatial_dimension_names` property for consistency and clarity * refactor: update seismic template tests to focus on `full_chunk_shape` handling and validation * refactor: reorder `dimension_names` for easy diff * Apply suggestion from @BrianMichell Co-authored-by: Brian Michell <[email protected]> * refactor: update `subset` type from list to tuple for consistency in SEG-Y parsers * refactor: rename variable `tpl` to `template` --------- Co-authored-by: Altay Sansal <[email protected]> Co-authored-by: Brian Michell <[email protected]>
1 parent e47f743 commit 12dfee9

24 files changed

+71
-109
lines changed

docs/template_registry.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,10 @@ If you have a custom template class, register an instance so others can fetch it
6767
```python
6868
from typing import Any
6969
from mdio.builder.template_registry import register_template
70-
from mdio.builder.templates.abstract_dataset_template import AbstractDatasetTemplate
70+
from mdio.builder.templates.base import AbstractDatasetTemplate
7171
from mdio.builder.templates.types import SeismicDataDomain
7272

73+
7374
class MyTemplate(AbstractDatasetTemplate):
7475
def __init__(self, domain: SeismicDataDomain = "time"):
7576
super().__init__(domain)
@@ -82,6 +83,7 @@ class MyTemplate(AbstractDatasetTemplate):
8283
def _load_dataset_attributes(self) -> dict[str, Any]:
8384
return {"surveyType": "2D", "gatherType": "custom"}
8485

86+
8587
# Make it available globally
8688
registered_name = register_template(MyTemplate("time"))
8789
print(registered_name) # "MyTemplateTime"

src/mdio/builder/template_registry.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import threading
2020
from typing import TYPE_CHECKING
2121

22-
from mdio.builder.templates.abstract_dataset_template import AbstractDatasetTemplate
22+
from mdio.builder.templates.base import AbstractDatasetTemplate
2323
from mdio.builder.templates.seismic_2d_poststack import Seismic2DPostStackTemplate
2424
from mdio.builder.templates.seismic_2d_prestack_cdp import Seismic2DPreStackCDPTemplate
2525
from mdio.builder.templates.seismic_2d_prestack_shot import Seismic2DPreStackShotTemplate
@@ -29,7 +29,7 @@
2929
from mdio.builder.templates.seismic_3d_prestack_shot import Seismic3DPreStackShotTemplate
3030

3131
if TYPE_CHECKING:
32-
from mdio.builder.templates.abstract_dataset_template import AbstractDatasetTemplate
32+
from mdio.builder.templates.base import AbstractDatasetTemplate
3333

3434

3535
__all__ = [

src/mdio/builder/templates/abstract_dataset_template.py renamed to src/mdio/builder/templates/base.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@ def __init__(self, data_domain: SeismicDataDomain) -> None:
3737
msg = "domain must be 'depth' or 'time'"
3838
raise ValueError(msg)
3939

40-
self._spatial_dim_names: tuple[str, ...] = ()
4140
self._dim_names: tuple[str, ...] = ()
4241
self._physical_coord_names: tuple[str, ...] = ()
4342
self._logical_coord_names: tuple[str, ...] = ()
@@ -99,8 +98,8 @@ def trace_domain(self) -> str:
9998

10099
@property
101100
def spatial_dimension_names(self) -> tuple[str, ...]:
102-
"""Returns the names of only the spatial dimensions."""
103-
return copy.deepcopy(self._spatial_dim_names)
101+
"""Returns the names of the dimensions excluding the last axis."""
102+
return copy.deepcopy(self._dim_names[:-1])
104103

105104
@property
106105
def dimension_names(self) -> tuple[str, ...]:
@@ -123,10 +122,18 @@ def coordinate_names(self) -> tuple[str, ...]:
123122
return copy.deepcopy(self._physical_coord_names + self._logical_coord_names)
124123

125124
@property
126-
def full_chunk_size(self) -> tuple[int, ...]:
127-
"""Returns the chunk size for the variables."""
125+
def full_chunk_shape(self) -> tuple[int, ...]:
126+
"""Returns the chunk shape for the variables."""
128127
return copy.deepcopy(self._var_chunk_shape)
129128

129+
@full_chunk_shape.setter
130+
def full_chunk_shape(self, shape: tuple[int, ...]) -> None:
131+
"""Sets the chunk shape for the variables."""
132+
if len(shape) != len(self._dim_sizes):
133+
msg = f"Chunk shape {shape} does not match dimension sizes {self._dim_sizes}"
134+
raise ValueError(msg)
135+
self._var_chunk_shape = shape
136+
130137
@property
131138
@abstractmethod
132139
def _name(self) -> str:
@@ -192,7 +199,7 @@ def _add_coordinates(self) -> None:
192199
for name in self.coordinate_names:
193200
self._builder.add_coordinate(
194201
name=name,
195-
dimensions=self._spatial_dim_names,
202+
dimensions=self.spatial_dimension_names,
196203
data_type=ScalarType.FLOAT64,
197204
compressor=compressors.Blosc(cname=compressors.BloscCname.zstd),
198205
metadata=CoordinateMetadata(units_v1=self.get_unit_by_key(name)),
@@ -202,15 +209,15 @@ def _add_trace_mask(self) -> None:
202209
"""Add trace mask variables."""
203210
self._builder.add_variable(
204211
name="trace_mask",
205-
dimensions=self._spatial_dim_names,
212+
dimensions=self.spatial_dimension_names,
206213
data_type=ScalarType.BOOL,
207214
compressor=compressors.Blosc(cname=compressors.BloscCname.zstd), # also default in zarr3
208215
coordinates=self.coordinate_names,
209216
)
210217

211218
def _add_trace_headers(self, header_dtype: StructuredType) -> None:
212219
"""Add trace mask variables."""
213-
chunk_grid = RegularChunkGrid(configuration=RegularChunkShape(chunk_shape=self._var_chunk_shape[:-1]))
220+
chunk_grid = RegularChunkGrid(configuration=RegularChunkShape(chunk_shape=self.full_chunk_shape[:-1]))
214221
self._builder.add_variable(
215222
name="headers",
216223
dimensions=self.spatial_dimension_names,
@@ -226,7 +233,7 @@ def _add_variables(self) -> None:
226233
A virtual method that can be overwritten by subclasses to add custom variables.
227234
Uses the class field 'builder' to add variables to the dataset.
228235
"""
229-
chunk_grid = RegularChunkGrid(configuration=RegularChunkShape(chunk_shape=self._var_chunk_shape))
236+
chunk_grid = RegularChunkGrid(configuration=RegularChunkShape(chunk_shape=self.full_chunk_shape))
230237
unit = self.get_unit_by_key(self._default_variable_name)
231238
self._builder.add_variable(
232239
name=self.default_variable_name,

src/mdio/builder/templates/seismic_2d_poststack.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from typing import Any
44

5-
from mdio.builder.templates.abstract_dataset_template import AbstractDatasetTemplate
5+
from mdio.builder.templates.base import AbstractDatasetTemplate
66
from mdio.builder.templates.types import SeismicDataDomain
77

88

@@ -12,8 +12,7 @@ class Seismic2DPostStackTemplate(AbstractDatasetTemplate):
1212
def __init__(self, data_domain: SeismicDataDomain):
1313
super().__init__(data_domain=data_domain)
1414

15-
self._spatial_dim_names = ("cdp",)
16-
self._dim_names = (*self._spatial_dim_names, self._data_domain)
15+
self._dim_names = ("cdp", self._data_domain)
1716
self._physical_coord_names = ("cdp_x", "cdp_y")
1817
self._var_chunk_shape = (1024, 1024)
1918

src/mdio/builder/templates/seismic_2d_prestack_cdp.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from typing import Any
44

5-
from mdio.builder.templates.abstract_dataset_template import AbstractDatasetTemplate
5+
from mdio.builder.templates.base import AbstractDatasetTemplate
66
from mdio.builder.templates.types import CdpGatherDomain
77
from mdio.builder.templates.types import SeismicDataDomain
88

@@ -18,8 +18,7 @@ def __init__(self, data_domain: SeismicDataDomain, gather_domain: CdpGatherDomai
1818
msg = "gather_type must be 'offset' or 'angle'"
1919
raise ValueError(msg)
2020

21-
self._spatial_dim_names = ("cdp", self._gather_domain)
22-
self._dim_names = (*self._spatial_dim_names, self._data_domain)
21+
self._dim_names = ("cdp", self._gather_domain, self._data_domain)
2322
self._physical_coord_names = ("cdp_x", "cdp_y")
2423
self._var_chunk_shape = (16, 64, 1024)
2524

src/mdio/builder/templates/seismic_2d_prestack_shot.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from mdio.builder.schemas import compressors
66
from mdio.builder.schemas.dtype import ScalarType
77
from mdio.builder.schemas.v1.variable import CoordinateMetadata
8-
from mdio.builder.templates.abstract_dataset_template import AbstractDatasetTemplate
8+
from mdio.builder.templates.base import AbstractDatasetTemplate
99
from mdio.builder.templates.types import SeismicDataDomain
1010

1111

@@ -15,8 +15,7 @@ class Seismic2DPreStackShotTemplate(AbstractDatasetTemplate):
1515
def __init__(self, data_domain: SeismicDataDomain):
1616
super().__init__(data_domain=data_domain)
1717

18-
self._spatial_dim_names = ("shot_point", "channel")
19-
self._dim_names = (*self._spatial_dim_names, self._data_domain)
18+
self._dim_names = ("shot_point", "channel", self._data_domain)
2019
self._physical_coord_names = ("source_coord_x", "source_coord_y", "group_coord_x", "group_coord_y")
2120
self._logical_coord_names = ("gun",)
2221
self._var_chunk_shape = (16, 32, 2048)

src/mdio/builder/templates/seismic_3d_poststack.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from typing import Any
44

5-
from mdio.builder.templates.abstract_dataset_template import AbstractDatasetTemplate
5+
from mdio.builder.templates.base import AbstractDatasetTemplate
66
from mdio.builder.templates.types import SeismicDataDomain
77

88

@@ -12,8 +12,7 @@ class Seismic3DPostStackTemplate(AbstractDatasetTemplate):
1212
def __init__(self, data_domain: SeismicDataDomain):
1313
super().__init__(data_domain=data_domain)
1414

15-
self._spatial_dim_names = ("inline", "crossline")
16-
self._dim_names = (*self._spatial_dim_names, self._data_domain)
15+
self._dim_names = ("inline", "crossline", self._data_domain)
1716
self._physical_coord_names = ("cdp_x", "cdp_y")
1817
self._var_chunk_shape = (128, 128, 128)
1918

src/mdio/builder/templates/seismic_3d_prestack_cdp.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from typing import Any
44

5-
from mdio.builder.templates.abstract_dataset_template import AbstractDatasetTemplate
5+
from mdio.builder.templates.base import AbstractDatasetTemplate
66
from mdio.builder.templates.types import CdpGatherDomain
77
from mdio.builder.templates.types import SeismicDataDomain
88

@@ -18,8 +18,7 @@ def __init__(self, data_domain: SeismicDataDomain, gather_domain: CdpGatherDomai
1818
msg = "gather_type must be 'offset' or 'angle'"
1919
raise ValueError(msg)
2020

21-
self._spatial_dim_names = ("inline", "crossline", self._gather_domain)
22-
self._dim_names = (*self._spatial_dim_names, self._data_domain)
21+
self._dim_names = ("inline", "crossline", self._gather_domain, self._data_domain)
2322
self._physical_coord_names = ("cdp_x", "cdp_y")
2423
self._var_chunk_shape = (8, 8, 32, 512)
2524

src/mdio/builder/templates/seismic_3d_prestack_coca.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from mdio.builder.schemas import compressors
66
from mdio.builder.schemas.dtype import ScalarType
77
from mdio.builder.schemas.v1.variable import CoordinateMetadata
8-
from mdio.builder.templates.abstract_dataset_template import AbstractDatasetTemplate
8+
from mdio.builder.templates.base import AbstractDatasetTemplate
99
from mdio.builder.templates.types import SeismicDataDomain
1010

1111

@@ -15,8 +15,7 @@ class Seismic3DPreStackCocaTemplate(AbstractDatasetTemplate):
1515
def __init__(self, data_domain: SeismicDataDomain):
1616
super().__init__(data_domain=data_domain)
1717

18-
self._spatial_dim_names = ("inline", "crossline", "offset", "azimuth")
19-
self._dim_names = (*self._spatial_dim_names, self._data_domain)
18+
self._dim_names = ("inline", "crossline", "offset", "azimuth", self._data_domain)
2019
self._physical_coord_names = ("cdp_x", "cdp_y")
2120
self._var_chunk_shape = (8, 8, 32, 1, 1024)
2221

src/mdio/builder/templates/seismic_3d_prestack_shot.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from mdio.builder.schemas import compressors
66
from mdio.builder.schemas.dtype import ScalarType
77
from mdio.builder.schemas.v1.variable import CoordinateMetadata
8-
from mdio.builder.templates.abstract_dataset_template import AbstractDatasetTemplate
8+
from mdio.builder.templates.base import AbstractDatasetTemplate
99
from mdio.builder.templates.types import SeismicDataDomain
1010

1111

@@ -15,8 +15,7 @@ class Seismic3DPreStackShotTemplate(AbstractDatasetTemplate):
1515
def __init__(self, data_domain: SeismicDataDomain):
1616
super().__init__(data_domain=data_domain)
1717

18-
self._spatial_dim_names = ("shot_point", "cable", "channel")
19-
self._dim_names = (*self._spatial_dim_names, self._data_domain)
18+
self._dim_names = ("shot_point", "cable", "channel", self._data_domain)
2019
self._physical_coord_names = ("source_coord_x", "source_coord_y", "group_coord_x", "group_coord_y")
2120
self._logical_coord_names = ("gun",)
2221
self._var_chunk_shape = (8, 1, 128, 2048)

0 commit comments

Comments
 (0)