@@ -42,34 +42,31 @@ def load_config(config_path: Union[str, Path]) -> Config:
4242def save_config (cfg : Config , save_path : Union [str , Path ]) -> None :
4343 """
4444 Save configuration to YAML file.
45-
45+
4646 Args:
4747 cfg: Config object to save
4848 save_path: Path where to save the YAML file
4949 """
5050 save_path = Path (save_path )
5151 save_path .parent .mkdir (parents = True , exist_ok = True )
52-
52+
5353 omega_conf = OmegaConf .structured (cfg )
5454 OmegaConf .save (omega_conf , save_path )
5555
5656
57- def merge_configs (
58- base_cfg : Config ,
59- * override_cfgs : Union [Config , Dict , str , Path ]
60- ) -> Config :
57+ def merge_configs (base_cfg : Config , * override_cfgs : Union [Config , Dict , str , Path ]) -> Config :
6158 """
6259 Merge multiple configurations together.
63-
60+
6461 Args:
6562 base_cfg: Base configuration
6663 *override_cfgs: One or more override configs (Config, dict, or path to YAML)
67-
64+
6865 Returns:
6966 Merged Config object
7067 """
7168 result = OmegaConf .structured (base_cfg )
72-
69+
7370 for override_cfg in override_cfgs :
7471 if isinstance (override_cfg , (str , Path )):
7572 override_omega = OmegaConf .load (override_cfg )
@@ -79,22 +76,22 @@ def merge_configs(
7976 override_omega = OmegaConf .create (override_cfg )
8077 else :
8178 raise TypeError (f"Unsupported config type: { type (override_cfg )} " )
82-
79+
8380 result = OmegaConf .merge (result , override_omega )
84-
81+
8582 return OmegaConf .to_object (result )
8683
8784
8885def update_from_cli (cfg : Config , overrides : List [str ]) -> Config :
8986 """
9087 Update config from command-line overrides.
91-
88+
9289 Supports dot notation: ['data.batch_size=4', 'model.architecture=unet3d']
93-
90+
9491 Args:
9592 cfg: Base Config object
9693 overrides: List of 'key=value' strings
97-
94+
9895 Returns:
9996 Updated Config object
10097 """
@@ -107,11 +104,11 @@ def update_from_cli(cfg: Config, overrides: List[str]) -> Config:
107104def to_dict (cfg : Config , resolve : bool = True ) -> Dict [str , Any ]:
108105 """
109106 Convert Config to dictionary.
110-
107+
111108 Args:
112109 cfg: Config object
113110 resolve: Whether to resolve variable interpolations
114-
111+
115112 Returns:
116113 Dictionary representation
117114 """
@@ -122,10 +119,10 @@ def to_dict(cfg: Config, resolve: bool = True) -> Dict[str, Any]:
122119def from_dict (d : Dict [str , Any ]) -> Config :
123120 """
124121 Create Config from dictionary.
125-
122+
126123 Args:
127124 d: Dictionary with configuration values
128-
125+
129126 Returns:
130127 Config object
131128 """
@@ -138,7 +135,7 @@ def from_dict(d: Dict[str, Any]) -> Config:
138135def print_config (cfg : Config , resolve : bool = True ) -> None :
139136 """
140137 Pretty print configuration.
141-
138+
142139 Args:
143140 cfg: Config to print
144141 resolve: Whether to resolve variable interpolations
@@ -150,10 +147,10 @@ def print_config(cfg: Config, resolve: bool = True) -> None:
150147def validate_config (cfg : Config ) -> None :
151148 """
152149 Validate configuration values.
153-
150+
154151 Args:
155152 cfg: Config object to validate
156-
153+
157154 Raises:
158155 ValueError: If configuration is invalid
159156 """
@@ -162,18 +159,18 @@ def validate_config(cfg: Config) -> None:
162159 raise ValueError ("model.in_channels must be positive" )
163160 if cfg .model .out_channels <= 0 :
164161 raise ValueError ("model.out_channels must be positive" )
165- if len (cfg .model .input_size ) != 3 :
166- raise ValueError ("model.input_size must be 3D" )
167-
162+ if len (cfg .model .input_size ) not in [ 2 , 3 ] :
163+ raise ValueError ("model.input_size must be 2D or 3D (got length {})" . format ( len ( cfg . model . input_size )) )
164+
168165 # System validation
169166 if cfg .system .training .batch_size <= 0 :
170167 raise ValueError ("system.training.batch_size must be positive" )
171168 if cfg .system .training .num_workers < 0 :
172169 raise ValueError ("system.training.num_workers must be non-negative" )
173170
174171 # Data validation
175- if len (cfg .data .patch_size ) != 3 :
176- raise ValueError ("data.patch_size must be 3D" )
172+ if len (cfg .data .patch_size ) not in [ 2 , 3 ] :
173+ raise ValueError ("data.patch_size must be 2D or 3D (got length {})" . format ( len ( cfg . data . patch_size )) )
177174
178175 # Optimizer validation
179176 if cfg .optimization .optimizer .lr <= 0 :
@@ -188,7 +185,7 @@ def validate_config(cfg: Config) -> None:
188185 raise ValueError ("optimization.gradient_clip_val must be non-negative" )
189186 if cfg .optimization .accumulate_grad_batches <= 0 :
190187 raise ValueError ("optimization.accumulate_grad_batches must be positive" )
191-
188+
192189 # Loss validation
193190 if len (cfg .model .loss_functions ) != len (cfg .model .loss_weights ):
194191 raise ValueError ("loss_functions and loss_weights must have same length" )
@@ -199,16 +196,17 @@ def validate_config(cfg: Config) -> None:
199196def get_config_hash (cfg : Config ) -> str :
200197 """
201198 Generate a hash string for the configuration.
202-
199+
203200 Useful for experiment tracking and reproducibility.
204-
201+
205202 Args:
206203 cfg: Config object
207-
204+
208205 Returns:
209206 Hash string
210207 """
211208 import hashlib
209+
212210 omega_conf = OmegaConf .structured (cfg )
213211 yaml_str = OmegaConf .to_yaml (omega_conf , resolve = True )
214212 return hashlib .md5 (yaml_str .encode ()).hexdigest ()[:8 ]
@@ -228,20 +226,20 @@ def create_experiment_name(cfg: Config) -> str:
228226 cfg .model .architecture ,
229227 f"bs{ cfg .system .training .batch_size } " ,
230228 f"lr{ cfg .optimization .optimizer .lr :.0e} " ,
231- get_config_hash (cfg )
229+ get_config_hash (cfg ),
232230 ]
233231 return "_" .join (parts )
234232
235233
236234__all__ = [
237- ' load_config' ,
238- ' save_config' ,
239- ' merge_configs' ,
240- ' update_from_cli' ,
241- ' to_dict' ,
242- ' from_dict' ,
243- ' print_config' ,
244- ' validate_config' ,
245- ' get_config_hash' ,
246- ' create_experiment_name' ,
247- ]
235+ " load_config" ,
236+ " save_config" ,
237+ " merge_configs" ,
238+ " update_from_cli" ,
239+ " to_dict" ,
240+ " from_dict" ,
241+ " print_config" ,
242+ " validate_config" ,
243+ " get_config_hash" ,
244+ " create_experiment_name" ,
245+ ]
0 commit comments