diff --git a/src/mdio/commands/segy.py b/src/mdio/commands/segy.py index 71e1c4d4..0bebda1d 100644 --- a/src/mdio/commands/segy.py +++ b/src/mdio/commands/segy.py @@ -13,6 +13,11 @@ from click_params import JSON from click_params import IntListParamType from click_params import StringListParamType +from segy.schema import HeaderField +from segy.standards import get_segy_standard + +from mdio.core.storage_location import StorageLocation +from mdio.schemas.v1.templates.template_registry import TemplateRegistry SEGY_HELP = """ MDIO and SEG-Y conversion utilities. Below is general information about the SEG-Y format and MDIO @@ -318,34 +323,29 @@ def segy_import( # noqa: PLR0913 # Lazy import to reduce CLI startup time from mdio import segy_to_mdio # noqa: PLC0415 + _ = (chunk_size, lossless, compression_tolerance, grid_overrides) + + segy_spec = get_segy_standard(1.0) + index_names = header_names or [f"dim_{i}" for i in range(len(header_locations))] + index_types = header_types or ["int32"] * len(header_locations) + index_fields = [ + HeaderField(name=name, byte=byte, format=format_) + for name, byte, format_ in zip(index_names, header_locations, index_types, strict=True) + ] + segy_spec = segy_spec.customize(trace_header_fields=index_fields) + segy_to_mdio( - segy_path=segy_path, - mdio_path_or_buffer=mdio_path, - index_bytes=header_locations, - index_types=header_types, - index_names=header_names, - chunksize=chunk_size, - lossless=lossless, - compression_tolerance=compression_tolerance, - storage_options_input=storage_options_input, - storage_options_output=storage_options_output, + segy_spec=segy_spec, + mdio_template=TemplateRegistry().get("PostStack3DTime"), + input_location=StorageLocation(segy_path, storage_options_input), + output_location=StorageLocation(mdio_path, storage_options_output), overwrite=overwrite, - grid_overrides=grid_overrides, ) @cli.command(name="export") @argument("mdio-file", type=STRING) @argument("segy-path", type=Path(exists=False)) -@option( - "-access", - "--access-pattern", - required=False, - default="012", - help="Access pattern of the file", - type=STRING, - show_default=True, -) @option( "-storage", "--storage-options", @@ -366,7 +366,6 @@ def segy_import( # noqa: PLR0913 def segy_export( mdio_file: str, segy_path: str, - access_pattern: str, storage_options: dict[str, Any], endian: str, ) -> None: @@ -391,9 +390,7 @@ def segy_export( from mdio import mdio_to_segy # noqa: PLC0415 mdio_to_segy( - mdio_path_or_buffer=mdio_file, - output_segy_path=segy_path, - access_pattern=access_pattern, - storage_options=storage_options, + input_location=StorageLocation(mdio_file, storage_options), + output_location=StorageLocation(str(segy_path)), endian=endian, ) diff --git a/src/mdio/converters/mdio.py b/src/mdio/converters/mdio.py index 9726cd6b..e373704d 100644 --- a/src/mdio/converters/mdio.py +++ b/src/mdio/converters/mdio.py @@ -5,17 +5,22 @@ import os from pathlib import Path from tempfile import TemporaryDirectory +from typing import TYPE_CHECKING import numpy as np +import xarray as xr from psutil import cpu_count from tqdm.dask import TqdmCallback -from mdio import MDIOReader from mdio.segy.blocked_io import to_segy from mdio.segy.creation import concat_files +from mdio.segy.creation import get_required_segy_fields from mdio.segy.creation import mdio_spec_to_segy from mdio.segy.utilities import segy_export_rechunker +if TYPE_CHECKING: + from mdio.core.storage_location import StorageLocation + try: import distributed except ImportError: @@ -26,15 +31,15 @@ NUM_CPUS = int(os.getenv("MDIO__EXPORT__CPU_COUNT", default_cpus)) -def mdio_to_segy( # noqa: PLR0912, PLR0913 - mdio_path_or_buffer: str, - output_segy_path: str, +def mdio_to_segy( # noqa: PLR0912, PLR0913, PLR0915 + input_location: StorageLocation, + output_location: StorageLocation, + *, endian: str = "big", - access_pattern: str = "012", - storage_options: dict = None, new_chunks: tuple[int, ...] = None, selection_mask: np.ndarray = None, client: distributed.Client = None, + overwrite: bool = False, ) -> None: """Convert MDIO file to SEG-Y format. @@ -47,20 +52,19 @@ def mdio_to_segy( # noqa: PLR0912, PLR0913 A `selection_mask` can be provided (same shape as spatial grid) to export a subset. Args: - mdio_path_or_buffer: Input path where the MDIO is located. - output_segy_path: Path to the output SEG-Y file. - endian: Endianness of the input SEG-Y. Rev.2 allows little endian. Default is 'big'. - access_pattern: This specificies the chunk access pattern. Underlying zarr.Array must - exist. Examples: '012', '01' - storage_options: Storage options for the cloud storage backend. Default: None (anonymous) + input_location: Location of the input MDIO file. + output_location: Location of the output SEG-Y file. + endian: Endianness of the output SEG-Y. Rev.2 allows little endian. Default is "big". new_chunks: Set manual chunksize. For development purposes only. - selection_mask: Array that lists the subset of traces + selection_mask: Array that lists the subset of traces. client: Dask client. If `None` we will use local threaded scheduler. If `auto` is used we will create multiple processes (with 8 threads each). + overwrite: Whether to overwrite the SEG-Y file if it already exists. Raises: - ImportError: if distributed package isn't installed but requested. - ValueError: if cut mask is empty, i.e. no traces will be written. + FileExistsError: If the output location already exists and `overwrite` is False. + ImportError: If distributed package isn't installed but requested. + ValueError: If cut mask is empty, i.e. no traces will be written. Examples: To export an existing local MDIO file to SEG-Y we use the code snippet below. This will @@ -69,40 +73,41 @@ def mdio_to_segy( # noqa: PLR0912, PLR0913 >>> from mdio import mdio_to_segy >>> - >>> >>> mdio_to_segy( - ... mdio_path_or_buffer="prefix2/file.mdio", - ... output_segy_path="prefix/file.segy", + ... input_location=StorageLocation("prefix2/file.mdio"), + ... output_location=StorageLocation("prefix/file.segy"), ... ) If we want to export this as an IEEE big-endian, using a selection mask, we would run: >>> mdio_to_segy( - ... mdio_path_or_buffer="prefix2/file.mdio", - ... output_segy_path="prefix/file.segy", + ... input_location=StorageLocation("prefix2/file.mdio"), + ... output_location=StorageLocation("prefix/file.segy"), ... selection_mask=boolean_mask, ... ) """ backend = "dask" - output_segy_path = Path(output_segy_path) + if not overwrite and output_location.exists(): + err = f"Output location '{output_location.uri}' exists. Set `overwrite=True` if intended." + raise FileExistsError(err) - mdio = MDIOReader( - mdio_path_or_buffer=mdio_path_or_buffer, - access_pattern=access_pattern, - storage_options=storage_options, - ) + output_segy_path = Path(output_location.uri) if new_chunks is None: - new_chunks = segy_export_rechunker(mdio.chunks, mdio.shape, mdio._traces.dtype) + ds_tmp = xr.open_dataset(input_location.uri, engine="zarr", mask_and_scale=False) + amp = ds_tmp["amplitude"] + chunks = amp.encoding.get("chunks") + shape = amp.shape + dtype = amp.dtype + new_chunks = segy_export_rechunker(chunks, shape, dtype) + ds_tmp.close() creation_args = [ - mdio_path_or_buffer, - output_segy_path, - access_pattern, + input_location, + output_location, endian, - storage_options, new_chunks, backend, ] @@ -111,14 +116,16 @@ def mdio_to_segy( # noqa: PLR0912, PLR0913 if distributed is not None: # This is in case we work with big data feature = client.submit(mdio_spec_to_segy, *creation_args) - mdio, segy_factory = feature.result() + ds, segy_factory = feature.result() else: msg = "Distributed client was provided, but `distributed` is not installed" raise ImportError(msg) else: - mdio, segy_factory = mdio_spec_to_segy(*creation_args) + ds, segy_factory = mdio_spec_to_segy(*creation_args) + + amp_da, headers_da, trace_mask_da, _, _ = get_required_segy_fields(ds) - live_mask = mdio.live_mask.compute() + live_mask = trace_mask_da.data.compute() if selection_mask is not None: live_mask = live_mask & selection_mask @@ -138,12 +145,18 @@ def mdio_to_segy( # noqa: PLR0912, PLR0913 dim_slices += (slice(start, stop),) # Lazily pull the data with limits now, and limit mask so its the same shape. - live_mask, headers, samples = mdio[dim_slices] - live_mask = live_mask.rechunk(headers.chunks) + trace_mask_da = trace_mask_da.data + headers = headers_da.data + samples = amp_da.data + + live_mask_da = trace_mask_da[dim_slices] + headers = headers[dim_slices] + samples = samples[dim_slices] + live_mask_da = live_mask_da.rechunk(headers.chunks) if selection_mask is not None: selection_mask = selection_mask[dim_slices] - live_mask = live_mask & selection_mask + live_mask_da = live_mask_da & selection_mask # tmp file root out_dir = output_segy_path.parent @@ -154,7 +167,7 @@ def mdio_to_segy( # noqa: PLR0912, PLR0913 block_records = to_segy( samples=samples, headers=headers, - live_mask=live_mask, + live_mask=live_mask_da, segy_factory=segy_factory, file_root=tmp_dir.name, ) diff --git a/src/mdio/segy/creation.py b/src/mdio/segy/creation.py index b1daf7d5..749bd1af 100644 --- a/src/mdio/segy/creation.py +++ b/src/mdio/segy/creation.py @@ -7,31 +7,70 @@ from pathlib import Path from shutil import copyfileobj from typing import TYPE_CHECKING -from typing import Any import numpy as np +import xarray as xr from segy.factory import SegyFactory from segy.schema import Endianness from segy.schema import SegySpec from tqdm.auto import tqdm -from mdio.api.accessor import MDIOReader from mdio.segy.compat import mdio_segy_spec from mdio.segy.compat import revision_encode if TYPE_CHECKING: from numpy.typing import NDArray + from xarray import Dataset as xr_Dataset + + from mdio.core.storage_location import StorageLocation logger = logging.getLogger(__name__) -def make_segy_factory(mdio: MDIOReader, spec: SegySpec) -> SegyFactory: +def get_required_segy_fields( + ds: xr_Dataset, +) -> tuple[xr.DataArray, xr.DataArray, xr.DataArray, dict, str]: + """Validate that required fields exist in an MDIO dataset. + + Args: + ds: Dataset to validate. + + Returns: + Tuple containing amplitude, headers, trace_mask variables, attributes dict, + and the API version string. + + Raises: + KeyError: If any of the required fields are missing. + """ + # TODO (BrianMichell): Define and implement field inference # noqa: TD003 + missing = [f"attrs['{attr}']" for attr in ("apiVersion", "attributes") if attr not in ds.attrs] + + attributes = ds.attrs.get("attributes", {}) + missing.extend(f"attrs['attributes']['{key}']" for key in ("textHeader", "binaryHeader") if key not in attributes) + + missing.extend(var for var in ("amplitude", "headers", "trace_mask") if var not in ds) + + if missing: + err = ", ".join(missing) + msg = f"Missing required field(s): {err}" + raise KeyError(msg) + + return ( + ds["amplitude"], + ds["headers"], + ds["trace_mask"], + attributes, + ds.attrs["apiVersion"], + ) + + +def make_segy_factory(ds: xr_Dataset, spec: SegySpec) -> SegyFactory: """Generate SEG-Y factory from MDIO metadata.""" - grid = mdio.grid - sample_dim = grid.select_dim("sample") - sample_interval = sample_dim[1] - sample_dim[0] - samples_per_trace = len(sample_dim) + sample_dim_name = ds["amplitude"].dims[-1] + sample_coord = ds[sample_dim_name].values + sample_interval = sample_coord[1] - sample_coord[0] + samples_per_trace = len(sample_coord) return SegyFactory( spec=spec, @@ -40,65 +79,61 @@ def make_segy_factory(mdio: MDIOReader, spec: SegySpec) -> SegyFactory: ) -def mdio_spec_to_segy( # noqa: PLR0913 - mdio_path_or_buffer: str, - output_segy_path: Path, - access_pattern: str, +def mdio_spec_to_segy( + input_location: StorageLocation, + output_location: StorageLocation, output_endian: str, - storage_options: dict[str, Any], new_chunks: tuple[int, ...], backend: str, -) -> tuple[MDIOReader, SegyFactory]: +) -> tuple[xr_Dataset, SegyFactory]: """Create SEG-Y file without any traces given MDIO specification. This function opens an MDIO file, gets some relevant information for SEG-Y files, then creates a SEG-Y file with the specification it read from the MDIO file. - It then returns the `MDIOReader` instance, and the parsed floating point format `sample_format` - for further use. + It then returns the opened xarray dataset and the parsed floating point format + `sample_format` for further use. Function will attempt to read text, and binary headers, and some grid information from the MDIO file. If these don't exist, the process will fail. Args: - mdio_path_or_buffer: Store or URL for MDIO file. - output_segy_path: Path to the output SEG-Y file. - access_pattern: Chunk access pattern, optional. Default is "012". Examples: '012', '01'. + input_location: Location of the MDIO file. + output_location: Location of the output SEG-Y file. output_endian: Endianness of the output file. - storage_options: Options for the storage backend. By default, system-wide credentials - will be used. new_chunks: Set manual chunksize. For development purposes only. backend: Backend selection, optional. Default is "zarr". Must be in {'zarr', 'dask'}. Returns: - Initialized MDIOReader for MDIO file and return SegyFactory + Opened xarray Dataset for MDIO file and configured SegyFactory """ - mdio = MDIOReader( - mdio_path_or_buffer=mdio_path_or_buffer, - access_pattern=access_pattern, - storage_options=storage_options, - return_metadata=True, - new_chunks=new_chunks, - backend=backend, - disk_cache=False, # Making sure disk caching is disabled - ) + ds = xr.open_dataset(input_location.uri, engine="zarr", mask_and_scale=False) + + amp, _, _, attributes, mdio_file_version = get_required_segy_fields(ds) + + if backend == "dask" and new_chunks is not None: + chunk_map = dict(zip(amp.dims, new_chunks, strict=False)) + ds = ds.chunk(chunk_map) - mdio_file_version = mdio.root.attrs["api_version"] spec = mdio_segy_spec(mdio_file_version) spec.endianness = Endianness(output_endian) - factory = make_segy_factory(mdio, spec=spec) + factory = make_segy_factory(ds, spec=spec) + + text_field = attributes["textHeader"] + if isinstance(text_field, list): + text_field = "".join(text_field) - text_str = "\n".join(mdio.text_header) - text_bytes = factory.create_textual_header(text_str) + text_bytes = factory.create_textual_header(text_field) - binary_header = revision_encode(mdio.binary_header, mdio_file_version) + binary_header = revision_encode(attributes["binaryHeader"], mdio_file_version) bin_hdr_bytes = factory.create_binary_header(binary_header) - with output_segy_path.open(mode="wb") as fp: + output_path = Path(output_location.uri) + with output_path.open(mode="wb") as fp: fp.write(text_bytes) fp.write(bin_hdr_bytes) - return mdio, factory + return ds, factory @dataclass(slots=True) diff --git a/tests/integration/test_segy_import_export.py b/tests/integration/test_segy_import_export.py index e80028df..9f45ae20 100644 --- a/tests/integration/test_segy_import_export.py +++ b/tests/integration/test_segy_import_export.py @@ -57,15 +57,22 @@ def test_import_4d_segy( # noqa: PLR0913 """Test importing a SEG-Y file to MDIO.""" segy_path = segy_mock_4d_shots[chan_header_type] - segy_to_mdio( - segy_path=segy_path, - mdio_path_or_buffer=zarr_tmp.__str__(), + _ = grid_overrides + + segy_spec = get_segy_standard(1.0) + segy_spec = customize_segy_specs( + segy_spec=segy_spec, index_bytes=index_bytes, index_names=index_names, index_types=index_types, - chunksize=(8, 2, 10), + ) + + segy_to_mdio( + segy_spec=segy_spec, + mdio_template=TemplateRegistry().get("PreStackShotGathers3DTime"), + input_location=StorageLocation(str(segy_path)), + output_location=StorageLocation(str(zarr_tmp)), overwrite=True, - grid_overrides=grid_overrides, ) # Expected values @@ -107,15 +114,20 @@ def test_import_4d_segy( # noqa: PLR0913 """Test importing a SEG-Y file to MDIO.""" segy_path = segy_mock_4d_shots[chan_header_type] - segy_to_mdio( - segy_path=segy_path, - mdio_path_or_buffer=zarr_tmp.__str__(), + segy_spec = get_segy_standard(1.0) + segy_spec = customize_segy_specs( + segy_spec=segy_spec, index_bytes=index_bytes, index_names=index_names, index_types=index_types, - chunksize=(8, 2, 128, 1024), + ) + + segy_to_mdio( + segy_spec=segy_spec, + mdio_template=TemplateRegistry().get("PreStackShotGathers3DTime"), + input_location=StorageLocation(str(segy_path)), + output_location=StorageLocation(str(zarr_tmp)), overwrite=True, - grid_overrides=grid_overrides, ) # Expected values @@ -165,14 +177,20 @@ def test_import_4d_segy( # noqa: PLR0913 segy_path = segy_mock_4d_shots[chan_header_type] os.environ["MDIO__GRID__SPARSITY_RATIO_LIMIT"] = "1.1" + segy_spec = get_segy_standard(1.0) + segy_spec = customize_segy_specs( + segy_spec=segy_spec, + index_bytes=index_bytes, + index_names=index_names, + index_types=index_types, + ) + with pytest.raises(GridTraceSparsityError) as execinfo: segy_to_mdio( - segy_path=segy_path, - mdio_path_or_buffer=zarr_tmp.__str__(), - index_bytes=index_bytes, - index_names=index_names, - index_types=index_types, - chunksize=(8, 2, 128, 1024), + segy_spec=segy_spec, + mdio_template=TemplateRegistry().get("PreStackShotGathers3DTime"), + input_location=StorageLocation(str(segy_path)), + output_location=StorageLocation(str(zarr_tmp)), overwrite=True, ) @@ -201,15 +219,20 @@ def test_import_6d_segy( # noqa: PLR0913 """Test importing a SEG-Y file to MDIO.""" segy_path = segy_mock_4d_shots[chan_header_type] - segy_to_mdio( - segy_path=segy_path, - mdio_path_or_buffer=zarr_tmp.__str__(), + segy_spec = get_segy_standard(1.0) + segy_spec = customize_segy_specs( + segy_spec=segy_spec, index_bytes=index_bytes, index_names=index_names, index_types=index_types, - chunksize=(1, 1, 8, 1, 12, 36), + ) + + segy_to_mdio( + segy_spec=segy_spec, + mdio_template=TemplateRegistry().get("PreStackShotGathers3DTime"), + input_location=StorageLocation(str(segy_path)), + output_location=StorageLocation(str(zarr_tmp)), overwrite=True, - grid_overrides=grid_overrides, ) # Expected values @@ -410,8 +433,8 @@ class TestExport: def test_3d_export(self, zarr_tmp: Path, segy_export_tmp: Path) -> None: """Test 3D export to IBM and IEEE.""" mdio_to_segy( - mdio_path_or_buffer=zarr_tmp.__str__(), - output_segy_path=segy_export_tmp.__str__(), + input_location=StorageLocation(zarr_tmp.__str__()), + output_location=StorageLocation(segy_export_tmp.__str__()), ) def test_size_equal(self, segy_input: Path, segy_export_tmp: Path) -> None: diff --git a/tests/integration/test_segy_import_export_masked.py b/tests/integration/test_segy_import_export_masked.py index 5c310865..1e7fcd85 100644 --- a/tests/integration/test_segy_import_export_masked.py +++ b/tests/integration/test_segy_import_export_masked.py @@ -368,9 +368,8 @@ def test_export(self, test_conf: MaskedExportConfig, export_masked_path: Path) - new_chunks = segy_export_rechunker(chunks, shape, dtype="float32", limit="0.3M") mdio_to_segy( - mdio_path.__str__(), - segy_rt_path.__str__(), - access_pattern=access_pattern, + input_location=StorageLocation(mdio_path.__str__()), + output_location=StorageLocation(segy_rt_path.__str__()), new_chunks=new_chunks, ) @@ -395,9 +394,8 @@ def test_export_masked(self, test_conf: MaskedExportConfig, export_masked_path: selection_mask = generate_selection_mask(selection_conf, grid_conf) mdio_to_segy( - mdio_path.__str__(), - segy_rt_path.__str__(), - access_pattern=access_pattern, + input_location=StorageLocation(mdio_path.__str__()), + output_location=StorageLocation(segy_rt_path.__str__()), new_chunks=export_chunks, selection_mask=selection_mask, ) diff --git a/tests/unit/test_segy_field_validation.py b/tests/unit/test_segy_field_validation.py new file mode 100644 index 00000000..47b32f8a --- /dev/null +++ b/tests/unit/test_segy_field_validation.py @@ -0,0 +1,50 @@ +"""Tests for SEG-Y field validation helper.""" + +import numpy as np +import pytest +import xarray as xr + +from mdio.segy.creation import get_required_segy_fields + + +def _make_dataset() -> xr.Dataset: + data = np.zeros((1, 1, 1), dtype=np.float32) + headers = np.zeros((1, 1, 1), dtype=np.int32) + mask = np.ones((1, 1, 1), dtype=bool) + return xr.Dataset( + { + "amplitude": (("x", "y", "sample"), data), + "headers": (("x", "y", "header"), headers), + "trace_mask": (("x", "y", "mask"), mask), + }, + attrs={ + "apiVersion": "1.0.0", + "attributes": {"textHeader": "", "binaryHeader": {}}, + }, + ) + + +def test_get_required_segy_fields_returns_all() -> None: + """Return tuple when all required fields exist.""" + ds = _make_dataset() + amp, hdr, mask, attrs, version = get_required_segy_fields(ds) + assert amp.name == "amplitude" + assert hdr.name == "headers" + assert mask.name == "trace_mask" + assert attrs == {"textHeader": "", "binaryHeader": {}} + assert version == "1.0.0" + + +def test_get_required_segy_fields_missing_variable() -> None: + """Raise when a required data variable is missing.""" + ds = _make_dataset().drop_vars("amplitude") + with pytest.raises(KeyError): + get_required_segy_fields(ds) + + +def test_get_required_segy_fields_missing_attribute() -> None: + """Raise when a required attribute is missing.""" + ds = _make_dataset() + del ds.attrs["attributes"]["textHeader"] + with pytest.raises(KeyError): + get_required_segy_fields(ds)