Skip to content

Commit 860339a

Browse files
committed
Update to use pydantic-settings
1 parent e6e5c85 commit 860339a

File tree

3 files changed

+115
-72
lines changed

3 files changed

+115
-72
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ dependencies = [
2525
"pint>=0.25.0",
2626
"psutil>=7.1.0",
2727
"pydantic>=2.12.0",
28+
"pydantic-settings (>=2.6.1,<3.0.0)",
2829
"rich>=14.1.0",
2930
"segy>=0.5.3",
3031
"tqdm>=4.67.1",

src/mdio/api/_environ.py

Lines changed: 112 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -1,116 +1,156 @@
11
"""Environment variable management for MDIO operations."""
22

3-
from os import getenv
4-
53
from psutil import cpu_count
4+
from pydantic import ConfigDict
5+
from pydantic import Field
6+
from pydantic import ValidationError
7+
from pydantic import field_validator
8+
from pydantic_settings import BaseSettings
69

710
from mdio.converters.exceptions import EnvironmentFormatError
811

9-
# Environment variable keys
10-
_EXPORT_CPUS_KEY = "MDIO__EXPORT__CPU_COUNT"
11-
_IMPORT_CPUS_KEY = "MDIO__IMPORT__CPU_COUNT"
12-
_GRID_SPARSITY_RATIO_WARN_KEY = "MDIO__GRID__SPARSITY_RATIO_WARN"
13-
_GRID_SPARSITY_RATIO_LIMIT_KEY = "MDIO__GRID__SPARSITY_RATIO_LIMIT"
14-
_SAVE_SEGY_FILE_HEADER_KEY = "MDIO__IMPORT__SAVE_SEGY_FILE_HEADER"
15-
_MDIO_SEGY_SPEC_KEY = "MDIO__SEGY__SPEC"
16-
_RAW_HEADERS_KEY = "MDIO__IMPORT__RAW_HEADERS"
17-
_IGNORE_CHECKS_KEY = "MDIO_IGNORE_CHECKS"
18-
_CLOUD_NATIVE_KEY = "MDIO__IMPORT__CLOUD_NATIVE"
19-
20-
# Default values
21-
_EXPORT_CPUS_DEFAULT = cpu_count(logical=True)
22-
_IMPORT_CPUS_DEFAULT = cpu_count(logical=True)
23-
_GRID_SPARSITY_RATIO_WARN_DEFAULT = "2"
24-
_GRID_SPARSITY_RATIO_LIMIT_DEFAULT = "10"
25-
_SAVE_SEGY_FILE_HEADER_DEFAULT = "false"
26-
_MDIO_SEGY_SPEC_DEFAULT = None
27-
_RAW_HEADERS_DEFAULT = "false"
28-
_IGNORE_CHECKS_DEFAULT = "false"
29-
_CLOUD_NATIVE_DEFAULT = "false"
30-
31-
32-
def _get_env_value(key: str, default: str | int | None) -> str | None:
33-
"""Get environment variable value with fallback to default."""
34-
if isinstance(default, int):
35-
default = str(default)
36-
return getenv(key, default)
37-
38-
39-
def _parse_bool(value: str | None) -> bool:
40-
"""Parse string value to boolean."""
41-
if value is None:
42-
return False
43-
return value.lower() in ("1", "true", "yes", "on")
44-
45-
46-
def _parse_int(value: str | None, key: str) -> int:
47-
"""Parse string value to integer with validation."""
48-
if value is None:
49-
raise EnvironmentFormatError(key, "int")
50-
try:
51-
return int(value)
52-
except ValueError as e:
53-
raise EnvironmentFormatError(key, "int") from e
54-
5512

56-
def _parse_float(value: str | None, key: str) -> float:
57-
"""Parse string value to float with validation."""
58-
if value is None:
59-
raise EnvironmentFormatError(key, "float")
13+
class MDIOSettings(BaseSettings):
14+
"""MDIO environment configuration settings."""
15+
16+
# CPU configuration
17+
export_cpus: int = Field(
18+
default_factory=lambda: cpu_count(logical=True),
19+
description="Number of CPUs to use for export operations",
20+
alias="MDIO__EXPORT__CPU_COUNT",
21+
)
22+
import_cpus: int = Field(
23+
default_factory=lambda: cpu_count(logical=True),
24+
description="Number of CPUs to use for import operations",
25+
alias="MDIO__IMPORT__CPU_COUNT",
26+
)
27+
28+
# Grid sparsity configuration
29+
grid_sparsity_ratio_warn: float = Field(
30+
default=2.0,
31+
description="Sparsity ratio threshold for warnings",
32+
alias="MDIO__GRID__SPARSITY_RATIO_WARN",
33+
)
34+
grid_sparsity_ratio_limit: float = Field(
35+
default=10.0,
36+
description="Sparsity ratio threshold for errors",
37+
alias="MDIO__GRID__SPARSITY_RATIO_LIMIT",
38+
)
39+
40+
# Import configuration
41+
save_segy_file_header: bool = Field(
42+
default=False,
43+
description="Whether to save SEG-Y file headers",
44+
alias="MDIO__IMPORT__SAVE_SEGY_FILE_HEADER",
45+
)
46+
raw_headers: bool = Field(
47+
default=False,
48+
description="Whether to preserve raw headers",
49+
alias="MDIO__IMPORT__RAW_HEADERS",
50+
)
51+
cloud_native: bool = Field(
52+
default=False,
53+
description="Whether to use cloud-native mode for SEG-Y processing",
54+
alias="MDIO__IMPORT__CLOUD_NATIVE",
55+
)
56+
57+
# General configuration
58+
mdio_segy_spec: str | None = Field(
59+
default=None,
60+
description="Path to MDIO SEG-Y specification file",
61+
alias="MDIO__SEGY__SPEC",
62+
)
63+
ignore_checks: bool = Field(
64+
default=False,
65+
description="Whether to ignore validation checks",
66+
alias="MDIO_IGNORE_CHECKS",
67+
)
68+
69+
model_config = ConfigDict(
70+
env_prefix="",
71+
case_sensitive=True,
72+
)
73+
74+
@field_validator("save_segy_file_header", "raw_headers", "ignore_checks", "cloud_native", mode="before")
75+
@classmethod
76+
def parse_bool_fields(cls, v: object) -> bool:
77+
"""Parse boolean fields leniently, like the original implementation."""
78+
if v is None:
79+
return False
80+
if isinstance(v, str):
81+
return v.lower() in ("1", "true", "yes", "on")
82+
return bool(v)
83+
84+
85+
def _get_settings() -> MDIOSettings:
86+
"""Get current MDIO settings from environment variables."""
6087
try:
61-
return float(value)
62-
except ValueError as e:
63-
raise EnvironmentFormatError(key, "float") from e
88+
return MDIOSettings()
89+
except ValidationError as e:
90+
# Extract the field name and expected type from the error
91+
error_details = e.errors()[0]
92+
field_name = error_details.get("loc", [None])[0]
93+
error_type = error_details.get("type", "unknown")
94+
95+
# Map pydantic error types to our error types
96+
type_mapping = {
97+
"int_parsing": "int",
98+
"float_parsing": "float",
99+
}
100+
mapped_type = type_mapping.get(error_type, error_type)
101+
102+
# Map field names back to environment variable names for the error
103+
env_var_mapping = {
104+
"export_cpus": "MDIO__EXPORT__CPU_COUNT",
105+
"import_cpus": "MDIO__IMPORT__CPU_COUNT",
106+
"grid_sparsity_ratio_warn": "MDIO__GRID__SPARSITY_RATIO_WARN",
107+
"grid_sparsity_ratio_limit": "MDIO__GRID__SPARSITY_RATIO_LIMIT",
108+
}
109+
env_var = env_var_mapping.get(field_name, field_name)
110+
111+
raise EnvironmentFormatError(env_var, mapped_type) from e
64112

65113

66114
def export_cpus() -> int:
67115
"""Number of CPUs to use for export operations."""
68-
value = _get_env_value(_EXPORT_CPUS_KEY, _EXPORT_CPUS_DEFAULT)
69-
return _parse_int(value, _EXPORT_CPUS_KEY)
116+
return _get_settings().export_cpus
70117

71118

72119
def import_cpus() -> int:
73120
"""Number of CPUs to use for import operations."""
74-
value = _get_env_value(_IMPORT_CPUS_KEY, _IMPORT_CPUS_DEFAULT)
75-
return _parse_int(value, _IMPORT_CPUS_KEY)
121+
return _get_settings().import_cpus
76122

77123

78124
def grid_sparsity_ratio_warn() -> float:
79125
"""Sparsity ratio threshold for warnings."""
80-
value = _get_env_value(_GRID_SPARSITY_RATIO_WARN_KEY, _GRID_SPARSITY_RATIO_WARN_DEFAULT)
81-
return _parse_float(value, _GRID_SPARSITY_RATIO_WARN_KEY)
126+
return _get_settings().grid_sparsity_ratio_warn
82127

83128

84129
def grid_sparsity_ratio_limit() -> float:
85130
"""Sparsity ratio threshold for errors."""
86-
value = _get_env_value(_GRID_SPARSITY_RATIO_LIMIT_KEY, _GRID_SPARSITY_RATIO_LIMIT_DEFAULT)
87-
return _parse_float(value, _GRID_SPARSITY_RATIO_LIMIT_KEY)
131+
return _get_settings().grid_sparsity_ratio_limit
88132

89133

90134
def save_segy_file_header() -> bool:
91135
"""Whether to save SEG-Y file headers."""
92-
value = _get_env_value(_SAVE_SEGY_FILE_HEADER_KEY, _SAVE_SEGY_FILE_HEADER_DEFAULT)
93-
return _parse_bool(value)
136+
return _get_settings().save_segy_file_header
94137

95138

96139
def mdio_segy_spec() -> str | None:
97140
"""Path to MDIO SEG-Y specification file."""
98-
return _get_env_value(_MDIO_SEGY_SPEC_KEY, _MDIO_SEGY_SPEC_DEFAULT)
141+
return _get_settings().mdio_segy_spec
99142

100143

101144
def raw_headers() -> bool:
102145
"""Whether to preserve raw headers."""
103-
value = _get_env_value(_RAW_HEADERS_KEY, _RAW_HEADERS_DEFAULT)
104-
return _parse_bool(value)
146+
return _get_settings().raw_headers
105147

106148

107149
def ignore_checks() -> bool:
108150
"""Whether to ignore validation checks."""
109-
value = _get_env_value(_IGNORE_CHECKS_KEY, _IGNORE_CHECKS_DEFAULT)
110-
return _parse_bool(value)
151+
return _get_settings().ignore_checks
111152

112153

113154
def cloud_native() -> bool:
114155
"""Whether to use cloud-native mode for SEG-Y processing."""
115-
value = _get_env_value(_CLOUD_NATIVE_KEY, _CLOUD_NATIVE_DEFAULT)
116-
return _parse_bool(value)
156+
return _get_settings().cloud_native

uv.lock

Lines changed: 2 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)