@@ -47,6 +47,22 @@ def resolve_config_path(param: str) -> str:
4747 return param if os .path .isfile (param ) else os .path .join ("src" , param )
4848
4949
50+ def _merge_logical_axis_rules (base_rules , new_rules ):
51+ """Merges two lists of logical_axis_rules. Rules in new_rules override all rules
52+ with the same name in base_rules."""
53+ if not new_rules :
54+ return base_rules
55+
56+ new_rule_keys = {rule [0 ] for rule in new_rules }
57+
58+ # Filter old rules to exclude any that will be replaced.
59+ updated_rules = [rule for rule in base_rules if rule [0 ] not in new_rule_keys ]
60+
61+ # Add all the new rules.
62+ updated_rules .extend (new_rules )
63+ return updated_rules
64+
65+
5066def _load_config (config_name : str ) -> omegaconf .DictConfig :
5167 """Loads a YAML file and its base_configs recursively using OmegaConf."""
5268 cfg = omegaconf .OmegaConf .load (config_name )
@@ -185,13 +201,38 @@ def initialize(argv: list[str], **kwargs) -> HyperParameters:
185201 logger .warning ("Model config for '%s' not found at %s" , model_name , model_config_path )
186202
187203 # 4. Final merge (base, model, then overrides)
188- final_config = omegaconf .OmegaConf .merge (base_yml_config , model_cfg , overrides_cfg )
204+ model_cfg_oc = omegaconf .OmegaConf .create (model_cfg )
205+
206+ # 4. Manually merge logical_axis_rules to avoid OmegaConf's list replacement behavior.
207+ base_rules_oc = base_yml_config .get ("logical_axis_rules" , [])
208+ model_rules_oc = model_cfg_oc .get ("logical_axis_rules" , [])
209+ overrides_rules_oc = overrides_cfg .get ("logical_axis_rules" , [])
210+
211+ base_rules = omegaconf .OmegaConf .to_container (base_rules_oc , resolve = True ) if base_rules_oc else []
212+ model_rules = omegaconf .OmegaConf .to_container (model_rules_oc , resolve = True ) if model_rules_oc else []
213+ overrides_rules = omegaconf .OmegaConf .to_container (overrides_rules_oc , resolve = True ) if overrides_rules_oc else []
214+
215+ merged_rules = _merge_logical_axis_rules (base_rules , model_rules )
216+ merged_rules = _merge_logical_axis_rules (merged_rules , overrides_rules )
217+
218+ # Remove the rules from the original configs before the main merge
219+ if "logical_axis_rules" in base_yml_config :
220+ del base_yml_config ["logical_axis_rules" ]
221+ if "logical_axis_rules" in model_cfg_oc :
222+ del model_cfg_oc ["logical_axis_rules" ]
223+ if "logical_axis_rules" in overrides_cfg :
224+ del overrides_cfg ["logical_axis_rules" ]
225+
226+ # 5. Final merge for all other keys
227+ final_config = omegaconf .OmegaConf .merge (base_yml_config , model_cfg_oc , overrides_cfg )
228+ final_config ["logical_axis_rules" ] = merged_rules
229+
189230 raw_keys_dict = omegaconf .OmegaConf .to_container (final_config , resolve = True )
190231
191- # 5 . Handle environment variable overrides
192- cli_keys = set (omegaconf .OmegaConf .to_container (cli_cfg , resolve = True ).keys ())
193- kwargs_keys = set (kwargs .keys ())
194- for k in list (raw_keys_dict .keys ()):
232+ # 6 . Handle environment variable overrides
233+ cli_keys = frozenset (omegaconf .OmegaConf .to_container (cli_cfg , resolve = True ).keys ())
234+ kwargs_keys = frozenset (kwargs .keys ())
235+ for k in tuple (raw_keys_dict .keys ()):
195236 env_key = yaml_key_to_env_key (k )
196237 if env_key in os .environ :
197238 if k in cli_keys or k in kwargs_keys :
0 commit comments