Skip to content

Commit 515ca6c

Browse files
committed
Working tests with fake server
1 parent e10c50c commit 515ca6c

File tree

4 files changed

+235
-83
lines changed

4 files changed

+235
-83
lines changed

src/mdio/api/io.py

Lines changed: 126 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,46 @@
2323
from xarray.core.types import T_Chunks
2424
from xarray.core.types import ZarrWriteModes
2525

26-
2726
def _normalize_path(path: UPath | Path | str) -> UPath:
28-
return UPath(path)
27+
"""Normalize a path to a UPath.
2928
29+
For gs:// paths, the fake GCS server configuration is handled via storage_options
30+
in _normalize_storage_options().
31+
"""
32+
from upath import UPath
33+
34+
return UPath(path)
3035

3136
def _normalize_storage_options(path: UPath) -> dict[str, Any] | None:
32-
return None if len(path.storage_options) == 0 else path.storage_options
37+
"""Normalize and patch storage options for UPath paths.
38+
39+
- Extracts any existing options from the UPath.
40+
- Automatically redirects gs:// URLs to a local fake-GCS endpoint
41+
when testing (localhost:4443).
42+
"""
43+
import gcsfs
44+
45+
# Start with any existing options from UPath
46+
storage_options = dict(path.storage_options) if len(path.storage_options) else {}
47+
48+
# Redirect gs:// to local fake-GCS server for testing
49+
if str(path).startswith("gs://"):
50+
fs = gcsfs.GCSFileSystem(
51+
endpoint_url="http://localhost:4443",
52+
token="anon",
53+
)
54+
base_url = getattr(getattr(fs, "session", None), "_base_url", "http://localhost:4443")
55+
print(f"[mdio.utils] Redirecting GCS path to local fake server: {base_url}")
56+
storage_options["fs"] = fs
57+
58+
return storage_options or None
59+
60+
# def _normalize_path(path: UPath | Path | str) -> UPath:
61+
# return UPath(path)
62+
63+
64+
# def _normalize_storage_options(path: UPath) -> dict[str, Any] | None:
65+
# return None if len(path.storage_options) == 0 else path.storage_options
3366

3467

3568
def open_mdio(input_path: UPath | Path | str, chunks: T_Chunks = None) -> xr_Dataset:
@@ -49,6 +82,8 @@ def open_mdio(input_path: UPath | Path | str, chunks: T_Chunks = None) -> xr_Dat
4982
Returns:
5083
An Xarray dataset opened from the input path.
5184
"""
85+
import zarr
86+
5287
input_path = _normalize_path(input_path)
5388
storage_options = _normalize_storage_options(input_path)
5489
zarr_format = zarr.config.get("default_zarr_format")
@@ -61,43 +96,101 @@ def open_mdio(input_path: UPath | Path | str, chunks: T_Chunks = None) -> xr_Dat
6196
consolidated=zarr_format == ZarrFormat.V2, # on for v2, off for v3
6297
)
6398

64-
65-
def to_mdio( # noqa: PLR0913
99+
def to_mdio(
66100
dataset: Dataset,
67101
output_path: UPath | Path | str,
68102
mode: ZarrWriteModes | None = None,
69103
*,
70104
compute: bool = True,
71-
region: Mapping[str, slice | Literal["auto"]] | Literal["auto"] | None = None,
72-
) -> None:
73-
"""Write dataset contents to an MDIO output_path.
105+
region: Mapping[str, slice | Literal["auto"]] | Literal["auto"] | None = None,):
106+
"""Write dataset contents to an MDIO output_path."""
107+
import gcsfs
108+
import zarr
74109

75-
Args:
76-
dataset: The dataset to write.
77-
output_path: The universal path of the output MDIO file.
78-
mode: Persistence mode: "w" means create (overwrite if exists)
79-
"w-" means create (fail if exists)
80-
"a" means override all existing variables including dimension coordinates (create if does not exist)
81-
"a-" means only append those variables that have ``append_dim``.
82-
"r+" means modify existing array *values* only (raise an error if any metadata or shapes would change).
83-
The default mode is "r+" if ``region`` is set and ``w-`` otherwise.
84-
compute: If True write array data immediately; otherwise return a ``dask.delayed.Delayed`` object that
85-
can be computed to write array data later. Metadata is always updated eagerly.
86-
region: Optional mapping from dimension names to either a) ``"auto"``, or b) integer slices, indicating
87-
the region of existing MDIO array(s) in which to write this dataset's data.
88-
"""
89110
output_path = _normalize_path(output_path)
90-
storage_options = _normalize_storage_options(output_path)
91111
zarr_format = zarr.config.get("default_zarr_format")
92112

93-
with zarr_warnings_suppress_unstable_structs_v3():
94-
xr_to_zarr(
95-
dataset,
96-
store=output_path.as_posix(), # xarray doesn't like URI when file:// is protocol
97-
mode=mode,
98-
compute=compute,
99-
consolidated=zarr_format == ZarrFormat.V2, # on for v2, off for v3
100-
region=region,
101-
storage_options=storage_options,
102-
write_empty_chunks=False,
113+
# For GCS paths, create FSMap for fake GCS server
114+
if str(output_path).startswith("gs://"):
115+
fs = gcsfs.GCSFileSystem(
116+
endpoint_url="http://localhost:4443",
117+
token="anon",
103118
)
119+
base_url = getattr(getattr(fs, "session", None), "_base_url", "http://localhost:4443")
120+
print(f"[mdio.utils] Using fake GCS mapper via {base_url}")
121+
store = fs.get_mapper(output_path.as_posix().replace("gs://", ""))
122+
storage_options = None # Must be None when passing a mapper
123+
else:
124+
store = output_path.as_posix()
125+
storage_options = _normalize_storage_options(output_path)
126+
127+
print(f"[mdio.utils] Writing to store: {store}")
128+
print(f"[mdio.utils] Storage options: {storage_options}")
129+
130+
kwargs = dict(
131+
dataset=dataset,
132+
store=store,
133+
mode=mode,
134+
compute=compute,
135+
consolidated=zarr_format == ZarrFormat.V2,
136+
region=region,
137+
write_empty_chunks=False,
138+
)
139+
if storage_options is not None and not isinstance(store, dict):
140+
kwargs["storage_options"] = storage_options
141+
142+
with zarr_warnings_suppress_unstable_structs_v3():
143+
xr_to_zarr(**kwargs)
144+
145+
146+
# def to_mdio( # noqa: PLR0913
147+
# dataset: Dataset,
148+
# output_path: UPath | Path | str,
149+
# mode: ZarrWriteModes | None = None,
150+
# *,
151+
# compute: bool = True,
152+
# region: Mapping[str, slice | Literal["auto"]] | Literal["auto"] | None = None,
153+
# ) -> None:
154+
# """Write dataset contents to an MDIO output_path."""
155+
# import gcsfs, zarr
156+
157+
# output_path = _normalize_path(output_path)
158+
159+
# if output_path.as_posix().startswith("gs://"):
160+
# fs = gcsfs.GCSFileSystem(
161+
# endpoint_url="http://localhost:4443",
162+
# token="anon",
163+
# )
164+
165+
# base_url = getattr(getattr(fs, "session", None), "_base_url", "http://localhost:4443")
166+
# print(f"Using custom fake GCS filesystem with endpoint {base_url}")
167+
168+
# # Build a mapper so all I/O uses the fake GCS server
169+
# mapper = fs.get_mapper(output_path.as_posix().replace("gs://", ""))
170+
# store = mapper
171+
# storage_options = None # must be None when passing a mapper
172+
# else:
173+
# store = output_path.as_posix()
174+
# storage_options = _normalize_storage_options(output_path) or {}
175+
176+
# print(f"Writing to store: {store}")
177+
# zarr_format = zarr.config.get("default_zarr_format")
178+
179+
# kwargs = dict(
180+
# dataset=dataset,
181+
# store=store,
182+
# mode=mode,
183+
# compute=compute,
184+
# consolidated=zarr_format == ZarrFormat.V2,
185+
# region=region,
186+
# write_empty_chunks=False,
187+
# )
188+
# if storage_options is not None:
189+
# kwargs["storage_options"] = storage_options
190+
191+
# with zarr_warnings_suppress_unstable_structs_v3():
192+
# xr_to_zarr(**kwargs)
193+
194+
195+
196+

src/mdio/converters/segy.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -531,9 +531,11 @@ def segy_to_mdio( # noqa PLR0913
531531
input_path = _normalize_path(input_path)
532532
output_path = _normalize_path(output_path)
533533

534+
print("Checking if output path exists...")
534535
if not overwrite and output_path.exists():
535536
err = f"Output location '{output_path.as_posix()}' exists. Set `overwrite=True` if intended."
536537
raise FileExistsError(err)
538+
print("Output path check was fine")
537539

538540
segy_settings = SegyFileSettings(storage_options=input_path.storage_options)
539541
segy_file_kwargs: SegyFileArguments = {
@@ -589,15 +591,16 @@ def segy_to_mdio( # noqa PLR0913
589591
# blocked_io.to_zarr() -> _workers.trace_worker()
590592

591593
# This will create the Zarr store with the correct structure but with empty arrays
594+
print("Creating Zarr store...")
592595
to_mdio(xr_dataset, output_path=output_path, mode="w", compute=False)
593-
596+
print("Zarr store created")
594597
# This will write the non-dimension coordinates and trace mask
595598
meta_ds = xr_dataset[drop_vars_delayed + ["trace_mask"]]
596599
to_mdio(meta_ds, output_path=output_path, mode="r+", compute=True)
597-
600+
print("Non-dimension coordinates and trace mask written")
598601
# Now we can drop them to simplify chunked write of the data variable
599602
xr_dataset = xr_dataset.drop_vars(drop_vars_delayed)
600-
603+
print("Dropped variables")
601604
# Write the headers and traces in chunks using grid_map to indicate dead traces
602605
default_variable_name = mdio_template.default_variable_name
603606
# This is an memory-expensive and time-consuming read-write operation

src/mdio/segy/_workers.py

Lines changed: 46 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -8,25 +8,47 @@
88
import numpy as np
99
from segy.arrays import HeaderArray
1010

11-
from mdio.api.io import _normalize_storage_options
1211
from mdio.core.config import MDIOSettings
1312
from mdio.segy._raw_trace_wrapper import SegyFileRawTraceWrapper
1413
from mdio.segy.file import SegyFileArguments
1514
from mdio.segy.file import SegyFileWrapper
1615

1716
if TYPE_CHECKING:
18-
from upath import UPath
1917
from zarr import Array as zarr_Array
18+
from zarr import Group as zarr_Group
19+
from segy import SegyFile
2020

21-
from zarr import open_group as zarr_open_group
2221
from zarr.core.config import config as zarr_config
23-
2422
from mdio.builder.schemas.v1.stats import CenteredBinHistogram
2523
from mdio.builder.schemas.v1.stats import SummaryStatistics
2624
from mdio.constants import fill_value_map
2725

2826
logger = logging.getLogger(__name__)
2927

28+
# Global variable to store opened segy file per worker process
29+
_worker_segy_file = None
30+
31+
32+
def _init_worker(segy_file_kwargs: SegyFileArguments) -> None:
33+
"""Initialize worker process with persistent segy file handle.
34+
35+
This function is called once per worker process to open the segy file,
36+
which is then reused across all tasks in that worker.
37+
38+
Args:
39+
segy_file_kwargs: Arguments to open SegyFile instance.
40+
"""
41+
global _worker_segy_file
42+
43+
from segy import SegyFile
44+
45+
# Setting the zarr config to 1 thread to ensure we honor the `MDIO__IMPORT__MAX_WORKERS` environment variable.
46+
# The Zarr 3 engine utilizes multiple threads. This can lead to resource contention and unpredictable memory usage.
47+
zarr_config.set({"threading.max_workers": 1})
48+
49+
# Open the SEG-Y file once per worker
50+
_worker_segy_file = SegyFile(**segy_file_kwargs)
51+
3052

3153
def header_scan_worker(
3254
segy_file_kwargs: SegyFileArguments,
@@ -71,55 +93,42 @@ def header_scan_worker(
7193

7294

7395
def trace_worker( # noqa: PLR0913
74-
segy_file_kwargs: SegyFileArguments,
75-
output_path: UPath,
76-
data_variable_name: str,
96+
data_array: zarr_Array,
97+
header_array: zarr_Array | None,
98+
raw_header_array: zarr_Array | None,
7799
region: dict[str, slice],
78-
grid_map: zarr_Array,
100+
grid_map_data: np.ndarray,
79101
) -> SummaryStatistics | None:
80102
"""Writes a subset of traces from a region of the dataset of Zarr file.
103+
104+
Uses pre-opened segy file from _init_worker and receives zarr arrays directly.
81105
82106
Args:
83-
segy_file_kwargs: Arguments to open SegyFile instance.
84-
output_path: Universal Path for the output Zarr dataset
85-
(e.g. local file path or cloud storage URI) the location
86-
also includes storage options for cloud storage.
87-
data_variable_name: Name of the data variable to write.
107+
data_array: Zarr array for writing trace data.
108+
header_array: Zarr array for writing trace headers (or None if not needed).
109+
raw_header_array: Zarr array for writing raw headers (or None if not needed).
88110
region: Region of the dataset to write to.
89-
grid_map: Zarr array mapping live traces to their positions in the dataset.
111+
grid_map_data: Numpy array mapping live traces to their positions in the dataset.
90112
91113
Returns:
92114
SummaryStatistics object containing statistics about the written traces.
93115
"""
116+
global _worker_segy_file
117+
118+
# Use the pre-opened segy file from worker initialization
119+
segy_file = _worker_segy_file
120+
94121
region_slices = tuple(region.values())
95-
local_grid_map = grid_map[region_slices[:-1]] # minus last (vertical) axis
122+
local_grid_map = grid_map_data[region_slices[:-1]] # minus last (vertical) axis
96123

97124
# The dtype.max is the sentinel value for the grid map.
98125
# Normally, this is uint32, but some grids need to be promoted to uint64.
99126
not_null = local_grid_map != fill_value_map.get(local_grid_map.dtype.name)
100127
if not not_null.any():
101128
return None
102129

103-
# Open the SEG-Y file in this process since the open file handles cannot be shared across processes.
104-
segy_file = SegyFileWrapper(**segy_file_kwargs)
105-
106-
# Setting the zarr config to 1 thread to ensure we honor the `MDIO__IMPORT__MAX_WORKERS` environment variable.
107-
# The Zarr 3 engine utilizes multiple threads. This can lead to resource contention and unpredictable memory usage.
108-
zarr_config.set({"threading.max_workers": 1})
109-
110130
live_trace_indexes = local_grid_map[not_null].tolist()
111131

112-
# Open the zarr group to write directly
113-
storage_options = _normalize_storage_options(output_path)
114-
zarr_group = zarr_open_group(output_path.as_posix(), mode="r+", storage_options=storage_options)
115-
116-
header_key = "headers"
117-
raw_header_key = "raw_headers"
118-
119-
# Check which variables exist in the zarr store
120-
available_arrays = list(zarr_group.array_keys())
121-
122-
# traces = segy_file.trace[live_trace_indexes]
123132
# Raw headers are not intended to remain as a feature of the SEGY ingestion.
124133
# For that reason, we have wrapped the accessors to provide an interface that can be removed
125134
# and not require additional changes to the below code.
@@ -132,24 +141,21 @@ def trace_worker( # noqa: PLR0913
132141
full_shape = tuple(s.stop - s.start for s in region_slices)
133142
header_shape = tuple(s.stop - s.start for s in header_region_slices)
134143

135-
# Write raw headers if they exist
144+
# Write raw headers if array was provided
136145
# Headers only have spatial dimensions (no sample dimension)
137-
if raw_header_key in available_arrays:
138-
raw_header_array = zarr_group[raw_header_key]
146+
if raw_header_array is not None:
139147
tmp_raw_headers = np.full(header_shape, raw_header_array.fill_value)
140148
tmp_raw_headers[not_null] = traces.raw_header
141149
raw_header_array[header_region_slices] = tmp_raw_headers
142150

143-
# Write headers if they exist
151+
# Write headers if array was provided
144152
# Headers only have spatial dimensions (no sample dimension)
145-
if header_key in available_arrays:
146-
header_array = zarr_group[header_key]
153+
if header_array is not None:
147154
tmp_headers = np.full(header_shape, header_array.fill_value)
148155
tmp_headers[not_null] = traces.header
149156
header_array[header_region_slices] = tmp_headers
150157

151158
# Write the data variable
152-
data_array = zarr_group[data_variable_name]
153159
tmp_samples = np.full(full_shape, data_array.fill_value)
154160
tmp_samples[not_null] = traces.sample
155161
data_array[region_slices] = tmp_samples

0 commit comments

Comments
 (0)