Skip to content

Commit a23fb84

Browse files
committed
Fix to_zarr function signature to not be changed, don't reopen the zarr store
1 parent 30a6472 commit a23fb84

File tree

2 files changed

+13
-75
lines changed

2 files changed

+13
-75
lines changed

src/mdio/converters/segy.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -354,7 +354,6 @@ def segy_to_mdio( # noqa: PLR0913, PLR0915
354354
... grid_overrides={"HasDuplicates": True},
355355
... )
356356
"""
357-
print("Entering segy_to_mdio")
358357
index_names = index_names or [f"dim_{i}" for i in range(len(index_bytes))]
359358
index_types = index_types or ["int32"] * len(index_bytes)
360359

@@ -379,7 +378,7 @@ def segy_to_mdio( # noqa: PLR0913, PLR0915
379378
num_traces = segy.num_traces
380379

381380
# Index the dataset using a spec that interprets the user provided index headers.
382-
index_fields: list[HeaderField] = []
381+
index_fields = []
383382
for name, byte, format_ in zip(index_names, index_bytes, index_types, strict=True):
384383
index_fields.append(HeaderField(name=name, byte=byte, format=format_))
385384
mdio_spec_grid = mdio_spec.customize(trace_header_fields=index_fields)
@@ -489,16 +488,11 @@ def segy_to_mdio( # noqa: PLR0913, PLR0915
489488
write_attribute(name="binary_header", zarr_group=meta_group, attribute=binary_header.to_dict())
490489

491490
# Write traces
492-
zarr_root = mdio_path_or_buffer # the same path you passed earlier to create_empty
493-
data_var = f"data/chunked_{suffix}"
494-
header_var = f"metadata/chunked_{suffix}_trace_headers"
495-
496491
stats = blocked_io.to_zarr(
497492
segy_file=segy,
498493
grid=grid,
499-
zarr_root_path=zarr_root,
500-
data_var_path=data_var,
501-
header_var_path=header_var,
494+
data_array=data_array,
495+
header_array=header_array,
502496
)
503497

504498
# Write actual stats

src/mdio/segy/blocked_io.py

Lines changed: 10 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,5 @@
11
"""Functions for doing blocked I/O from SEG-Y."""
22

3-
# from __future__ import annotations
4-
5-
# import multiprocessing as mp
6-
# import os
7-
# from concurrent.futures import ProcessPoolExecutor
8-
# from itertools import repeat
9-
# from typing import TYPE_CHECKING, Any
10-
11-
# import numpy as np
12-
# from psutil import cpu_count
13-
# from tqdm.auto import tqdm
14-
# import zarr
15-
16-
# # from mdio.core.indexing import ChunkIterator
17-
# # from mdio.segy._workers import trace_worker
18-
19-
# from mdio.core.indexing import ChunkIterator
20-
# from mdio.segy._workers import trace_worker
21-
# from mdio.segy.creation import SegyPartRecord
22-
# from mdio.segy.creation import concat_files
23-
# from mdio.segy.creation import serialize_to_segy_stack
24-
# from mdio.segy.utilities import find_trailing_ones_index
25-
26-
# if TYPE_CHECKING:
27-
# from segy import SegyFile
28-
# from mdio.core import Grid
29-
303
from __future__ import annotations
314

325
import multiprocessing as mp
@@ -35,6 +8,7 @@
358
from itertools import repeat
369
from pathlib import Path
3710
from typing import TYPE_CHECKING
11+
from typing import Any
3812

3913
import numpy as np
4014
from dask.array import Array
@@ -55,61 +29,33 @@
5529
from numpy.typing import NDArray
5630
from segy import SegyFactory
5731
from segy import SegyFile
32+
from zarr import Array as ZarrArray
5833

5934
from mdio.core import Grid
6035

6136
default_cpus = cpu_count(logical=True)
6237

6338

64-
def _worker_reopen(
65-
zarr_root_path: str,
66-
data_var_path: str,
67-
header_var_path: str,
68-
segy_file: SegyFile,
69-
grid: Grid,
70-
chunk_indices: tuple[slice, ...],
71-
) -> tuple[Any, ...] | None:
72-
"""
73-
Worker function that reopens the Zarr store in this process,
74-
obtains fresh array handles, and calls the real trace_worker.
75-
"""
76-
root = zarr.open_group(zarr_root_path, mode="r+")
77-
data_arr = root[data_var_path]
78-
header_arr = root[header_var_path]
79-
result = trace_worker(segy_file, data_arr, header_arr, grid, chunk_indices)
80-
root.store.close()
81-
return result
82-
83-
8439
def to_zarr(
8540
segy_file: SegyFile,
8641
grid: Grid,
87-
zarr_root_path: str,
88-
data_var_path: str,
89-
header_var_path: str,
42+
data_array: ZarrArray,
43+
header_array: ZarrArray,
9044
) -> dict[str, Any]:
9145
"""Blocked I/O from SEG-Y to chunked `zarr.core.Array`.
9246
93-
Each worker reopens the Zarr store independently to avoid lock contention when writing.
94-
9547
Args:
9648
segy_file: SEG-Y file instance.
9749
grid: mdio.Grid instance.
98-
zarr_root_path: Filesystem path (or URI) to the root of the MDIO Zarr store.
99-
data_var_path: Path within the Zarr group for the data array (e.g., "data/chunked_012").
100-
header_var_path: Path within the Zarr group for the header array
101-
(e.g., "metadata/chunked_012_trace_headers").
50+
data_array: Zarr array for storing trace data.
51+
header_array: Zarr array for storing trace headers.
10252
10353
Returns:
10454
Global statistics for the SEG-Y as a dictionary.
10555
"""
106-
# Open Zarr store only in the main process to retrieve shape/metadata
107-
root = zarr.open_group(zarr_root_path, mode="r")
108-
data_array_meta = root[data_var_path] # only for shape info
10956
# Initialize chunk iterator (returns next chunk slice indices each iteration)
110-
chunker = ChunkIterator(data_array_meta, chunk_samples=False)
57+
chunker = ChunkIterator(data_array, chunk_samples=False)
11158
num_chunks = len(chunker)
112-
root.store.close() # close immediately
11359

11460
# Determine number of workers
11561
num_cpus_env = int(os.getenv("MDIO__IMPORT__CPU_COUNT", default_cpus))
@@ -125,11 +71,10 @@ def to_zarr(
12571
# Launch multiprocessing pool
12672
with ProcessPoolExecutor(max_workers=num_workers, mp_context=context) as executor:
12773
lazy_work = executor.map(
128-
_worker_reopen,
129-
repeat(zarr_root_path),
130-
repeat(data_var_path),
131-
repeat(header_var_path),
74+
trace_worker,
13275
repeat(segy_file),
76+
repeat(data_array),
77+
repeat(header_array),
13378
repeat(grid),
13479
chunker,
13580
chunksize=pool_chunksize,
@@ -166,7 +111,6 @@ def to_zarr(
166111
return {"mean": glob_mean, "std": glob_std, "rms": glob_rms, "min": glob_min, "max": glob_max}
167112

168113

169-
170114
def segy_record_concat(
171115
block_records: NDArray,
172116
file_root: str,

0 commit comments

Comments
 (0)