Skip to content

Commit 9c99816

Browse files
committed
Refactor environment managmenet to expose basic functions for lighter imports
1 parent 69c5e87 commit 9c99816

File tree

8 files changed

+169
-162
lines changed

8 files changed

+169
-162
lines changed

src/mdio/api/_environ.py

Lines changed: 107 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -1,121 +1,116 @@
11
"""Environment variable management for MDIO operations."""
22

33
from os import getenv
4-
from typing import Final
54

65
from psutil import cpu_count
76

87
from mdio.converters.exceptions import EnvironmentFormatError
98

9+
# Environment variable keys
10+
_EXPORT_CPUS_KEY = "MDIO__EXPORT__CPU_COUNT"
11+
_IMPORT_CPUS_KEY = "MDIO__IMPORT__CPU_COUNT"
12+
_GRID_SPARSITY_RATIO_WARN_KEY = "MDIO__GRID__SPARSITY_RATIO_WARN"
13+
_GRID_SPARSITY_RATIO_LIMIT_KEY = "MDIO__GRID__SPARSITY_RATIO_LIMIT"
14+
_SAVE_SEGY_FILE_HEADER_KEY = "MDIO__IMPORT__SAVE_SEGY_FILE_HEADER"
15+
_MDIO_SEGY_SPEC_KEY = "MDIO__SEGY__SPEC"
16+
_RAW_HEADERS_KEY = "MDIO__IMPORT__RAW_HEADERS"
17+
_IGNORE_CHECKS_KEY = "MDIO_IGNORE_CHECKS"
18+
_CLOUD_NATIVE_KEY = "MDIO__IMPORT__CLOUD_NATIVE"
1019

11-
class Environment:
12-
"""Unified API for accessing and validating MDIO environment variables."""
13-
14-
# Environment variable keys and defaults
15-
_EXPORT_CPUS_KEY: Final[str] = "MDIO__EXPORT__CPU_COUNT"
16-
_IMPORT_CPUS_KEY: Final[str] = "MDIO__IMPORT__CPU_COUNT"
17-
_GRID_SPARSITY_RATIO_WARN_KEY: Final[str] = "MDIO__GRID__SPARSITY_RATIO_WARN"
18-
_GRID_SPARSITY_RATIO_LIMIT_KEY: Final[str] = "MDIO__GRID__SPARSITY_RATIO_LIMIT"
19-
_SAVE_SEGY_FILE_HEADER_KEY: Final[str] = "MDIO__IMPORT__SAVE_SEGY_FILE_HEADER"
20-
_MDIO_SEGY_SPEC_KEY: Final[str] = "MDIO__SEGY__SPEC"
21-
_RAW_HEADERS_KEY: Final[str] = "MDIO__IMPORT__RAW_HEADERS"
22-
_IGNORE_CHECKS_KEY: Final[str] = "MDIO_IGNORE_CHECKS"
23-
_CLOUD_NATIVE_KEY: Final[str] = "MDIO__IMPORT__CLOUD_NATIVE"
24-
25-
# Default values
26-
_EXPORT_CPUS_DEFAULT: Final[int] = cpu_count(logical=True)
27-
_IMPORT_CPUS_DEFAULT: Final[int] = cpu_count(logical=True)
28-
_GRID_SPARSITY_RATIO_WARN_DEFAULT: Final[str] = "2"
29-
_GRID_SPARSITY_RATIO_LIMIT_DEFAULT: Final[str] = "10"
30-
_SAVE_SEGY_FILE_HEADER_DEFAULT: Final[str] = "false"
31-
_MDIO_SEGY_SPEC_DEFAULT: Final[None] = None
32-
_RAW_HEADERS_DEFAULT: Final[str] = "false"
33-
_IGNORE_CHECKS_DEFAULT: Final[str] = "false"
34-
_CLOUD_NATIVE_DEFAULT: Final[str] = "false"
35-
36-
@classmethod
37-
def _get_env_value(cls, key: str, default: str | int | None) -> str | None:
38-
"""Get environment variable value with fallback to default."""
39-
if isinstance(default, int):
40-
default = str(default)
41-
return getenv(key, default)
42-
43-
@staticmethod
44-
def _parse_bool(value: str | None) -> bool:
45-
"""Parse string value to boolean."""
46-
if value is None:
47-
return False
48-
return value.lower() in ("1", "true", "yes", "on")
49-
50-
@staticmethod
51-
def _parse_int(value: str | None, key: str) -> int:
52-
"""Parse string value to integer with validation."""
53-
if value is None:
54-
raise EnvironmentFormatError(key, "int")
55-
try:
56-
return int(value)
57-
except ValueError as e:
58-
raise EnvironmentFormatError(key, "int") from e
59-
60-
@staticmethod
61-
def _parse_float(value: str | None, key: str) -> float:
62-
"""Parse string value to float with validation."""
63-
if value is None:
64-
raise EnvironmentFormatError(key, "float")
65-
try:
66-
return float(value)
67-
except ValueError as e:
68-
raise EnvironmentFormatError(key, "float") from e
69-
70-
@classmethod
71-
def export_cpus(cls) -> int:
72-
"""Number of CPUs to use for export operations."""
73-
value = cls._get_env_value(cls._EXPORT_CPUS_KEY, cls._EXPORT_CPUS_DEFAULT)
74-
return cls._parse_int(value, cls._EXPORT_CPUS_KEY)
75-
76-
@classmethod
77-
def import_cpus(cls) -> int:
78-
"""Number of CPUs to use for import operations."""
79-
value = cls._get_env_value(cls._IMPORT_CPUS_KEY, cls._IMPORT_CPUS_DEFAULT)
80-
return cls._parse_int(value, cls._IMPORT_CPUS_KEY)
81-
82-
@classmethod
83-
def grid_sparsity_ratio_warn(cls) -> float:
84-
"""Sparsity ratio threshold for warnings."""
85-
value = cls._get_env_value(cls._GRID_SPARSITY_RATIO_WARN_KEY, cls._GRID_SPARSITY_RATIO_WARN_DEFAULT)
86-
return cls._parse_float(value, cls._GRID_SPARSITY_RATIO_WARN_KEY)
87-
88-
@classmethod
89-
def grid_sparsity_ratio_limit(cls) -> float:
90-
"""Sparsity ratio threshold for errors."""
91-
value = cls._get_env_value(cls._GRID_SPARSITY_RATIO_LIMIT_KEY, cls._GRID_SPARSITY_RATIO_LIMIT_DEFAULT)
92-
return cls._parse_float(value, cls._GRID_SPARSITY_RATIO_LIMIT_KEY)
93-
94-
@classmethod
95-
def save_segy_file_header(cls) -> bool:
96-
"""Whether to save SEG-Y file headers."""
97-
value = cls._get_env_value(cls._SAVE_SEGY_FILE_HEADER_KEY, cls._SAVE_SEGY_FILE_HEADER_DEFAULT)
98-
return cls._parse_bool(value)
99-
100-
@classmethod
101-
def mdio_segy_spec(cls) -> str | None:
102-
"""Path to MDIO SEG-Y specification file."""
103-
return cls._get_env_value(cls._MDIO_SEGY_SPEC_KEY, cls._MDIO_SEGY_SPEC_DEFAULT)
104-
105-
@classmethod
106-
def raw_headers(cls) -> bool:
107-
"""Whether to preserve raw headers."""
108-
value = cls._get_env_value(cls._RAW_HEADERS_KEY, cls._RAW_HEADERS_DEFAULT)
109-
return cls._parse_bool(value)
110-
111-
@classmethod
112-
def ignore_checks(cls) -> bool:
113-
"""Whether to ignore validation checks."""
114-
value = cls._get_env_value(cls._IGNORE_CHECKS_KEY, cls._IGNORE_CHECKS_DEFAULT)
115-
return cls._parse_bool(value)
116-
117-
@classmethod
118-
def cloud_native(cls) -> bool:
119-
"""Whether to use cloud-native mode for SEG-Y processing."""
120-
value = cls._get_env_value(cls._CLOUD_NATIVE_KEY, cls._CLOUD_NATIVE_DEFAULT)
121-
return cls._parse_bool(value)
20+
# Default values
21+
_EXPORT_CPUS_DEFAULT = cpu_count(logical=True)
22+
_IMPORT_CPUS_DEFAULT = cpu_count(logical=True)
23+
_GRID_SPARSITY_RATIO_WARN_DEFAULT = "2"
24+
_GRID_SPARSITY_RATIO_LIMIT_DEFAULT = "10"
25+
_SAVE_SEGY_FILE_HEADER_DEFAULT = "false"
26+
_MDIO_SEGY_SPEC_DEFAULT = None
27+
_RAW_HEADERS_DEFAULT = "false"
28+
_IGNORE_CHECKS_DEFAULT = "false"
29+
_CLOUD_NATIVE_DEFAULT = "false"
30+
31+
32+
def _get_env_value(key: str, default: str | int | None) -> str | None:
33+
"""Get environment variable value with fallback to default."""
34+
if isinstance(default, int):
35+
default = str(default)
36+
return getenv(key, default)
37+
38+
39+
def _parse_bool(value: str | None) -> bool:
40+
"""Parse string value to boolean."""
41+
if value is None:
42+
return False
43+
return value.lower() in ("1", "true", "yes", "on")
44+
45+
46+
def _parse_int(value: str | None, key: str) -> int:
47+
"""Parse string value to integer with validation."""
48+
if value is None:
49+
raise EnvironmentFormatError(key, "int")
50+
try:
51+
return int(value)
52+
except ValueError as e:
53+
raise EnvironmentFormatError(key, "int") from e
54+
55+
56+
def _parse_float(value: str | None, key: str) -> float:
57+
"""Parse string value to float with validation."""
58+
if value is None:
59+
raise EnvironmentFormatError(key, "float")
60+
try:
61+
return float(value)
62+
except ValueError as e:
63+
raise EnvironmentFormatError(key, "float") from e
64+
65+
66+
def export_cpus() -> int:
67+
"""Number of CPUs to use for export operations."""
68+
value = _get_env_value(_EXPORT_CPUS_KEY, _EXPORT_CPUS_DEFAULT)
69+
return _parse_int(value, _EXPORT_CPUS_KEY)
70+
71+
72+
def import_cpus() -> int:
73+
"""Number of CPUs to use for import operations."""
74+
value = _get_env_value(_IMPORT_CPUS_KEY, _IMPORT_CPUS_DEFAULT)
75+
return _parse_int(value, _IMPORT_CPUS_KEY)
76+
77+
78+
def grid_sparsity_ratio_warn() -> float:
79+
"""Sparsity ratio threshold for warnings."""
80+
value = _get_env_value(_GRID_SPARSITY_RATIO_WARN_KEY, _GRID_SPARSITY_RATIO_WARN_DEFAULT)
81+
return _parse_float(value, _GRID_SPARSITY_RATIO_WARN_KEY)
82+
83+
84+
def grid_sparsity_ratio_limit() -> float:
85+
"""Sparsity ratio threshold for errors."""
86+
value = _get_env_value(_GRID_SPARSITY_RATIO_LIMIT_KEY, _GRID_SPARSITY_RATIO_LIMIT_DEFAULT)
87+
return _parse_float(value, _GRID_SPARSITY_RATIO_LIMIT_KEY)
88+
89+
90+
def save_segy_file_header() -> bool:
91+
"""Whether to save SEG-Y file headers."""
92+
value = _get_env_value(_SAVE_SEGY_FILE_HEADER_KEY, _SAVE_SEGY_FILE_HEADER_DEFAULT)
93+
return _parse_bool(value)
94+
95+
96+
def mdio_segy_spec() -> str | None:
97+
"""Path to MDIO SEG-Y specification file."""
98+
return _get_env_value(_MDIO_SEGY_SPEC_KEY, _MDIO_SEGY_SPEC_DEFAULT)
99+
100+
101+
def raw_headers() -> bool:
102+
"""Whether to preserve raw headers."""
103+
value = _get_env_value(_RAW_HEADERS_KEY, _RAW_HEADERS_DEFAULT)
104+
return _parse_bool(value)
105+
106+
107+
def ignore_checks() -> bool:
108+
"""Whether to ignore validation checks."""
109+
value = _get_env_value(_IGNORE_CHECKS_KEY, _IGNORE_CHECKS_DEFAULT)
110+
return _parse_bool(value)
111+
112+
113+
def cloud_native() -> bool:
114+
"""Whether to use cloud-native mode for SEG-Y processing."""
115+
value = _get_env_value(_CLOUD_NATIVE_KEY, _CLOUD_NATIVE_DEFAULT)
116+
return _parse_bool(value)

src/mdio/converters/mdio.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import numpy as np
99
from tqdm.dask import TqdmCallback
1010

11-
from mdio.api._environ import Environment
11+
from mdio.api._environ import export_cpus
1212
from mdio.api.io import _normalize_path
1313
from mdio.api.io import open_mdio
1414
from mdio.segy.blocked_io import to_segy
@@ -143,7 +143,7 @@ def mdio_to_segy( # noqa: PLR0912, PLR0913, PLR0915
143143
if client is not None:
144144
block_records = block_records.compute()
145145
else:
146-
block_records = block_records.compute(num_workers=Environment.export_cpus())
146+
block_records = block_records.compute(num_workers=export_cpus())
147147

148148
ordered_files = [rec.path for rec in block_records.ravel() if rec != 0]
149149
ordered_files = [output_path] + ordered_files

src/mdio/converters/segy.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,11 @@
1313
from segy.standards.codes import MeasurementSystem as SegyMeasurementSystem
1414
from segy.standards.fields import binary as binary_header_fields
1515

16-
from mdio.api._environ import Environment
16+
from mdio.api._environ import grid_sparsity_ratio_limit
17+
from mdio.api._environ import grid_sparsity_ratio_warn
18+
from mdio.api._environ import ignore_checks as ignore_checks_env
19+
from mdio.api._environ import raw_headers
20+
from mdio.api._environ import save_segy_file_header
1721
from mdio.api.io import _normalize_path
1822
from mdio.api.io import to_mdio
1923
from mdio.builder.schemas.chunk_grid import RegularChunkGrid
@@ -99,9 +103,9 @@ def grid_density_qc(grid: Grid, num_traces: int) -> None:
99103
sparsity_ratio = float("inf") if num_traces == 0 else grid_traces / num_traces
100104

101105
# Fetch and validate environment variables
102-
warning_ratio = Environment.grid_sparsity_ratio_warn()
103-
error_ratio = Environment.grid_sparsity_ratio_limit()
104-
ignore_checks = Environment.ignore_checks()
106+
warning_ratio = grid_sparsity_ratio_warn()
107+
error_ratio = grid_sparsity_ratio_limit()
108+
ignore_checks = ignore_checks_env()
105109

106110
# Check sparsity
107111
should_warn = sparsity_ratio > warning_ratio
@@ -359,7 +363,7 @@ def _populate_coordinates(
359363

360364

361365
def _add_segy_file_headers(xr_dataset: xr_Dataset, segy_file_info: SegyFileInfo) -> xr_Dataset:
362-
if not Environment.save_segy_file_header():
366+
if not save_segy_file_header():
363367
return xr_dataset
364368

365369
expected_rows = 40
@@ -383,7 +387,7 @@ def _add_segy_file_headers(xr_dataset: xr_Dataset, segy_file_info: SegyFileInfo)
383387
"binaryHeader": segy_file_info.binary_header_dict,
384388
}
385389
)
386-
if Environment.raw_headers():
390+
if raw_headers():
387391
raw_binary_base64 = base64.b64encode(segy_file_info.raw_binary_headers).decode("ascii")
388392
xr_dataset["segy_file_header"].attrs.update({"rawBinaryHeader": raw_binary_base64})
389393

@@ -550,7 +554,7 @@ def segy_to_mdio( # noqa PLR0913
550554
_, non_dim_coords = _get_coordinates(grid, segy_headers, mdio_template)
551555
header_dtype = to_structured_type(segy_spec.trace.header.dtype)
552556

553-
if Environment.raw_headers():
557+
if raw_headers():
554558
if zarr.config.get("default_zarr_format") == ZarrFormat.V2:
555559
logger.warning("Raw headers are only supported for Zarr v3. Skipping raw headers.")
556560
else:

src/mdio/segy/_workers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import numpy as np
99
from segy.arrays import HeaderArray
1010

11-
from mdio.api._environ import Environment
11+
from mdio.api._environ import cloud_native
1212
from mdio.api.io import _normalize_storage_options
1313
from mdio.segy._raw_trace_wrapper import SegyFileRawTraceWrapper
1414
from mdio.segy.file import SegyFileArguments
@@ -50,7 +50,7 @@ def header_scan_worker(
5050

5151
slice_ = slice(*trace_range)
5252

53-
trace_header = segy_file.trace[slice_].header if Environment.cloud_native() else segy_file.header[slice_]
53+
trace_header = segy_file.trace[slice_].header if cloud_native() else segy_file.header[slice_]
5454

5555
if subset is not None:
5656
# struct field selection needs a list, not a tuple; a subset is a tuple from the template.

src/mdio/segy/blocked_io.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from tqdm.auto import tqdm
1616
from zarr import open_group as zarr_open_group
1717

18-
from mdio.api._environ import Environment
18+
from mdio.api._environ import import_cpus
1919
from mdio.api.io import _normalize_storage_options
2020
from mdio.builder.schemas.v1.stats import CenteredBinHistogram
2121
from mdio.builder.schemas.v1.stats import SummaryStatistics
@@ -80,7 +80,7 @@ def to_zarr( # noqa: PLR0913, PLR0915
8080

8181
# For Unix async writes with s3fs/fsspec & multiprocessing, use 'spawn' instead of default
8282
# 'fork' to avoid deadlocks on cloud stores. Slower but necessary. Default on Windows.
83-
num_workers = min(num_chunks, Environment.import_cpus())
83+
num_workers = min(num_chunks, import_cpus())
8484
context = mp.get_context("spawn")
8585
executor = ProcessPoolExecutor(max_workers=num_workers, mp_context=context)
8686

src/mdio/segy/compat.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from segy.schema import TraceSpec
2323
from segy.standards.fields import binary
2424

25-
from mdio.api._environ import Environment
25+
from mdio.api._environ import mdio_segy_spec as get_mdio_segy_spec
2626
from mdio.exceptions import InvalidMDIOError
2727

2828
MDIO_VERSION = metadata.version("multidimio")
@@ -74,7 +74,7 @@ def get_trace_fields(version_str: str) -> list[HeaderField]:
7474

7575
def mdio_segy_spec(version_str: str | None = None) -> SegySpec:
7676
"""Get a SEG-Y encoding spec for MDIO based on version."""
77-
spec_override = Environment.mdio_segy_spec()
77+
spec_override = get_mdio_segy_spec()
7878

7979
if spec_override is not None:
8080
return SegySpec.model_validate_json(spec_override)

src/mdio/segy/parsers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import numpy as np
1212
from tqdm.auto import tqdm
1313

14-
from mdio.api._environ import Environment
14+
from mdio.api._environ import import_cpus
1515
from mdio.segy._workers import header_scan_worker
1616

1717
if TYPE_CHECKING:
@@ -50,7 +50,7 @@ def parse_headers(
5050

5151
trace_ranges.append((start, stop))
5252

53-
num_workers = min(n_blocks, Environment.import_cpus())
53+
num_workers = min(n_blocks, import_cpus())
5454

5555
tqdm_kw = {"unit": "block", "dynamic_ncols": True}
5656
# For Unix async writes with s3fs/fsspec & multiprocessing, use 'spawn' instead of default

0 commit comments

Comments
 (0)