44import socket
55from functools import lru_cache
66from pathlib import Path
7- from typing import Literal , Optional , Union
7+ from typing import Any , Literal , Optional
88
99import yaml
1010from backports .entry_points_selectable import entry_points
11- from pydantic import BaseModel , ConfigDict , field_validator
11+ from pydantic import BaseModel , ConfigDict , RootModel , ValidationInfo , field_validator
1212from pydantic_settings import BaseSettings
1313
1414
15+ class MagnificationTable (RootModel [dict [int , float ]]):
16+ pass
17+
18+
19+ CALIBRATIONS_VALIDATION_SCHEMAS = {
20+ "magnification" : MagnificationTable ,
21+ }
22+
23+
1524class MachineConfig (BaseModel ): # type: ignore
1625 """
1726 Keys that describe the type of workflow conducted on the client side, and how
@@ -27,7 +36,7 @@ class MachineConfig(BaseModel): # type: ignore
2736 # Hardware and software -----------------------------------------------------------
2837 camera : str = "FALCON"
2938 superres : bool = False
30- calibrations : dict [str , dict [ str , Union [ dict , float ]] ]
39+ calibrations : dict [str , Any ]
3140 acquisition_software : list [str ]
3241 software_versions : dict [str , str ] = {}
3342 software_settings_output_directories : dict [str , list [str ]] = {}
@@ -94,8 +103,44 @@ class MachineConfig(BaseModel): # type: ignore
94103 node_creator_queue : str = "node_creator"
95104 notifications_queue : str = "pato_notification"
96105
106+ # Pydantic BaseModel settings
97107 model_config = ConfigDict (extra = "allow" )
98108
109+ @field_validator ("calibrations" , mode = "before" )
110+ @classmethod
111+ def validate_calibration_data (
112+ cls , v : dict [str , dict [Any , Any ]]
113+ ) -> dict [str , dict [Any , Any ]]:
114+ # Pass the calibration dictionaries through their matching Pydantic models, if any are set
115+ if isinstance (v , dict ):
116+ validated = {}
117+ for (
118+ key ,
119+ value ,
120+ ) in v .items ():
121+ model_cls = CALIBRATIONS_VALIDATION_SCHEMAS .get (key )
122+ if model_cls :
123+ try :
124+ # Validate and store as a dict object with the corrected types
125+ validated [key ] = model_cls .model_validate (value ).root
126+ except Exception as e :
127+ raise ValueError (f"Validation failed for key '{ key } ': { e } " )
128+ else :
129+ validated [key ] = value
130+ return validated
131+ # Let it validate and fail as-is
132+ return v
133+
134+ @field_validator ("software_versions" , mode = "before" )
135+ @classmethod
136+ def validate_software_versions (cls , v : dict [str , Any ]) -> dict [str , str ]:
137+ # Software versions should be numerical strings, even if they appear int- or float-like
138+ if isinstance (v , dict ):
139+ validated = {key : str (value ) for key , value in v .items ()}
140+ return validated
141+ # Let it validate and fail as-is
142+ return v
143+
99144
100145def from_file (config_file_path : Path , instrument : str = "" ) -> dict [str , MachineConfig ]:
101146 with open (config_file_path , "r" ) as config_stream :
@@ -141,9 +186,9 @@ class Security(BaseModel):
141186
142187 @field_validator ("graylog_port" )
143188 def check_port_present_if_host_is (
144- cls , v : Optional [int ], values : dict , ** kwargs
189+ cls , v : Optional [int ], info : ValidationInfo , ** kwargs
145190 ) -> Optional [int ]:
146- if values [ "graylog_host" ] and v is None :
191+ if info . data . get ( "graylog_host" ) and v is None :
147192 raise ValueError ("The Graylog port must be set if the Graylog host is" )
148193 return v
149194
0 commit comments