Skip to content

Commit 417f766

Browse files
committed
Added 'field_validator' functions for the 'calibrations' and 'software_versions' keys in the 'MachineConfig' model; added helper model to correctly validate the magnification calibration table
1 parent 61e2192 commit 417f766

File tree

1 file changed

+50
-5
lines changed

1 file changed

+50
-5
lines changed

src/murfey/util/config.py

Lines changed: 50 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,23 @@
44
import socket
55
from functools import lru_cache
66
from pathlib import Path
7-
from typing import Literal, Optional, Union
7+
from typing import Any, Literal, Optional
88

99
import yaml
1010
from backports.entry_points_selectable import entry_points
11-
from pydantic import BaseModel, ConfigDict, field_validator
11+
from pydantic import BaseModel, ConfigDict, RootModel, ValidationInfo, field_validator
1212
from pydantic_settings import BaseSettings
1313

1414

15+
class MagnificationTable(RootModel[dict[int, float]]):
16+
pass
17+
18+
19+
CALIBRATIONS_VALIDATION_SCHEMAS = {
20+
"magnification": MagnificationTable,
21+
}
22+
23+
1524
class MachineConfig(BaseModel): # type: ignore
1625
"""
1726
Keys that describe the type of workflow conducted on the client side, and how
@@ -27,7 +36,7 @@ class MachineConfig(BaseModel): # type: ignore
2736
# Hardware and software -----------------------------------------------------------
2837
camera: str = "FALCON"
2938
superres: bool = False
30-
calibrations: dict[str, dict[str, Union[dict, float]]]
39+
calibrations: dict[str, Any]
3140
acquisition_software: list[str]
3241
software_versions: dict[str, str] = {}
3342
software_settings_output_directories: dict[str, list[str]] = {}
@@ -94,8 +103,44 @@ class MachineConfig(BaseModel): # type: ignore
94103
node_creator_queue: str = "node_creator"
95104
notifications_queue: str = "pato_notification"
96105

106+
# Pydantic BaseModel settings
97107
model_config = ConfigDict(extra="allow")
98108

109+
@field_validator("calibrations", mode="before")
110+
@classmethod
111+
def validate_calibration_data(
112+
cls, v: dict[str, dict[Any, Any]]
113+
) -> dict[str, dict[Any, Any]]:
114+
# Pass the calibration dictionaries through their matching Pydantic models, if any are set
115+
if isinstance(v, dict):
116+
validated = {}
117+
for (
118+
key,
119+
value,
120+
) in v.items():
121+
model_cls = CALIBRATIONS_VALIDATION_SCHEMAS.get(key)
122+
if model_cls:
123+
try:
124+
# Validate and store as a dict object with the corrected types
125+
validated[key] = model_cls.model_validate(value).root
126+
except Exception as e:
127+
raise ValueError(f"Validation failed for key '{key}': {e}")
128+
else:
129+
validated[key] = value
130+
return validated
131+
# Let it validate and fail as-is
132+
return v
133+
134+
@field_validator("software_versions", mode="before")
135+
@classmethod
136+
def validate_software_versions(cls, v: dict[str, Any]) -> dict[str, str]:
137+
# Software versions should be numerical strings, even if they appear int- or float-like
138+
if isinstance(v, dict):
139+
validated = {key: str(value) for key, value in v.items()}
140+
return validated
141+
# Let it validate and fail as-is
142+
return v
143+
99144

100145
def from_file(config_file_path: Path, instrument: str = "") -> dict[str, MachineConfig]:
101146
with open(config_file_path, "r") as config_stream:
@@ -141,9 +186,9 @@ class Security(BaseModel):
141186

142187
@field_validator("graylog_port")
143188
def check_port_present_if_host_is(
144-
cls, v: Optional[int], values: dict, **kwargs
189+
cls, v: Optional[int], info: ValidationInfo, **kwargs
145190
) -> Optional[int]:
146-
if values["graylog_host"] and v is None:
191+
if info.data.get("graylog_host") and v is None:
147192
raise ValueError("The Graylog port must be set if the Graylog host is")
148193
return v
149194

0 commit comments

Comments
 (0)