Skip to content

Commit a19c6b9

Browse files
committed
Implement fixes to ensure lazy allocation of data arrays on serialization
1 parent a8b12f3 commit a19c6b9

File tree

2 files changed

+42
-9
lines changed

2 files changed

+42
-9
lines changed

src/mdio/schemas/v1/dataset_serializer.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
"""Convert MDIO v1 schema Dataset to Xarray DataSet and write it in Zarr."""
22

33
import numpy as np
4+
from dask import array as dask_array
5+
from zarr import zeros as zarr_zeros
46
from numcodecs import Blosc as nc_Blosc
57
from xarray import DataArray as xr_DataArray
68
from xarray import Dataset as xr_Dataset
7-
from zarr import zeros as zarr_zeros
89
from zarr.core.chunk_key_encodings import V2ChunkKeyEncoding
910

1011
from mdio.converters.type_converter import to_numpy_dtype
@@ -177,8 +178,8 @@ def to_xarray_dataset(mdio_ds: Dataset) -> xr_Dataset: # noqa: PLR0912
177178
mdio_ds: The source MDIO dataset to construct from.
178179
179180
Notes:
180-
- We can't use Dask (e.g., dask_array.zeros) because of the problems with
181-
structured type support. We will uze zarr.zeros instead
181+
- Using dask.array.zeros for lazy evaluation to prevent eager memory allocation
182+
while maintaining support for structured dtypes
182183
183184
Returns:
184185
The constructed dataset with proper MDIO structure and metadata.
@@ -195,9 +196,14 @@ def to_xarray_dataset(mdio_ds: Dataset) -> xr_Dataset: # noqa: PLR0912
195196
dtype = to_numpy_dtype(v.data_type)
196197
chunks = _get_zarr_chunks(v, all_named_dims=all_named_dims)
197198

198-
# Use zarr.zeros to create an empty array with the specified shape and dtype
199-
# NOTE: zarr_format=2 is essential, to_zarr() will fail if zarr_format=2 is used
200-
data = zarr_zeros(shape=shape, dtype=dtype, zarr_format=2)
199+
if hasattr(dtype, "fields"):
200+
data = zarr_zeros(shape=shape, dtype=dtype, zarr_format=2)
201+
else:
202+
data = dask_array.zeros(shape=shape, dtype=dtype, chunks=chunks)
203+
204+
# Use dask.array.zeros to create a lazy array with the specified shape and dtype
205+
# This prevents eager memory allocation while maintaining support for structured dtypes
206+
data = dask_array.zeros(shape=shape, dtype=dtype, chunks=chunks)
201207
# Create a DataArray for the variable. We will set coords in the second pass
202208
dim_names = _get_dimension_names(v)
203209
data_array = xr_DataArray(data, dims=dim_names)

src/mdio/segy/_workers.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -117,13 +117,40 @@ def trace_worker( # noqa: PLR0913
117117
ds_to_write = dataset[[data_variable_name, "headers"]]
118118
ds_to_write = ds_to_write.reset_coords()
119119

120-
ds_to_write["headers"].data[not_null] = traces.header
121-
ds_to_write["headers"].data[~not_null] = 0
120+
try:
121+
# Create temporary array for headers with the correct shape
122+
tmp_headers = np.zeros(not_null.shape, dtype=ds_to_write["headers"].dtype)
123+
tmp_headers[not_null] = traces.header
124+
ds_to_write["headers"].data[:] = tmp_headers
125+
except Exception as e:
126+
print(f"Error writing headers: {e}")
127+
print(f"not_null.shape: {not_null.shape}")
128+
print(f"traces.header.shape: {traces.header.shape}")
129+
print(f"ds_to_write['headers'].data.shape: {ds_to_write['headers'].data.shape}")
130+
raise e
131+
122132
else:
123133
ds_to_write = dataset[[data_variable_name]]
124134
ds_to_write = ds_to_write.reset_coords()
125135

126-
ds_to_write[data_variable_name].data[not_null] = traces.sample
136+
try:
137+
# Get the sample dimension size from the data variable itself
138+
sample_dim_size = ds_to_write[data_variable_name].shape[-1]
139+
tmp_samples = np.zeros(not_null.shape + (sample_dim_size,), dtype=ds_to_write[data_variable_name].dtype)
140+
141+
# Assign trace samples to the correct positions
142+
# We need to handle the fact that traces.sample is (num_traces, num_samples)
143+
# and we want to put it into positions where not_null is True
144+
tmp_samples[not_null] = traces.sample
145+
ds_to_write[data_variable_name].data[:] = tmp_samples
146+
except Exception as e:
147+
print(f"Error writing samples: {e}")
148+
print(f"not_null.shape: {not_null.shape}")
149+
print(f"traces.sample.shape: {traces.sample.shape}")
150+
print(f"ds_to_write[data_variable_name].data.shape: {ds_to_write[data_variable_name].data.shape}")
151+
print(f"not_null.sum(): {not_null.sum()}")
152+
print(f"len(traces.sample): {len(traces.sample)}")
153+
raise e
127154

128155
out_path = output_location.uri
129156
ds_to_write.to_zarr(out_path, region=region, mode="r+", write_empty_chunks=False, zarr_format=2)

0 commit comments

Comments
 (0)