Skip to content

Commit faeb616

Browse files
Export functionality for MDIO v1 ingested files (TGSAI#611)
* Export part 1 * Enable header value validation * Revert the test names back * Remove Endianness, new_chunks API args and traceDomain, * PR review * lint * create/use new api location and lint * allow configuring opener chunks * clarify xarray open parameters * fix regression of not-opening with native dask re-chunking * fix regression of not-opening with native dask re-chunking * make export rechunker work with named dimension sizes and chunks * make StorageLocation available at library level and update mdio to segy example * pre-open with zarr backend and simplify dataset slicing after lazy loading * better opener docs * more explicit xarray selection * rename trace variable name to default variable name * remove the guard for setting storage options to empty dictionary. new zarr is ok with None. * update lockfile * fix broken tests and inconsistent type hints * clean up comments * clarify binary header scaling * make test names clearer * fix broken unit tests due to storage_options handling --------- Co-authored-by: Altay Sansal <[email protected]>
1 parent fcffad8 commit faeb616

22 files changed

+651
-582
lines changed

src/mdio/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from mdio.api import MDIOReader
66
from mdio.api import MDIOWriter
77
from mdio.api.convenience import copy_mdio
8+
from mdio.api.opener import open_dataset
89
from mdio.converters import mdio_to_segy
910
from mdio.converters import numpy_to_mdio
1011
from mdio.converters import segy_to_mdio
@@ -14,11 +15,13 @@
1415
from mdio.core.factory import create_empty
1516
from mdio.core.factory import create_empty_like
1617
from mdio.core.grid import Grid
18+
from mdio.core.storage_location import StorageLocation
1719

1820
__all__ = [
1921
"MDIOReader",
2022
"MDIOWriter",
2123
"copy_mdio",
24+
"open_dataset",
2225
"mdio_to_segy",
2326
"numpy_to_mdio",
2427
"segy_to_mdio",
@@ -28,6 +31,7 @@
2831
"create_empty",
2932
"create_empty_like",
3033
"Grid",
34+
"StorageLocation",
3135
]
3236

3337

src/mdio/api/convenience.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,6 @@ def copy_mdio( # noqa: PLR0913
5151
storage_options_output: Storage options for output MDIO.
5252
5353
"""
54-
storage_options_input = storage_options_input or {}
55-
storage_options_output = storage_options_output or {}
56-
5754
create_empty_like(
5855
source_path,
5956
target_path,

src/mdio/api/opener.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
"""Utils for reading MDIO dataset."""
2+
3+
from __future__ import annotations
4+
5+
from typing import TYPE_CHECKING
6+
7+
import xarray as xr
8+
9+
if TYPE_CHECKING:
10+
from xarray.core.types import T_Chunks
11+
12+
from mdio.core.storage_location import StorageLocation
13+
14+
15+
def open_dataset(storage_location: StorageLocation, chunks: T_Chunks = None) -> xr.Dataset:
16+
"""Open a Zarr dataset from the specified storage location.
17+
18+
Args:
19+
storage_location: StorageLocation for the dataset.
20+
chunks: If provided, loads data into dask arrays with new chunking.
21+
- ``chunks="auto"`` will use dask ``auto`` chunking
22+
- ``chunks=None`` skips using dask, which is generally faster for small arrays.
23+
- ``chunks=-1`` loads the data with dask using a single chunk for all arrays.
24+
- ``chunks={}`` loads the data with dask using the engine's preferred chunk size (on disk).
25+
- ``chunks={dim: chunk, ...}`` loads the data with dask using the specified chunk size for each dimension.
26+
27+
See dask chunking for more details.
28+
29+
Returns:
30+
An Xarray dataset opened from the storage location.
31+
"""
32+
# NOTE: If mask_and_scale is not set,
33+
# Xarray will convert int to float and replace _FillValue with NaN
34+
# Fixed in Zarr v3, so we can fix this later.
35+
return xr.open_dataset(storage_location.uri, engine="zarr", chunks=chunks, mask_and_scale=False)

src/mdio/converters/mdio.py

Lines changed: 51 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,13 @@
55
import os
66
from pathlib import Path
77
from tempfile import TemporaryDirectory
8+
from typing import TYPE_CHECKING
89

910
import numpy as np
1011
from psutil import cpu_count
1112
from tqdm.dask import TqdmCallback
1213

13-
from mdio import MDIOReader
14+
from mdio.api.opener import open_dataset
1415
from mdio.segy.blocked_io import to_segy
1516
from mdio.segy.creation import concat_files
1617
from mdio.segy.creation import mdio_spec_to_segy
@@ -21,18 +22,19 @@
2122
except ImportError:
2223
distributed = None
2324

25+
if TYPE_CHECKING:
26+
from segy.schema import SegySpec
27+
28+
from mdio.core.storage_location import StorageLocation
2429

2530
default_cpus = cpu_count(logical=True)
2631
NUM_CPUS = int(os.getenv("MDIO__EXPORT__CPU_COUNT", default_cpus))
2732

2833

29-
def mdio_to_segy( # noqa: PLR0912, PLR0913
30-
mdio_path_or_buffer: str,
31-
output_segy_path: str,
32-
endian: str = "big",
33-
access_pattern: str = "012",
34-
storage_options: dict = None,
35-
new_chunks: tuple[int, ...] = None,
34+
def mdio_to_segy( # noqa: PLR0912, PLR0913, PLR0915
35+
segy_spec: SegySpec,
36+
input_location: StorageLocation,
37+
output_location: StorageLocation,
3638
selection_mask: np.ndarray = None,
3739
client: distributed.Client = None,
3840
) -> None:
@@ -47,13 +49,9 @@ def mdio_to_segy( # noqa: PLR0912, PLR0913
4749
A `selection_mask` can be provided (same shape as spatial grid) to export a subset.
4850
4951
Args:
50-
mdio_path_or_buffer: Input path where the MDIO is located.
51-
output_segy_path: Path to the output SEG-Y file.
52-
endian: Endianness of the input SEG-Y. Rev.2 allows little endian. Default is 'big'.
53-
access_pattern: This specificies the chunk access pattern. Underlying zarr.Array must
54-
exist. Examples: '012', '01'
55-
storage_options: Storage options for the cloud storage backend. Default: None (anonymous)
56-
new_chunks: Set manual chunksize. For development purposes only.
52+
segy_spec: The SEG-Y specification to use for the conversion.
53+
input_location: Store or URL (and cloud options) for MDIO file.
54+
output_location: Path to the output SEG-Y file.
5755
selection_mask: Array that lists the subset of traces
5856
client: Dask client. If `None` we will use local threaded scheduler. If `auto` is used we
5957
will create multiple processes (with 8 threads each).
@@ -64,86 +62,70 @@ def mdio_to_segy( # noqa: PLR0912, PLR0913
6462
6563
Examples:
6664
To export an existing local MDIO file to SEG-Y we use the code snippet below. This will
67-
export the full MDIO (without padding) to SEG-Y format using IBM floats and big-endian
68-
byte order.
65+
export the full MDIO (without padding) to SEG-Y format.
6966
70-
>>> from mdio import mdio_to_segy
71-
>>>
67+
>>> from mdio import mdio_to_segy, StorageLocation
7268
>>>
73-
>>> mdio_to_segy(
74-
... mdio_path_or_buffer="prefix2/file.mdio",
75-
... output_segy_path="prefix/file.segy",
76-
... )
77-
78-
If we want to export this as an IEEE big-endian, using a selection mask, we would run:
79-
80-
>>> mdio_to_segy(
81-
... mdio_path_or_buffer="prefix2/file.mdio",
82-
... output_segy_path="prefix/file.segy",
83-
... selection_mask=boolean_mask,
84-
... )
85-
69+
>>> input_location = StorageLocation("prefix2/file.mdio")
70+
>>> output_location = StorageLocation("prefix/file.segy")
71+
>>> mdio_to_segy(input_location, output_location)
8672
"""
87-
backend = "dask"
88-
89-
output_segy_path = Path(output_segy_path)
73+
output_segy_path = Path(output_location.uri)
9074

91-
mdio = MDIOReader(
92-
mdio_path_or_buffer=mdio_path_or_buffer,
93-
access_pattern=access_pattern,
94-
storage_options=storage_options,
95-
)
75+
# First we open with vanilla zarr backend and then get some info
76+
# We will re-open with `new_chunks` and Dask later in mdio_spec_to_segy
77+
dataset = open_dataset(input_location)
9678

97-
if new_chunks is None:
98-
new_chunks = segy_export_rechunker(mdio.chunks, mdio.shape, mdio._traces.dtype)
79+
default_variable_name = dataset.attrs["attributes"]["default_variable_name"]
80+
amplitude = dataset[default_variable_name]
81+
chunks = amplitude.encoding["preferred_chunks"]
82+
sizes = amplitude.sizes
83+
dtype = amplitude.dtype
84+
new_chunks = segy_export_rechunker(chunks, sizes, dtype)
9985

100-
creation_args = [
101-
mdio_path_or_buffer,
102-
output_segy_path,
103-
access_pattern,
104-
endian,
105-
storage_options,
106-
new_chunks,
107-
backend,
108-
]
86+
creation_args = [segy_spec, input_location, output_location, new_chunks]
10987

11088
if client is not None:
11189
if distributed is not None:
11290
# This is in case we work with big data
11391
feature = client.submit(mdio_spec_to_segy, *creation_args)
114-
mdio, segy_factory = feature.result()
92+
dataset, segy_factory = feature.result()
11593
else:
11694
msg = "Distributed client was provided, but `distributed` is not installed"
11795
raise ImportError(msg)
11896
else:
119-
mdio, segy_factory = mdio_spec_to_segy(*creation_args)
97+
dataset, segy_factory = mdio_spec_to_segy(*creation_args)
12098

121-
live_mask = mdio.live_mask.compute()
99+
trace_mask = dataset["trace_mask"].compute()
122100

123101
if selection_mask is not None:
124-
live_mask = live_mask & selection_mask
102+
if trace_mask.shape != selection_mask.shape:
103+
msg = "Selection mask and trace mask shapes do not match."
104+
raise ValueError(msg)
105+
selection_mask = trace_mask.copy(data=selection_mask) # make into DataArray
106+
trace_mask = trace_mask & selection_mask
125107

126108
# This handles the case if we are skipping a whole block.
127-
if live_mask.sum() == 0:
109+
if trace_mask.sum() == 0:
128110
msg = "No traces will be written out. Live mask is empty."
129111
raise ValueError(msg)
130112

131113
# Find rough dim limits, so we don't unnecessarily hit disk / cloud store.
132114
# Typically, gets triggered when there is a selection mask
133-
dim_slices = ()
134-
live_nonzeros = live_mask.nonzero()
135-
for dim_nonzeros in live_nonzeros:
136-
start = np.min(dim_nonzeros)
137-
stop = np.max(dim_nonzeros) + 1
138-
dim_slices += (slice(start, stop),)
115+
dim_slices = {}
116+
dim_live_indices = np.nonzero(trace_mask.values)
117+
for dim_name, dim_live in zip(trace_mask.dims, dim_live_indices, strict=True):
118+
start = dim_live.min().item()
119+
stop = dim_live.max().item() + 1
120+
dim_slices[dim_name] = slice(start, stop)
139121

140-
# Lazily pull the data with limits now, and limit mask so its the same shape.
141-
live_mask, headers, samples = mdio[dim_slices]
142-
live_mask = live_mask.rechunk(headers.chunks)
122+
# Lazily pull the data with limits now.
123+
# All the variables, metadata, etc. is all sliced to the same range.
124+
dataset = dataset.isel(dim_slices)
143125

144126
if selection_mask is not None:
145127
selection_mask = selection_mask[dim_slices]
146-
live_mask = live_mask & selection_mask
128+
dataset["trace_mask"] = dataset["trace_mask"] & selection_mask
147129

148130
# tmp file root
149131
out_dir = output_segy_path.parent
@@ -152,9 +134,9 @@ def mdio_to_segy( # noqa: PLR0912, PLR0913
152134
with tmp_dir:
153135
with TqdmCallback(desc="Unwrapping MDIO Blocks"):
154136
block_records = to_segy(
155-
samples=samples,
156-
headers=headers,
157-
live_mask=live_mask,
137+
samples=dataset[default_variable_name].data,
138+
headers=dataset["headers"].data,
139+
live_mask=dataset["trace_mask"].data,
158140
segy_factory=segy_factory,
159141
file_root=tmp_dir.name,
160142
)

src/mdio/converters/segy.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -387,13 +387,13 @@ def segy_to_mdio(
387387
xr_dataset = xr_dataset.drop_vars(drop_vars_delayed)
388388

389389
# Write the headers and traces in chunks using grid_map to indicate dead traces
390-
data_variable_name = mdio_template.trace_variable_name
390+
default_variable_name = mdio_template.default_variable_name
391391
# This is an memory-expensive and time-consuming read-write operation
392392
# performed in chunks to save the memory
393393
blocked_io.to_zarr(
394394
segy_file=segy_file,
395395
output_location=output_location,
396396
grid_map=grid.map,
397397
dataset=xr_dataset,
398-
data_variable_name=data_variable_name,
398+
data_variable_name=default_variable_name,
399399
)

src/mdio/core/factory.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,6 @@ def create_empty(
121121
"""
122122
zarr.config.set({"default_zarr_format": 2, "write_empty_chunks": False})
123123

124-
storage_options = storage_options or {}
125-
126124
url = process_url(url=config.path, disk_cache=False)
127125
root_group = open_group(url, mode="w", storage_options=storage_options)
128126
root_group = create_zarr_hierarchy(root_group, overwrite)
@@ -206,9 +204,6 @@ def create_empty_like(
206204
storage_options_input: Options for storage backend of the source dataset.
207205
storage_options_output: Options for storage backend of the destination dataset.
208206
"""
209-
storage_options_input = storage_options_input or {}
210-
storage_options_output = storage_options_output or {}
211-
212207
source_root = zarr.open_consolidated(
213208
source_path,
214209
mode="r",

src/mdio/schemas/v1/templates/abstract_dataset_template.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,9 @@ def build_dataset(
8484
self._dim_sizes = sizes
8585
self._horizontal_coord_unit = horizontal_coord_unit
8686

87-
self._builder = MDIODatasetBuilder(name=name, attributes=self._load_dataset_attributes())
87+
attr = self._load_dataset_attributes() or UserAttributes(attributes={})
88+
attr.attributes["default_variable_name"] = self._default_variable_name
89+
self._builder = MDIODatasetBuilder(name=name, attributes=attr)
8890
self._add_dimensions()
8991
self._add_coordinates()
9092
self._add_variables()
@@ -99,9 +101,9 @@ def name(self) -> str:
99101
return self._name
100102

101103
@property
102-
def trace_variable_name(self) -> str:
104+
def default_variable_name(self) -> str:
103105
"""Returns the name of the trace variable."""
104-
return self._trace_variable_name
106+
return self._default_variable_name
105107

106108
@property
107109
def trace_domain(self) -> str:
@@ -130,7 +132,7 @@ def _name(self) -> str:
130132
"""
131133

132134
@property
133-
def _trace_variable_name(self) -> str:
135+
def _default_variable_name(self) -> str:
134136
"""Get the name of the data variable.
135137
136138
A virtual method that can be overwritten by subclasses to return a
@@ -226,7 +228,7 @@ def _add_variables(self) -> None:
226228
Uses the class field 'builder' to add variables to the dataset.
227229
"""
228230
self._builder.add_variable(
229-
name=self._trace_variable_name,
231+
name=self.default_variable_name,
230232
dimensions=self._dim_names,
231233
data_type=ScalarType.FLOAT32,
232234
compressor=compressors.Blosc(algorithm=compressors.BloscAlgorithm.ZSTD),

0 commit comments

Comments
 (0)