Skip to content

Commit 879ddca

Browse files
BrianMichelltasansal
authored andcommitted
Add isolated environment variable getters
1 parent 8c40d61 commit 879ddca

File tree

8 files changed

+251
-49
lines changed

8 files changed

+251
-49
lines changed

src/mdio/api/_environ.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
"""Environment variable management for MDIO operations."""
2+
3+
from os import getenv
4+
from typing import Final
5+
6+
from psutil import cpu_count
7+
8+
from mdio.converters.exceptions import EnvironmentFormatError
9+
10+
11+
class Environment:
12+
"""Unified API for accessing and validating MDIO environment variables."""
13+
14+
# Environment variable keys and defaults
15+
_EXPORT_CPUS_KEY: Final[str] = "MDIO__EXPORT__CPU_COUNT"
16+
_IMPORT_CPUS_KEY: Final[str] = "MDIO__IMPORT__CPU_COUNT"
17+
_GRID_SPARSITY_RATIO_WARN_KEY: Final[str] = "MDIO__GRID__SPARSITY_RATIO_WARN"
18+
_GRID_SPARSITY_RATIO_LIMIT_KEY: Final[str] = "MDIO__GRID__SPARSITY_RATIO_LIMIT"
19+
_SAVE_SEGY_FILE_HEADER_KEY: Final[str] = "MDIO__IMPORT__SAVE_SEGY_FILE_HEADER"
20+
_MDIO_SEGY_SPEC_KEY: Final[str] = "MDIO__SEGY__SPEC"
21+
_RAW_HEADERS_KEY: Final[str] = "MDIO__IMPORT__RAW_HEADERS"
22+
_IGNORE_CHECKS_KEY: Final[str] = "MDIO_IGNORE_CHECKS"
23+
_CLOUD_NATIVE_KEY: Final[str] = "MDIO__IMPORT__CLOUD_NATIVE"
24+
25+
# Default values
26+
_EXPORT_CPUS_DEFAULT: Final[int] = cpu_count(logical=True)
27+
_IMPORT_CPUS_DEFAULT: Final[int] = cpu_count(logical=True)
28+
_GRID_SPARSITY_RATIO_WARN_DEFAULT: Final[str] = "2"
29+
_GRID_SPARSITY_RATIO_LIMIT_DEFAULT: Final[str] = "10"
30+
_SAVE_SEGY_FILE_HEADER_DEFAULT: Final[str] = "false"
31+
_MDIO_SEGY_SPEC_DEFAULT: Final[None] = None
32+
_RAW_HEADERS_DEFAULT: Final[str] = "false"
33+
_IGNORE_CHECKS_DEFAULT: Final[str] = "false"
34+
_CLOUD_NATIVE_DEFAULT: Final[str] = "false"
35+
36+
@classmethod
37+
def _get_env_value(cls, key: str, default: str | int | None) -> str | None:
38+
"""Get environment variable value with fallback to default."""
39+
if isinstance(default, int):
40+
default = str(default)
41+
return getenv(key, default)
42+
43+
@staticmethod
44+
def _parse_bool(value: str | None) -> bool:
45+
"""Parse string value to boolean."""
46+
if value is None:
47+
return False
48+
return value.lower() in ("1", "true", "yes", "on")
49+
50+
@staticmethod
51+
def _parse_int(value: str | None, key: str) -> int:
52+
"""Parse string value to integer with validation."""
53+
if value is None:
54+
raise EnvironmentFormatError(key, "int")
55+
try:
56+
return int(value)
57+
except ValueError as e:
58+
raise EnvironmentFormatError(key, "int") from e
59+
60+
@staticmethod
61+
def _parse_float(value: str | None, key: str) -> float:
62+
"""Parse string value to float with validation."""
63+
if value is None:
64+
raise EnvironmentFormatError(key, "float")
65+
try:
66+
return float(value)
67+
except ValueError as e:
68+
raise EnvironmentFormatError(key, "float") from e
69+
70+
@classmethod
71+
def export_cpus(cls) -> int:
72+
"""Number of CPUs to use for export operations."""
73+
value = cls._get_env_value(cls._EXPORT_CPUS_KEY, cls._EXPORT_CPUS_DEFAULT)
74+
return cls._parse_int(value, cls._EXPORT_CPUS_KEY)
75+
76+
@classmethod
77+
def import_cpus(cls) -> int:
78+
"""Number of CPUs to use for import operations."""
79+
value = cls._get_env_value(cls._IMPORT_CPUS_KEY, cls._IMPORT_CPUS_DEFAULT)
80+
return cls._parse_int(value, cls._IMPORT_CPUS_KEY)
81+
82+
@classmethod
83+
def grid_sparsity_ratio_warn(cls) -> float:
84+
"""Sparsity ratio threshold for warnings."""
85+
value = cls._get_env_value(cls._GRID_SPARSITY_RATIO_WARN_KEY, cls._GRID_SPARSITY_RATIO_WARN_DEFAULT)
86+
return cls._parse_float(value, cls._GRID_SPARSITY_RATIO_WARN_KEY)
87+
88+
@classmethod
89+
def grid_sparsity_ratio_limit(cls) -> float:
90+
"""Sparsity ratio threshold for errors."""
91+
value = cls._get_env_value(cls._GRID_SPARSITY_RATIO_LIMIT_KEY, cls._GRID_SPARSITY_RATIO_LIMIT_DEFAULT)
92+
return cls._parse_float(value, cls._GRID_SPARSITY_RATIO_LIMIT_KEY)
93+
94+
@classmethod
95+
def save_segy_file_header(cls) -> bool:
96+
"""Whether to save SEG-Y file headers."""
97+
value = cls._get_env_value(cls._SAVE_SEGY_FILE_HEADER_KEY, cls._SAVE_SEGY_FILE_HEADER_DEFAULT)
98+
return cls._parse_bool(value)
99+
100+
@classmethod
101+
def mdio_segy_spec(cls) -> str | None:
102+
"""Path to MDIO SEG-Y specification file."""
103+
return cls._get_env_value(cls._MDIO_SEGY_SPEC_KEY, cls._MDIO_SEGY_SPEC_DEFAULT)
104+
105+
@classmethod
106+
def raw_headers(cls) -> bool:
107+
"""Whether to preserve raw headers."""
108+
value = cls._get_env_value(cls._RAW_HEADERS_KEY, cls._RAW_HEADERS_DEFAULT)
109+
return cls._parse_bool(value)
110+
111+
@classmethod
112+
def ignore_checks(cls) -> bool:
113+
"""Whether to ignore validation checks."""
114+
value = cls._get_env_value(cls._IGNORE_CHECKS_KEY, cls._IGNORE_CHECKS_DEFAULT)
115+
return cls._parse_bool(value)
116+
117+
@classmethod
118+
def cloud_native(cls) -> bool:
119+
"""Whether to use cloud-native mode for SEG-Y processing."""
120+
value = cls._get_env_value(cls._CLOUD_NATIVE_KEY, cls._CLOUD_NATIVE_DEFAULT)
121+
return cls._parse_bool(value)

src/mdio/converters/mdio.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,13 @@
22

33
from __future__ import annotations
44

5-
import os
65
from tempfile import TemporaryDirectory
76
from typing import TYPE_CHECKING
87

98
import numpy as np
10-
from psutil import cpu_count
119
from tqdm.dask import TqdmCallback
1210

11+
from mdio.api._environ import Environment
1312
from mdio.api.io import _normalize_path
1413
from mdio.api.io import open_mdio
1514
from mdio.segy.blocked_io import to_segy
@@ -29,10 +28,6 @@
2928
from upath import UPath
3029

3130

32-
default_cpus = cpu_count(logical=True)
33-
NUM_CPUS = int(os.getenv("MDIO__EXPORT__CPU_COUNT", default_cpus))
34-
35-
3631
def mdio_to_segy( # noqa: PLR0912, PLR0913, PLR0915
3732
segy_spec: SegySpec,
3833
input_path: UPath | Path | str,
@@ -148,7 +143,7 @@ def mdio_to_segy( # noqa: PLR0912, PLR0913, PLR0915
148143
if client is not None:
149144
block_records = block_records.compute()
150145
else:
151-
block_records = block_records.compute(num_workers=NUM_CPUS)
146+
block_records = block_records.compute(num_workers=Environment.export_cpus())
152147

153148
ordered_files = [rec.path for rec in block_records.ravel() if rec != 0]
154149
ordered_files = [output_path] + ordered_files

src/mdio/converters/segy.py

Lines changed: 7 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
import base64
66
import logging
7-
import os
87
from typing import TYPE_CHECKING
98

109
import numpy as np
@@ -14,6 +13,7 @@
1413
from segy.standards.codes import MeasurementSystem as SegyMeasurementSystem
1514
from segy.standards.fields import binary as binary_header_fields
1615

16+
from mdio.api._environ import Environment
1717
from mdio.api.io import _normalize_path
1818
from mdio.api.io import to_mdio
1919
from mdio.builder.schemas.chunk_grid import RegularChunkGrid
@@ -28,7 +28,6 @@
2828
from mdio.builder.schemas.v1.variable import VariableMetadata
2929
from mdio.builder.xarray_builder import to_xarray_dataset
3030
from mdio.constants import ZarrFormat
31-
from mdio.converters.exceptions import EnvironmentFormatError
3231
from mdio.converters.exceptions import GridTraceCountError
3332
from mdio.converters.exceptions import GridTraceSparsityError
3433
from mdio.converters.type_converter import to_structured_type
@@ -92,8 +91,6 @@ def grid_density_qc(grid: Grid, num_traces: int) -> None:
9291
Raises:
9392
GridTraceSparsityError: If the sparsity ratio exceeds `MDIO__GRID__SPARSITY_RATIO_LIMIT`
9493
and `MDIO_IGNORE_CHECKS` is not set to a truthy value (e.g., "1", "true").
95-
EnvironmentFormatError: If `MDIO__GRID__SPARSITY_RATIO_WARN` or
96-
`MDIO__GRID__SPARSITY_RATIO_LIMIT` cannot be converted to a float.
9794
"""
9895
# Calculate total possible traces in the grid (excluding sample dimension)
9996
grid_traces = np.prod(grid.shape[:-1], dtype=np.uint64)
@@ -102,20 +99,9 @@ def grid_density_qc(grid: Grid, num_traces: int) -> None:
10299
sparsity_ratio = float("inf") if num_traces == 0 else grid_traces / num_traces
103100

104101
# Fetch and validate environment variables
105-
warning_ratio_env = os.getenv("MDIO__GRID__SPARSITY_RATIO_WARN", "2")
106-
error_ratio_env = os.getenv("MDIO__GRID__SPARSITY_RATIO_LIMIT", "10")
107-
ignore_checks_env = os.getenv("MDIO_IGNORE_CHECKS", "false").lower()
108-
ignore_checks = ignore_checks_env in ("1", "true", "yes", "on")
109-
110-
try:
111-
warning_ratio = float(warning_ratio_env)
112-
except ValueError as e:
113-
raise EnvironmentFormatError("MDIO__GRID__SPARSITY_RATIO_WARN", "float") from e # noqa: EM101
114-
115-
try:
116-
error_ratio = float(error_ratio_env)
117-
except ValueError as e:
118-
raise EnvironmentFormatError("MDIO__GRID__SPARSITY_RATIO_LIMIT", "float") from e # noqa: EM101
102+
warning_ratio = Environment.grid_sparsity_ratio_warn()
103+
error_ratio = Environment.grid_sparsity_ratio_limit()
104+
ignore_checks = Environment.ignore_checks()
119105

120106
# Check sparsity
121107
should_warn = sparsity_ratio > warning_ratio
@@ -373,7 +359,7 @@ def _populate_coordinates(
373359

374360

375361
def _add_segy_file_headers(xr_dataset: xr_Dataset, segy_file_info: SegyFileInfo) -> xr_Dataset:
376-
save_file_header = os.getenv("MDIO__IMPORT__SAVE_SEGY_FILE_HEADER", "") in ("1", "true", "yes", "on")
362+
save_file_header = Environment.save_segy_file_header()
377363
if not save_file_header:
378364
return xr_dataset
379365

@@ -398,7 +384,7 @@ def _add_segy_file_headers(xr_dataset: xr_Dataset, segy_file_info: SegyFileInfo)
398384
"binaryHeader": segy_file_info.binary_header_dict,
399385
}
400386
)
401-
if os.getenv("MDIO__IMPORT__RAW_HEADERS") in ("1", "true", "yes", "on"):
387+
if Environment.raw_headers():
402388
raw_binary_base64 = base64.b64encode(segy_file_info.raw_binary_headers).decode("ascii")
403389
xr_dataset["segy_file_header"].attrs.update({"rawBinaryHeader": raw_binary_base64})
404390

@@ -565,7 +551,7 @@ def segy_to_mdio( # noqa PLR0913
565551
_, non_dim_coords = _get_coordinates(grid, segy_headers, mdio_template)
566552
header_dtype = to_structured_type(segy_spec.trace.header.dtype)
567553

568-
if os.getenv("MDIO__IMPORT__RAW_HEADERS") in ("1", "true", "yes", "on"):
554+
if Environment.raw_headers():
569555
if zarr.config.get("default_zarr_format") == ZarrFormat.V2:
570556
logger.warning("Raw headers are only supported for Zarr v3. Skipping raw headers.")
571557
else:

src/mdio/segy/_workers.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,12 @@
33
from __future__ import annotations
44

55
import logging
6-
import os
76
from typing import TYPE_CHECKING
87

98
import numpy as np
109
from segy.arrays import HeaderArray
1110

11+
from mdio.api._environ import Environment
1212
from mdio.api.io import _normalize_storage_options
1313
from mdio.segy._raw_trace_wrapper import SegyFileRawTraceWrapper
1414
from mdio.segy.file import SegyFileArguments
@@ -50,12 +50,7 @@ def header_scan_worker(
5050

5151
slice_ = slice(*trace_range)
5252

53-
cloud_native_mode = os.getenv("MDIO__IMPORT__CLOUD_NATIVE", default="False")
54-
55-
if cloud_native_mode.lower() in {"true", "1"}:
56-
trace_header = segy_file.trace[slice_].header
57-
else:
58-
trace_header = segy_file.header[slice_]
53+
trace_header = segy_file.trace[slice_].header if Environment.cloud_native() else segy_file.header[slice_]
5954

6055
if subset is not None:
6156
# struct field selection needs a list, not a tuple; a subset is a tuple from the template.

src/mdio/segy/blocked_io.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from __future__ import annotations
44

55
import multiprocessing as mp
6-
import os
76
from concurrent.futures import ProcessPoolExecutor
87
from concurrent.futures import as_completed
98
from pathlib import Path
@@ -13,10 +12,10 @@
1312
import zarr
1413
from dask.array import Array
1514
from dask.array import map_blocks
16-
from psutil import cpu_count
1715
from tqdm.auto import tqdm
1816
from zarr import open_group as zarr_open_group
1917

18+
from mdio.api._environ import Environment
2019
from mdio.api.io import _normalize_storage_options
2120
from mdio.builder.schemas.v1.stats import CenteredBinHistogram
2221
from mdio.builder.schemas.v1.stats import SummaryStatistics
@@ -37,8 +36,6 @@
3736

3837
from mdio.segy.file import SegyFileArguments
3938

40-
default_cpus = cpu_count(logical=True)
41-
4239

4340
def _create_stats() -> SummaryStatistics:
4441
histogram = CenteredBinHistogram(bin_centers=[], counts=[])
@@ -83,8 +80,7 @@ def to_zarr( # noqa: PLR0913, PLR0915
8380

8481
# For Unix async writes with s3fs/fsspec & multiprocessing, use 'spawn' instead of default
8582
# 'fork' to avoid deadlocks on cloud stores. Slower but necessary. Default on Windows.
86-
num_cpus = int(os.getenv("MDIO__IMPORT__CPU_COUNT", default_cpus))
87-
num_workers = min(num_chunks, num_cpus)
83+
num_workers = min(num_chunks, Environment.import_cpus())
8884
context = mp.get_context("spawn")
8985
executor = ProcessPoolExecutor(max_workers=num_workers, mp_context=context)
9086

src/mdio/segy/compat.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from __future__ import annotations
99

1010
import logging
11-
import os
1211
from importlib import metadata
1312

1413
from packaging import version
@@ -23,6 +22,7 @@
2322
from segy.schema import TraceSpec
2423
from segy.standards.fields import binary
2524

25+
from mdio.api._environ import Environment
2626
from mdio.exceptions import InvalidMDIOError
2727

2828
MDIO_VERSION = metadata.version("multidimio")
@@ -74,7 +74,7 @@ def get_trace_fields(version_str: str) -> list[HeaderField]:
7474

7575
def mdio_segy_spec(version_str: str | None = None) -> SegySpec:
7676
"""Get a SEG-Y encoding spec for MDIO based on version."""
77-
spec_override = os.getenv("MDIO__SEGY__SPEC")
77+
spec_override = Environment.mdio_segy_spec()
7878

7979
if spec_override is not None:
8080
return SegySpec.model_validate_json(spec_override)

src/mdio/segy/parsers.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,25 +3,22 @@
33
from __future__ import annotations
44

55
import multiprocessing as mp
6-
import os
76
from concurrent.futures import ProcessPoolExecutor
87
from itertools import repeat
98
from math import ceil
109
from typing import TYPE_CHECKING
1110

1211
import numpy as np
13-
from psutil import cpu_count
1412
from tqdm.auto import tqdm
1513

14+
from mdio.api._environ import Environment
1615
from mdio.segy._workers import header_scan_worker
1716

1817
if TYPE_CHECKING:
1918
from segy.arrays import HeaderArray
2019

2120
from mdio.segy.file import SegyFileArguments
2221

23-
default_cpus = cpu_count(logical=True)
24-
2522

2623
def parse_headers(
2724
segy_file_kwargs: SegyFileArguments,
@@ -53,8 +50,7 @@ def parse_headers(
5350

5451
trace_ranges.append((start, stop))
5552

56-
num_cpus = int(os.getenv("MDIO__IMPORT__CPU_COUNT", default_cpus))
57-
num_workers = min(n_blocks, num_cpus)
53+
num_workers = min(n_blocks, Environment.import_cpus())
5854

5955
tqdm_kw = {"unit": "block", "dynamic_ncols": True}
6056
# For Unix async writes with s3fs/fsspec & multiprocessing, use 'spawn' instead of default

0 commit comments

Comments
 (0)