Skip to content

Commit 20cca4b

Browse files
committed
Cleanup
1 parent 55a28fc commit 20cca4b

File tree

4 files changed

+64
-139
lines changed

4 files changed

+64
-139
lines changed

src/mdio/api/io.py

Lines changed: 53 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from xarray.backends.writers import to_zarr as xr_to_zarr
1414

1515
from mdio.constants import ZarrFormat
16+
from mdio.core.config import MDIOSettings
1617
from mdio.core.zarr_io import zarr_warnings_suppress_unstable_structs_v3
1718

1819
if TYPE_CHECKING:
@@ -23,46 +24,39 @@
2324
from xarray.core.types import T_Chunks
2425
from xarray.core.types import ZarrWriteModes
2526

26-
def _normalize_path(path: UPath | Path | str) -> UPath:
27-
"""Normalize a path to a UPath.
28-
29-
For gs:// paths, the fake GCS server configuration is handled via storage_options
30-
in _normalize_storage_options().
31-
"""
32-
from upath import UPath
3327

28+
def _normalize_path(path: UPath | Path | str) -> UPath:
29+
"""Normalize a path to a UPath."""
3430
return UPath(path)
3531

32+
3633
def _normalize_storage_options(path: UPath) -> dict[str, Any] | None:
37-
"""Normalize and patch storage options for UPath paths.
34+
"""Normalize storage options from UPath."""
35+
return None if len(path.storage_options) == 0 else path.storage_options
36+
37+
38+
def _get_gcs_store(path: UPath) -> tuple[Any, dict[str, Any] | None]:
39+
"""Get store and storage options, using local fake GCS server if enabled.
40+
41+
Args:
42+
path: UPath pointing to storage location.
3843
39-
- Extracts any existing options from the UPath.
40-
- Automatically redirects gs:// URLs to a local fake-GCS endpoint
41-
when testing (localhost:4443).
44+
Returns:
45+
Tuple of (store, storage_options) where store is either a mapper or path string.
4246
"""
47+
settings = MDIOSettings()
4348

44-
# Start with any existing options from UPath
45-
storage_options = dict(path.storage_options) if len(path.storage_options) else {}
49+
if settings.local_gcs_server and str(path).startswith("gs://"):
50+
import gcsfs # noqa: PLC0415
4651

47-
# Redirect gs:// to local fake-GCS server for testing
48-
if str(path).startswith("gs://"):
49-
import gcsfs
5052
fs = gcsfs.GCSFileSystem(
5153
endpoint_url="http://localhost:4443",
52-
token="anon",
54+
token="anon", # noqa: S106
5355
)
54-
base_url = getattr(getattr(fs, "session", None), "_base_url", "http://localhost:4443")
55-
print(f"[mdio.utils] Redirecting GCS path to local fake server: {base_url}")
56-
storage_options["fs"] = fs
57-
58-
return storage_options or None
59-
60-
# def _normalize_path(path: UPath | Path | str) -> UPath:
61-
# return UPath(path)
56+
store = fs.get_mapper(path.as_posix().replace("gs://", ""))
57+
return store, None
6258

63-
64-
# def _normalize_storage_options(path: UPath) -> dict[str, Any] | None:
65-
# return None if len(path.storage_options) == 0 else path.storage_options
59+
return path.as_posix(), _normalize_storage_options(path)
6660

6761

6862
def open_mdio(input_path: UPath | Path | str, chunks: T_Chunks = None) -> xr_Dataset:
@@ -82,8 +76,6 @@ def open_mdio(input_path: UPath | Path | str, chunks: T_Chunks = None) -> xr_Dat
8276
Returns:
8377
An Xarray dataset opened from the input path.
8478
"""
85-
import zarr
86-
8779
input_path = _normalize_path(input_path)
8880
storage_options = _normalize_storage_options(input_path)
8981
zarr_format = zarr.config.get("default_zarr_format")
@@ -96,101 +88,48 @@ def open_mdio(input_path: UPath | Path | str, chunks: T_Chunks = None) -> xr_Dat
9688
consolidated=zarr_format == ZarrFormat.V2, # on for v2, off for v3
9789
)
9890

99-
def to_mdio(
91+
92+
def to_mdio( # noqa: PLR0913
10093
dataset: Dataset,
10194
output_path: UPath | Path | str,
10295
mode: ZarrWriteModes | None = None,
10396
*,
10497
compute: bool = True,
105-
region: Mapping[str, slice | Literal["auto"]] | Literal["auto"] | None = None,):
106-
"""Write dataset contents to an MDIO output_path."""
107-
import zarr
98+
region: Mapping[str, slice | Literal["auto"]] | Literal["auto"] | None = None,
99+
) -> None:
100+
"""Write dataset contents to an MDIO output_path.
108101
102+
Args:
103+
dataset: The dataset to write.
104+
output_path: The universal path of the output MDIO file.
105+
mode: Persistence mode: "w" means create (overwrite if exists)
106+
"w-" means create (fail if exists)
107+
"a" means override all existing variables including dimension coordinates (create if does not exist)
108+
"a-" means only append those variables that have ``append_dim``.
109+
"r+" means modify existing array *values* only (raise an error if any metadata or shapes would change).
110+
The default mode is "r+" if ``region`` is set and ``w-`` otherwise.
111+
compute: If True write array data immediately; otherwise return a ``dask.delayed.Delayed`` object that
112+
can be computed to write array data later. Metadata is always updated eagerly.
113+
region: Optional mapping from dimension names to either a) ``"auto"``, or b) integer slices, indicating
114+
the region of existing MDIO array(s) in which to write this dataset's data.
115+
"""
109116
output_path = _normalize_path(output_path)
110117
zarr_format = zarr.config.get("default_zarr_format")
111118

112-
# For GCS paths, create FSMap for fake GCS server
113-
if str(output_path).startswith("gs://"):
114-
import gcsfs
115-
fs = gcsfs.GCSFileSystem(
116-
endpoint_url="http://localhost:4443",
117-
token="anon",
118-
)
119-
base_url = getattr(getattr(fs, "session", None), "_base_url", "http://localhost:4443")
120-
print(f"[mdio.utils] Using fake GCS mapper via {base_url}")
121-
store = fs.get_mapper(output_path.as_posix().replace("gs://", ""))
122-
storage_options = None # Must be None when passing a mapper
123-
else:
124-
store = output_path.as_posix()
125-
storage_options = _normalize_storage_options(output_path)
126-
127-
print(f"[mdio.utils] Writing to store: {store}")
128-
print(f"[mdio.utils] Storage options: {storage_options}")
129-
130-
kwargs = dict(
131-
dataset=dataset,
132-
store=store,
133-
mode=mode,
134-
compute=compute,
135-
consolidated=zarr_format == ZarrFormat.V2,
136-
region=region,
137-
write_empty_chunks=False,
138-
)
119+
store, storage_options = _get_gcs_store(output_path)
120+
121+
kwargs = {
122+
"dataset": dataset,
123+
"store": store,
124+
"mode": mode,
125+
"compute": compute,
126+
"consolidated": zarr_format == ZarrFormat.V2,
127+
"region": region,
128+
"write_empty_chunks": False,
129+
}
130+
139131
if storage_options is not None and not isinstance(store, dict):
140132
kwargs["storage_options"] = storage_options
141133

142134
with zarr_warnings_suppress_unstable_structs_v3():
143135
xr_to_zarr(**kwargs)
144-
145-
146-
# def to_mdio( # noqa: PLR0913
147-
# dataset: Dataset,
148-
# output_path: UPath | Path | str,
149-
# mode: ZarrWriteModes | None = None,
150-
# *,
151-
# compute: bool = True,
152-
# region: Mapping[str, slice | Literal["auto"]] | Literal["auto"] | None = None,
153-
# ) -> None:
154-
# """Write dataset contents to an MDIO output_path."""
155-
# import gcsfs, zarr
156-
157-
# output_path = _normalize_path(output_path)
158-
159-
# if output_path.as_posix().startswith("gs://"):
160-
# fs = gcsfs.GCSFileSystem(
161-
# endpoint_url="http://localhost:4443",
162-
# token="anon",
163-
# )
164-
165-
# base_url = getattr(getattr(fs, "session", None), "_base_url", "http://localhost:4443")
166-
# print(f"Using custom fake GCS filesystem with endpoint {base_url}")
167-
168-
# # Build a mapper so all I/O uses the fake GCS server
169-
# mapper = fs.get_mapper(output_path.as_posix().replace("gs://", ""))
170-
# store = mapper
171-
# storage_options = None # must be None when passing a mapper
172-
# else:
173-
# store = output_path.as_posix()
174-
# storage_options = _normalize_storage_options(output_path) or {}
175-
176-
# print(f"Writing to store: {store}")
177-
# zarr_format = zarr.config.get("default_zarr_format")
178-
179-
# kwargs = dict(
180-
# dataset=dataset,
181-
# store=store,
182-
# mode=mode,
183-
# compute=compute,
184-
# consolidated=zarr_format == ZarrFormat.V2,
185-
# region=region,
186-
# write_empty_chunks=False,
187-
# )
188-
# if storage_options is not None:
189-
# kwargs["storage_options"] = storage_options
190-
191-
# with zarr_warnings_suppress_unstable_structs_v3():
192-
# xr_to_zarr(**kwargs)
193-
194-
195-
196-

src/mdio/converters/segy.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -531,11 +531,9 @@ def segy_to_mdio( # noqa PLR0913
531531
input_path = _normalize_path(input_path)
532532
output_path = _normalize_path(output_path)
533533

534-
print("Checking if output path exists...")
535534
if not overwrite and output_path.exists():
536535
err = f"Output location '{output_path.as_posix()}' exists. Set `overwrite=True` if intended."
537536
raise FileExistsError(err)
538-
print("Output path check was fine")
539537

540538
segy_settings = SegyFileSettings(storage_options=input_path.storage_options)
541539
segy_file_kwargs: SegyFileArguments = {
@@ -591,16 +589,14 @@ def segy_to_mdio( # noqa PLR0913
591589
# blocked_io.to_zarr() -> _workers.trace_worker()
592590

593591
# This will create the Zarr store with the correct structure but with empty arrays
594-
print("Creating Zarr store...")
595592
to_mdio(xr_dataset, output_path=output_path, mode="w", compute=False)
596-
print("Zarr store created")
593+
597594
# This will write the non-dimension coordinates and trace mask
598595
meta_ds = xr_dataset[drop_vars_delayed + ["trace_mask"]]
599596
to_mdio(meta_ds, output_path=output_path, mode="r+", compute=True)
600-
print("Non-dimension coordinates and trace mask written")
597+
601598
# Now we can drop them to simplify chunked write of the data variable
602599
xr_dataset = xr_dataset.drop_vars(drop_vars_delayed)
603-
print("Dropped variables")
604600
# Write the headers and traces in chunks using grid_map to indicate dead traces
605601
default_variable_name = mdio_template.default_variable_name
606602
# This is an memory-expensive and time-consuming read-write operation

src/mdio/core/config.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,5 +56,10 @@ class MDIOSettings(BaseSettings):
5656
description="Whether to ignore validation checks",
5757
alias="MDIO_IGNORE_CHECKS",
5858
)
59+
local_gcs_server: bool = Field(
60+
default=False,
61+
description="Whether to use local fake GCS server for testing (localhost:4443)",
62+
alias="MDIO__LOCAL_GCS_SERVER",
63+
)
5964

6065
model_config = SettingsConfigDict(case_sensitive=True)

src/mdio/segy/blocked_io.py

Lines changed: 4 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from tqdm.auto import tqdm
1616
from zarr import open_group as zarr_open_group
1717

18-
from mdio.api.io import _normalize_storage_options
18+
from mdio.api.io import _get_gcs_store
1919
from mdio.builder.schemas.v1.stats import CenteredBinHistogram
2020
from mdio.builder.schemas.v1.stats import SummaryStatistics
2121
from mdio.constants import ZarrFormat
@@ -82,23 +82,9 @@ def to_zarr( # noqa: PLR0913, PLR0915
8282
num_chunks = chunk_iter.num_chunks
8383

8484
zarr_format = zarr.config.get("default_zarr_format")
85-
print("Opening zarr group once in main process...")
86-
85+
8786
# Open zarr group once in main process
88-
# For GCS paths with fake server, create FSMap; otherwise use path + storage_options
89-
if str(output_path).startswith("gs://"):
90-
import gcsfs
91-
fs = gcsfs.GCSFileSystem(
92-
endpoint_url="http://localhost:4443",
93-
token="anon",
94-
)
95-
base_url = getattr(getattr(fs, "session", None), "_base_url", "http://localhost:4443")
96-
print(f"[mdio.utils] Using fake GCS mapper via {base_url}")
97-
store = fs.get_mapper(output_path.as_posix().replace("gs://", ""))
98-
storage_options = None
99-
else:
100-
store = output_path.as_posix()
101-
storage_options = _normalize_storage_options(output_path)
87+
store, storage_options = _get_gcs_store(output_path)
10288

10389
zarr_group = zarr_open_group(
10490
store,
@@ -108,7 +94,6 @@ def to_zarr( # noqa: PLR0913, PLR0915
10894
zarr_version=zarr_format,
10995
zarr_format=zarr_format,
11096
)
111-
print("Zarr group opened")
11297

11398
# Get array handles from the opened group
11499
data_array = zarr_group[data_variable_name]
@@ -123,7 +108,7 @@ def to_zarr( # noqa: PLR0913, PLR0915
123108
# 'fork' to avoid deadlocks on cloud stores. Slower but necessary. Default on Windows.
124109
num_workers = min(num_chunks, settings.import_cpus)
125110
context = mp.get_context("spawn")
126-
111+
127112
# Use initializer to open segy file once per worker
128113
executor = ProcessPoolExecutor(
129114
max_workers=num_workers,

0 commit comments

Comments
 (0)