Skip to content

Commit bdebc7e

Browse files
Use native zarr for trace worker instead of xarray for memory efficiency. (TGSAI#710)
* Ingestion with zarr instead of xarray * Optimizations * Pre-commit * Flag patch version change for debugging and testing * Use sensible variable names * revert version change this will trigger CI/CD to make a new release, better not done here, but must be done from `main` * remove unnecessary I/O calls by reading existing values --------- Co-authored-by: Altay Sansal <[email protected]> Co-authored-by: Altay Sansal <[email protected]>
1 parent 768368c commit bdebc7e

File tree

2 files changed

+31
-55
lines changed

2 files changed

+31
-55
lines changed

src/mdio/segy/_workers.py

Lines changed: 30 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,7 @@
1212
from segy import SegyFile
1313
from segy.arrays import HeaderArray
1414

15-
from mdio.api.io import to_mdio
16-
from mdio.builder.schemas.dtype import ScalarType
15+
from mdio.api.io import _normalize_storage_options
1716
from mdio.segy._raw_trace_wrapper import SegyFileRawTraceWrapper
1817
from mdio.segy.scalar import _get_coordinate_scalar
1918

@@ -22,15 +21,13 @@
2221
from segy.config import SegyHeaderOverrides
2322
from segy.schema import SegySpec
2423
from upath import UPath
25-
from xarray import Dataset as xr_Dataset
2624
from zarr import Array as zarr_Array
2725

28-
from xarray import Variable
26+
from zarr import open_group as zarr_open_group
2927
from zarr.core.config import config as zarr_config
3028

3129
from mdio.builder.schemas.v1.stats import CenteredBinHistogram
3230
from mdio.builder.schemas.v1.stats import SummaryStatistics
33-
from mdio.builder.xarray_builder import _get_fill_value
3431
from mdio.constants import fill_value_map
3532

3633
if TYPE_CHECKING:
@@ -100,7 +97,6 @@ def trace_worker( # noqa: PLR0913
10097
data_variable_name: str,
10198
region: dict[str, slice],
10299
grid_map: zarr_Array,
103-
dataset: xr_Dataset,
104100
) -> SummaryStatistics | None:
105101
"""Writes a subset of traces from a region of the dataset of Zarr file.
106102
@@ -112,7 +108,6 @@ def trace_worker( # noqa: PLR0913
112108
data_variable_name: Name of the data variable to write.
113109
region: Region of the dataset to write to.
114110
grid_map: Zarr array mapping live traces to their positions in the dataset.
115-
dataset: Xarray dataset containing the data to write.
116111
117112
Returns:
118113
SummaryStatistics object containing statistics about the written traces.
@@ -135,16 +130,15 @@ def trace_worker( # noqa: PLR0913
135130

136131
live_trace_indexes = local_grid_map[not_null].tolist()
137132

133+
# Open the zarr group to write directly
134+
storage_options = _normalize_storage_options(output_path)
135+
zarr_group = zarr_open_group(output_path.as_posix(), mode="r+", storage_options=storage_options)
136+
138137
header_key = "headers"
139138
raw_header_key = "raw_headers"
140139

141-
# Get subset of the dataset that has not yet been saved
142-
# The headers might not be present in the dataset
143-
worker_variables = [data_variable_name]
144-
if header_key in dataset.data_vars: # Keeping the `if` here to allow for more worker configurations
145-
worker_variables.append(header_key)
146-
if raw_header_key in dataset.data_vars:
147-
worker_variables.append(raw_header_key)
140+
# Check which variables exist in the zarr store
141+
available_arrays = list(zarr_group.array_keys())
148142

149143
# traces = segy_file.trace[live_trace_indexes]
150144
# Raw headers are not intended to remain as a feature of the SEGY ingestion.
@@ -153,51 +147,33 @@ def trace_worker( # noqa: PLR0913
153147
# NOTE: The `raw_header_key` code block should be removed in full as it will become dead code.
154148
traces = SegyFileRawTraceWrapper(segy_file, live_trace_indexes)
155149

156-
ds_to_write = dataset[worker_variables]
150+
# Compute slices once (headers exclude sample dimension)
151+
header_region_slices = region_slices[:-1] # Exclude sample dimension
152+
153+
full_shape = tuple(s.stop - s.start for s in region_slices)
154+
header_shape = tuple(s.stop - s.start for s in header_region_slices)
157155

158-
if raw_header_key in worker_variables:
159-
tmp_raw_headers = np.zeros_like(dataset[raw_header_key])
156+
# Write raw headers if they exist
157+
# Headers only have spatial dimensions (no sample dimension)
158+
if raw_header_key in available_arrays:
159+
raw_header_array = zarr_group[raw_header_key]
160+
tmp_raw_headers = np.full(header_shape, raw_header_array.fill_value)
160161
tmp_raw_headers[not_null] = traces.raw_header
162+
raw_header_array[header_region_slices] = tmp_raw_headers
161163

162-
ds_to_write[raw_header_key] = Variable(
163-
ds_to_write[raw_header_key].dims,
164-
tmp_raw_headers,
165-
attrs=ds_to_write[raw_header_key].attrs,
166-
encoding=ds_to_write[raw_header_key].encoding, # Not strictly necessary, but safer than not doing it.
167-
)
168-
169-
if header_key in worker_variables:
170-
# TODO(BrianMichell): Implement this better so that we can enable fill values without changing the code
171-
# https://github.com/TGSAI/mdio-python/issues/584
172-
tmp_headers = np.zeros_like(dataset[header_key])
164+
# Write headers if they exist
165+
# Headers only have spatial dimensions (no sample dimension)
166+
if header_key in available_arrays:
167+
header_array = zarr_group[header_key]
168+
tmp_headers = np.full(header_shape, header_array.fill_value)
173169
tmp_headers[not_null] = traces.header
174-
# Create a new Variable object to avoid copying the temporary array
175-
# The ideal solution is to use `ds_to_write[header_key][:] = tmp_headers`
176-
# but Xarray appears to be copying memory instead of doing direct assignment.
177-
# TODO(BrianMichell): #614 Look into this further.
178-
# https://github.com/TGSAI/mdio-python/issues/584
179-
ds_to_write[header_key] = Variable(
180-
ds_to_write[header_key].dims,
181-
tmp_headers,
182-
attrs=ds_to_write[header_key].attrs,
183-
encoding=ds_to_write[header_key].encoding, # Not strictly necessary, but safer than not doing it.
184-
)
185-
186-
data_variable = ds_to_write[data_variable_name]
187-
fill_value = _get_fill_value(ScalarType(data_variable.dtype.name))
188-
tmp_samples = np.full_like(data_variable, fill_value=fill_value)
189-
tmp_samples[not_null] = traces.sample
170+
header_array[header_region_slices] = tmp_headers
190171

191-
# TODO(BrianMichell): #614 Look into this further.
192-
# https://github.com/TGSAI/mdio-python/issues/584
193-
ds_to_write[data_variable_name] = Variable(
194-
ds_to_write[data_variable_name].dims,
195-
tmp_samples,
196-
attrs=ds_to_write[data_variable_name].attrs,
197-
encoding=ds_to_write[data_variable_name].encoding, # Not strictly necessary, but safer than not doing it.
198-
)
199-
200-
to_mdio(ds_to_write, output_path=output_path, region=region, mode="r+")
172+
# Write the data variable
173+
data_array = zarr_group[data_variable_name]
174+
tmp_samples = np.full(full_shape, data_array.fill_value)
175+
tmp_samples[not_null] = traces.sample
176+
data_array[region_slices] = tmp_samples
201177

202178
nonzero_samples = np.ma.masked_values(traces.sample, 0, copy=False)
203179
histogram = CenteredBinHistogram(bin_centers=[], counts=[])

src/mdio/segy/blocked_io.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def to_zarr( # noqa: PLR0913, PLR0915
9292
futures = []
9393
common_args = (segy_file_kwargs, output_path, data_variable_name)
9494
for region in chunk_iter:
95-
subset_args = (region, grid_map, dataset.isel(region))
95+
subset_args = (region, grid_map)
9696
future = executor.submit(trace_worker, *common_args, *subset_args)
9797
futures.append(future)
9898

0 commit comments

Comments
 (0)