Skip to content

Commit 0ceccc1

Browse files
Eager memory allocation fix (TGSAI#609)
* Implement fixes to ensure lazy allocation of data arrays on serialization * Avoid unnecessary copies of data in memory * Linting * Eliminate immediate overwrite of `data` bug * Remove unused import * Set appropriate fill value for lazy arrays * Clean up header value handler * Resolve data serialization issues * Ensure all encodings are captured * Simplify dataset coordinate population logic by removing unused imports and redundant variable handling * Refactor `_workers.py` to streamline variable handling, replace manual Variable creation with direct assignment, and resolve redundant imports. * make better use of grid * fix type hint * make better use of grid * fix(regression): make dataset serialization less eager * update zarr * remove comment --------- Co-authored-by: Altay Sansal <[email protected]>
1 parent a8b12f3 commit 0ceccc1

File tree

6 files changed

+268
-267
lines changed

6 files changed

+268
-267
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ dependencies = [
2929
"segy (>=0.4.2,<0.5.0)",
3030
"tqdm (>=4.67.1,<5.0.0)",
3131
"xarray>=2025.7.1",
32-
"zarr (>=3.1.1,<4.0.0)",
32+
"zarr (>=3.1.2,<4.0.0)",
3333
]
3434

3535
[project.optional-dependencies]

src/mdio/converters/segy.py

Lines changed: 21 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ def _build_and_check_grid(segy_dimensions: list[Dimension], segy_file: SegyFile,
163163

164164

165165
def _get_coordinates(
166-
segy_dimensions: list[Dimension],
166+
grid: Grid,
167167
segy_headers: SegyHeaderArray,
168168
mdio_template: AbstractDatasetTemplate,
169169
) -> tuple[list[Dimension], dict[str, SegyHeaderArray]]:
@@ -174,7 +174,7 @@ def _get_coordinates(
174174
The last dimension is always the vertical domain dimension
175175
176176
Args:
177-
segy_dimensions: List of of all SEG-Y dimensions.
177+
grid: Inferred MDIO grid for SEG-Y file.
178178
segy_headers: Headers read in from SEG-Y file.
179179
mdio_template: The MDIO template to use for the conversion.
180180
@@ -188,19 +188,15 @@ def _get_coordinates(
188188
- A dict of non-dimension coordinates (str: N-D arrays).
189189
"""
190190
dimensions_coords = []
191-
dim_names = [dim.name for dim in segy_dimensions]
192191
for dim_name in mdio_template.dimension_names:
193-
try:
194-
dim_index = dim_names.index(dim_name)
195-
except ValueError:
192+
if dim_name not in grid.dim_names:
196193
err = f"Dimension '{dim_name}' was not found in SEG-Y dimensions."
197-
raise ValueError(err) from err
198-
dimensions_coords.append(segy_dimensions[dim_index])
194+
raise ValueError(err)
195+
dimensions_coords.append(grid.select_dim(dim_name))
199196

200197
non_dim_coords: dict[str, SegyHeaderArray] = {}
201-
available_headers = segy_headers.dtype.names
202198
for coord_name in mdio_template.coordinate_names:
203-
if coord_name not in available_headers:
199+
if coord_name not in segy_headers.dtype.names:
204200
err = f"Coordinate '{coord_name}' not found in SEG-Y dimensions."
205201
raise ValueError(err)
206202
non_dim_coords[coord_name] = segy_headers[coord_name]
@@ -227,12 +223,14 @@ def populate_non_dim_coordinates(
227223
"""Populate the xarray dataset with coordinate variables."""
228224
not_null = grid.map[:] != UINT32_MAX
229225
for c_name, c_values in coordinates.items():
230-
dataset[c_name].values[not_null] = c_values
226+
c_tmp_array = dataset[c_name].values
227+
c_tmp_array[not_null] = c_values
228+
dataset[c_name][:] = c_tmp_array
231229
drop_vars_delayed.append(c_name)
232230
return dataset, drop_vars_delayed
233231

234232

235-
def _get_horizontal_coordinate_unit(segy_headers: list[Dimension]) -> LengthUnitEnum | None:
233+
def _get_horizontal_coordinate_unit(segy_headers: list[Dimension]) -> AllUnits | None:
236234
"""Get the coordinate unit from the SEG-Y headers."""
237235
name = TraceHeaderFieldsRev0.COORDINATE_UNIT.name.upper()
238236
unit_hdr = next((c for c in segy_headers if c.name.upper() == name), None)
@@ -347,15 +345,17 @@ def segy_to_mdio(
347345

348346
grid = _build_and_check_grid(segy_dimensions, segy_file, segy_headers)
349347

350-
dimensions, non_dim_coords = _get_coordinates(segy_dimensions, segy_headers, mdio_template)
351-
shape = [len(dim.coords) for dim in dimensions]
348+
dimensions, non_dim_coords = _get_coordinates(grid, segy_headers, mdio_template)
352349
# TODO(Altay): Turn this dtype into packed representation
353350
# https://github.com/TGSAI/mdio-python/issues/601
354351
headers = to_structured_type(segy_spec.trace.header.dtype)
355352

356353
horizontal_unit = _get_horizontal_coordinate_unit(segy_dimensions)
357354
mdio_ds: Dataset = mdio_template.build_dataset(
358-
name=mdio_template.name, sizes=shape, horizontal_coord_unit=horizontal_unit, headers=headers
355+
name=mdio_template.name,
356+
sizes=grid.shape,
357+
horizontal_coord_unit=horizontal_unit,
358+
headers=headers,
359359
)
360360

361361
_add_text_binary_headers(dataset=mdio_ds, segy_file=segy_file)
@@ -376,18 +376,12 @@ def segy_to_mdio(
376376
# IMPORTANT: Do not drop the "trace_mask" here, as it will be used later in
377377
# blocked_io.to_zarr() -> _workers.trace_worker()
378378

379-
# Write the xarray dataset to Zarr with as following:
380-
# Populated arrays:
381-
# - 1D dimensional coordinates
382-
# - ND non-dimensional coordinates
383-
# - ND trace_mask
384-
# Empty arrays (will be populated later in chunks):
385-
# - ND+1 traces
386-
# - ND headers (no _FillValue set due to the bug https://github.com/TGSAI/mdio-python/issues/582)
387-
# This will create the Zarr store with the correct structure
388-
# TODO(Dmitriy Repin): do chunked write for non-dimensional coordinates and trace_mask
389-
# https://github.com/TGSAI/mdio-python/issues/587
390-
xr_dataset.to_zarr(store=output_location.uri, mode="w", write_empty_chunks=False, zarr_format=2, compute=True)
379+
# This will create the Zarr store with the correct structure but with empty arrays
380+
xr_dataset.to_zarr(store=output_location.uri, mode="w", write_empty_chunks=False, zarr_format=2, compute=False)
381+
382+
# This will write the non-dimension coordinates and trace mask
383+
meta_ds = xr_dataset[drop_vars_delayed + ["trace_mask"]]
384+
meta_ds.to_zarr(store=output_location.uri, mode="r+", write_empty_chunks=False, zarr_format=2, compute=True)
391385

392386
# Now we can drop them to simplify chunked write of the data variable
393387
xr_dataset = xr_dataset.drop_vars(drop_vars_delayed)

src/mdio/schemas/v1/dataset_serializer.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
"""Convert MDIO v1 schema Dataset to Xarray DataSet and write it in Zarr."""
22

33
import numpy as np
4+
from dask import array as dask_array
45
from numcodecs import Blosc as nc_Blosc
56
from xarray import DataArray as xr_DataArray
67
from xarray import Dataset as xr_Dataset
7-
from zarr import zeros as zarr_zeros
88
from zarr.core.chunk_key_encodings import V2ChunkKeyEncoding
99

1010
from mdio.converters.type_converter import to_numpy_dtype
@@ -177,8 +177,8 @@ def to_xarray_dataset(mdio_ds: Dataset) -> xr_Dataset: # noqa: PLR0912
177177
mdio_ds: The source MDIO dataset to construct from.
178178
179179
Notes:
180-
- We can't use Dask (e.g., dask_array.zeros) because of the problems with
181-
structured type support. We will uze zarr.zeros instead
180+
- Using dask.array.zeros for lazy evaluation to prevent eager memory allocation
181+
while maintaining support for structured dtypes
182182
183183
Returns:
184184
The constructed dataset with proper MDIO structure and metadata.
@@ -195,9 +195,8 @@ def to_xarray_dataset(mdio_ds: Dataset) -> xr_Dataset: # noqa: PLR0912
195195
dtype = to_numpy_dtype(v.data_type)
196196
chunks = _get_zarr_chunks(v, all_named_dims=all_named_dims)
197197

198-
# Use zarr.zeros to create an empty array with the specified shape and dtype
199-
# NOTE: zarr_format=2 is essential, to_zarr() will fail if zarr_format=2 is used
200-
data = zarr_zeros(shape=shape, dtype=dtype, zarr_format=2)
198+
# Use dask.array.zeros to create a lazy array
199+
data = dask_array.full(shape=shape, dtype=dtype, chunks=chunks, fill_value=_get_fill_value(v.data_type))
201200
# Create a DataArray for the variable. We will set coords in the second pass
202201
dim_names = _get_dimension_names(v)
203202
data_array = xr_DataArray(data, dims=dim_names)

src/mdio/segy/_workers.py

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
import numpy as np
1111
from segy import SegyFile
1212

13+
from mdio.schemas import ScalarType
14+
1315
if TYPE_CHECKING:
1416
from segy.arrays import HeaderArray
1517
from segy.config import SegySettings
@@ -19,7 +21,9 @@
1921

2022
from mdio.core.storage_location import StorageLocation
2123

24+
2225
from mdio.constants import UINT32_MAX
26+
from mdio.schemas.v1.dataset_serializer import _get_fill_value
2327
from mdio.schemas.v1.stats import CenteredBinHistogram
2428
from mdio.schemas.v1.stats import SummaryStatistics
2529

@@ -109,33 +113,37 @@ def trace_worker( # noqa: PLR0913
109113
live_trace_indexes = grid_map[not_null].tolist()
110114
traces = segy_file.trace[live_trace_indexes]
111115

116+
header_key = "headers"
117+
112118
# Get subset of the dataset that has not yet been saved
113119
# The headers might not be present in the dataset
114-
# TODO(Dmitriy Repin): Check, should we overwrite the 'dataset' instead to save the memory
115-
# https://github.com/TGSAI/mdio-python/issues/584
116-
if "headers" in dataset.data_vars:
117-
ds_to_write = dataset[[data_variable_name, "headers"]]
118-
ds_to_write = ds_to_write.reset_coords()
119-
120-
ds_to_write["headers"].data[not_null] = traces.header
121-
ds_to_write["headers"].data[~not_null] = 0
122-
else:
123-
ds_to_write = dataset[[data_variable_name]]
124-
ds_to_write = ds_to_write.reset_coords()
120+
worker_variables = [data_variable_name]
121+
if header_key in dataset.data_vars: # Keeping the `if` here to allow for more worker configurations
122+
worker_variables.append(header_key)
123+
124+
ds_to_write = dataset[worker_variables]
125+
126+
if header_key in worker_variables:
127+
# Create temporary array for headers with the correct shape
128+
# TODO(BrianMichell): Implement this better so that we can enable fill values without changing the code. #noqa: TD003
129+
tmp_headers = np.zeros_like(dataset[header_key])
130+
tmp_headers[not_null] = traces.header
131+
ds_to_write[header_key][:] = tmp_headers
125132

126-
ds_to_write[data_variable_name].data[not_null] = traces.sample
133+
data_variable = ds_to_write[data_variable_name]
134+
fill_value = _get_fill_value(ScalarType(data_variable.dtype.name))
135+
tmp_samples = np.full_like(data_variable, fill_value=fill_value)
136+
tmp_samples[not_null] = traces.sample
137+
ds_to_write[data_variable_name][:] = tmp_samples
127138

128-
out_path = output_location.uri
129-
ds_to_write.to_zarr(out_path, region=region, mode="r+", write_empty_chunks=False, zarr_format=2)
139+
ds_to_write.to_zarr(output_location.uri, region=region, mode="r+", write_empty_chunks=False, zarr_format=2)
130140

131141
histogram = CenteredBinHistogram(bin_centers=[], counts=[])
132142
return SummaryStatistics(
133143
count=traces.sample.size,
134144
min=traces.sample.min(),
135145
max=traces.sample.max(),
136146
sum=traces.sample.sum(),
137-
# TODO(Altay): Look at how to do the sum squares statistic correctly
138-
# https://github.com/TGSAI/mdio-python/issues/581
139147
sum_squares=(traces.sample**2).sum(),
140148
histogram=histogram,
141149
)

src/mdio/segy/blocked_io.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,6 @@ def to_zarr( # noqa: PLR0913, PLR0915
9898
num_workers = min(num_chunks, num_cpus)
9999
context = mp.get_context("spawn")
100100
executor = ProcessPoolExecutor(max_workers=num_workers, mp_context=context)
101-
# return executor
102101

103102
segy_kw = {
104103
"url": segy_file.fs.unstrip_protocol(segy_file.url),

0 commit comments

Comments
 (0)