Skip to content

Commit 3dcbf25

Browse files
(Re)implement dynamic chunking for live mask and non-dimension coordinates (TGSAI#660)
* Re-implement on-the-fly chunking for non-dimension coordinates and live-mask * refactor a little * make chunks a little smaller * ensure non-dim coordinates are compressed * readjust sizes --------- Co-authored-by: Altay Sansal <[email protected]>
1 parent bd541e1 commit 3dcbf25

File tree

7 files changed

+60
-16
lines changed

7 files changed

+60
-16
lines changed

src/mdio/builder/templates/abstract_dataset_template.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -160,22 +160,17 @@ def _add_coordinates(self) -> None:
160160
)
161161

162162
# Add non-dimension coordinates
163-
# TODO(Dmitriy Repin): do chunked write for non-dimensional coordinates and trace_mask
164-
# https://github.com/TGSAI/mdio-python/issues/587
165-
# The chunk size used for trace mask will be different from the _var_chunk_shape
166163
for i in range(len(self._coord_names)):
167164
self._builder.add_coordinate(
168165
self._coord_names[i],
169166
dimensions=self._coord_dim_names,
170167
data_type=ScalarType.FLOAT64,
168+
compressor=compressors.Blosc(cname=compressors.BloscCname.zstd),
171169
metadata=CoordinateMetadata(units_v1=self._horizontal_coord_unit),
172170
)
173171

174172
def _add_trace_mask(self) -> None:
175173
"""Add trace mask variables."""
176-
# TODO(Dmitriy Repin): do chunked write for non-dimensional coordinates and trace_mask
177-
# https://github.com/TGSAI/mdio-python/issues/587
178-
# The chunk size used for trace mask will be different from the _var_chunk_shape
179174
self._builder.add_variable(
180175
name="trace_mask",
181176
dimensions=self._dim_names[:-1], # All dimensions except vertical (the last one)

src/mdio/builder/templates/seismic_2d_prestack_shot.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from typing import Any
44

5+
from mdio.builder.schemas import compressors
56
from mdio.builder.schemas.dtype import ScalarType
67
from mdio.builder.schemas.v1.variable import CoordinateMetadata
78
from mdio.builder.templates.abstract_dataset_template import AbstractDatasetTemplate
@@ -32,33 +33,39 @@ def _add_coordinates(self) -> None:
3233
self._builder.add_coordinate(name, dimensions=(name,), data_type=ScalarType.INT32)
3334

3435
# Add non-dimension coordinates
36+
compressor = compressors.Blosc(cname=compressors.BloscCname.zstd)
3537
coordinate_metadata = CoordinateMetadata(units_v1=self._horizontal_coord_unit)
3638
self._builder.add_coordinate(
3739
"gun",
3840
dimensions=("shot_point",),
3941
data_type=ScalarType.UINT8,
42+
compressor=compressor,
4043
)
4144
self._builder.add_coordinate(
4245
"source_coord_x",
4346
dimensions=("shot_point",),
4447
data_type=ScalarType.FLOAT64,
48+
compressor=compressor,
4549
metadata=coordinate_metadata,
4650
)
4751
self._builder.add_coordinate(
4852
"source_coord_y",
4953
dimensions=("shot_point",),
5054
data_type=ScalarType.FLOAT64,
55+
compressor=compressor,
5156
metadata=coordinate_metadata,
5257
)
5358
self._builder.add_coordinate(
5459
"group_coord_x",
5560
dimensions=("shot_point", "channel"),
5661
data_type=ScalarType.FLOAT64,
62+
compressor=compressor,
5763
metadata=coordinate_metadata,
5864
)
5965
self._builder.add_coordinate(
6066
"group_coord_y",
6167
dimensions=("shot_point", "channel"),
6268
data_type=ScalarType.FLOAT64,
69+
compressor=compressor,
6370
metadata=coordinate_metadata,
6471
)

src/mdio/builder/templates/seismic_3d_prestack_coca.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from typing import Any
44

5+
from mdio.builder.schemas import compressors
56
from mdio.builder.schemas.dtype import ScalarType
67
from mdio.builder.schemas.v1.units import AngleUnitModel
78
from mdio.builder.schemas.v1.variable import CoordinateMetadata
@@ -59,15 +60,18 @@ def _add_coordinates(self) -> None:
5960
)
6061

6162
# Add non-dimension coordinates
63+
compressor = compressors.Blosc(cname=compressors.BloscCname.zstd)
6264
self._builder.add_coordinate(
6365
"cdp_x",
6466
dimensions=("inline", "crossline"),
6567
data_type=ScalarType.FLOAT64,
68+
compressor=compressor,
6669
metadata=CoordinateMetadata(units_v1=self._horizontal_coord_unit),
6770
)
6871
self._builder.add_coordinate(
6972
"cdp_y",
7073
dimensions=("inline", "crossline"),
7174
data_type=ScalarType.FLOAT64,
75+
compressor=compressor,
7276
metadata=CoordinateMetadata(units_v1=self._horizontal_coord_unit),
7377
)

src/mdio/builder/templates/seismic_3d_prestack_shot.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from typing import Any
44

5+
from mdio.builder.schemas import compressors
56
from mdio.builder.schemas.dtype import ScalarType
67
from mdio.builder.schemas.v1.variable import CoordinateMetadata
78
from mdio.builder.templates.abstract_dataset_template import AbstractDatasetTemplate
@@ -32,32 +33,38 @@ def _add_coordinates(self) -> None:
3233
self._builder.add_coordinate(name, dimensions=(name,), data_type=ScalarType.INT32)
3334

3435
# Add non-dimension coordinates
36+
compressor = compressors.Blosc(cname=compressors.BloscCname.zstd)
3537
self._builder.add_coordinate(
3638
"gun",
3739
dimensions=("shot_point",),
3840
data_type=ScalarType.UINT8,
41+
compressor=compressor,
3942
)
4043
self._builder.add_coordinate(
4144
"source_coord_x",
4245
dimensions=("shot_point",),
4346
data_type=ScalarType.FLOAT64,
47+
compressor=compressor,
4448
metadata=CoordinateMetadata(units_v1=self._horizontal_coord_unit),
4549
)
4650
self._builder.add_coordinate(
4751
"source_coord_y",
4852
dimensions=("shot_point",),
4953
data_type=ScalarType.FLOAT64,
54+
compressor=compressor,
5055
metadata=CoordinateMetadata(units_v1=self._horizontal_coord_unit),
5156
)
5257
self._builder.add_coordinate(
5358
"group_coord_x",
5459
dimensions=("shot_point", "cable", "channel"),
5560
data_type=ScalarType.FLOAT64,
61+
compressor=compressor,
5662
metadata=CoordinateMetadata(units_v1=self._horizontal_coord_unit),
5763
)
5864
self._builder.add_coordinate(
5965
"group_coord_y",
6066
dimensions=("shot_point", "cable", "channel"),
6167
data_type=ScalarType.FLOAT64,
68+
compressor=compressor,
6269
metadata=CoordinateMetadata(units_v1=self._horizontal_coord_unit),
6370
)

src/mdio/converters/segy.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@
3232
from mdio.converters.exceptions import GridTraceSparsityError
3333
from mdio.converters.type_converter import to_structured_type
3434
from mdio.core.grid import Grid
35+
from mdio.core.utils_write import MAX_COORDINATES_BYTES
36+
from mdio.core.utils_write import MAX_SIZE_LIVE_MASK
37+
from mdio.core.utils_write import get_constrained_chunksize
3538
from mdio.segy import blocked_io
3639
from mdio.segy.utilities import get_grid_plan
3740

@@ -429,6 +432,32 @@ def enhanced_add_variables() -> None:
429432
return mdio_template
430433

431434

435+
def _chunk_variable(ds: Dataset, target_variable_name: str) -> None:
436+
"""Determines and sets the chunking for a specific Variable in the Dataset."""
437+
# Find variable index by name
438+
index = next((i for i, obj in enumerate(ds.variables) if obj.name == target_variable_name), None)
439+
440+
def determine_target_size(var_type: str) -> int:
441+
"""Determines the target size (in bytes) for a Variable based on its type."""
442+
if var_type == "bool":
443+
return MAX_SIZE_LIVE_MASK
444+
return MAX_COORDINATES_BYTES
445+
446+
# Create the chunk grid metadata
447+
var_type = ds.variables[index].data_type
448+
full_shape = tuple(dim.size for dim in ds.variables[index].dimensions)
449+
target_size = determine_target_size(var_type)
450+
451+
chunk_shape = get_constrained_chunksize(full_shape, var_type, target_size)
452+
chunk_grid = RegularChunkGrid(configuration=RegularChunkShape(chunk_shape=chunk_shape))
453+
454+
# Create variable metadata if it doesn't exist
455+
if ds.variables[index].metadata is None:
456+
ds.variables[index].metadata = VariableMetadata()
457+
458+
ds.variables[index].metadata.chunk_grid = chunk_grid
459+
460+
432461
def segy_to_mdio( # noqa PLR0913
433462
segy_spec: SegySpec,
434463
mdio_template: AbstractDatasetTemplate,
@@ -487,6 +516,11 @@ def segy_to_mdio( # noqa PLR0913
487516

488517
_add_grid_override_to_metadata(dataset=mdio_ds, grid_overrides=grid_overrides)
489518

519+
# Dynamically chunk the variables based on their type
520+
_chunk_variable(ds=mdio_ds, target_variable_name="trace_mask") # trace_mask is a Variable and not a Coordinate
521+
for coord in mdio_template.coordinate_names:
522+
_chunk_variable(ds=mdio_ds, target_variable_name=coord)
523+
490524
xr_dataset: xr_Dataset = to_xarray_dataset(mdio_ds=mdio_ds)
491525

492526
xr_dataset, drop_vars_delayed = _populate_coordinates(

src/mdio/core/utils_write.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,8 @@
99
from numpy.typing import DTypeLike
1010

1111

12-
MAX_SIZE_LIVE_MASK = 512 * 1024**2
13-
14-
JsonSerializable = str | int | float | bool | None | dict[str, "JsonSerializable"] | list["JsonSerializable"]
12+
MAX_SIZE_LIVE_MASK = 256 * 1024**2
13+
MAX_COORDINATES_BYTES = 32 * 1024**2
1514

1615

1716
def get_constrained_chunksize(

tests/unit/test_auto_chunking.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,9 @@ class TestAutoChunkLiveMask:
4848
[
4949
((100,), (100,)), # small 1d
5050
((100, 100), (100, 100)), # small 2d
51-
((50000, 50000), (25000, 25000)), # large 2d
51+
((50000, 50000), (16667, 16667)), # large 2d
5252
((1500, 1500, 1500), (750, 750, 750)), # large 3d
53-
((1000, 1000, 100, 36), (334, 334, 100, 36)), # large 4d
53+
((1000, 1000, 100, 36), (250, 250, 100, 36)), # large 4d
5454
],
5555
)
5656
def test_auto_chunk_live_mask(
@@ -65,16 +65,14 @@ def test_auto_chunk_live_mask(
6565
@pytest.mark.parametrize(
6666
"shape",
6767
[
68-
# Below are >500MiB. Smaller ones tested above
68+
# Below are >250MiB. Smaller ones tested above
6969
(32768, 32768),
7070
(46341, 46341),
7171
(86341, 96341),
7272
(55000, 97500),
7373
(100000, 100000),
74-
(1024, 1024, 1024),
75-
(215, 215, 215, 215),
7674
(512, 216, 512, 400),
77-
(74, 74, 74, 74, 74),
75+
(64, 128, 64, 32, 64),
7876
(512, 17, 43, 200, 50),
7977
],
8078
)
@@ -84,6 +82,6 @@ def test_auto_chunk_live_mask_nbytes(self, shape: tuple[int, ...]) -> None:
8482
result = get_live_mask_chunksize(shape)
8583
chunk_elements = np.prod(result)
8684

87-
# We want them to be 500MB +/- 25%
85+
# We want them to be 250MB +/- 50%
8886
assert chunk_elements > MAX_SIZE_LIVE_MASK * 0.75
8987
assert chunk_elements < MAX_SIZE_LIVE_MASK * 1.25

0 commit comments

Comments
 (0)