Skip to content

Commit c0abc4c

Browse files
Merge pull request #2704 from SamuelMarks:axis-conf
PiperOrigin-RevId: 833440797
2 parents 0eb74b4 + b0cd857 commit c0abc4c

File tree

1 file changed

+46
-5
lines changed

1 file changed

+46
-5
lines changed

src/MaxText/pyconfig.py

Lines changed: 46 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,22 @@ def resolve_config_path(param: str) -> str:
4747
return param if os.path.isfile(param) else os.path.join("src", param)
4848

4949

50+
def _merge_logical_axis_rules(base_rules, new_rules):
51+
"""Merges two lists of logical_axis_rules. Rules in new_rules override all rules
52+
with the same name in base_rules."""
53+
if not new_rules:
54+
return base_rules
55+
56+
new_rule_keys = {rule[0] for rule in new_rules}
57+
58+
# Filter old rules to exclude any that will be replaced.
59+
updated_rules = [rule for rule in base_rules if rule[0] not in new_rule_keys]
60+
61+
# Add all the new rules.
62+
updated_rules.extend(new_rules)
63+
return updated_rules
64+
65+
5066
def _load_config(config_name: str) -> omegaconf.DictConfig:
5167
"""Loads a YAML file and its base_configs recursively using OmegaConf."""
5268
cfg = omegaconf.OmegaConf.load(config_name)
@@ -185,13 +201,38 @@ def initialize(argv: list[str], **kwargs) -> HyperParameters:
185201
logger.warning("Model config for '%s' not found at %s", model_name, model_config_path)
186202

187203
# 4. Final merge (base, model, then overrides)
188-
final_config = omegaconf.OmegaConf.merge(base_yml_config, model_cfg, overrides_cfg)
204+
model_cfg_oc = omegaconf.OmegaConf.create(model_cfg)
205+
206+
# 4. Manually merge logical_axis_rules to avoid OmegaConf's list replacement behavior.
207+
base_rules_oc = base_yml_config.get("logical_axis_rules", [])
208+
model_rules_oc = model_cfg_oc.get("logical_axis_rules", [])
209+
overrides_rules_oc = overrides_cfg.get("logical_axis_rules", [])
210+
211+
base_rules = omegaconf.OmegaConf.to_container(base_rules_oc, resolve=True) if base_rules_oc else []
212+
model_rules = omegaconf.OmegaConf.to_container(model_rules_oc, resolve=True) if model_rules_oc else []
213+
overrides_rules = omegaconf.OmegaConf.to_container(overrides_rules_oc, resolve=True) if overrides_rules_oc else []
214+
215+
merged_rules = _merge_logical_axis_rules(base_rules, model_rules)
216+
merged_rules = _merge_logical_axis_rules(merged_rules, overrides_rules)
217+
218+
# Remove the rules from the original configs before the main merge
219+
if "logical_axis_rules" in base_yml_config:
220+
del base_yml_config["logical_axis_rules"]
221+
if "logical_axis_rules" in model_cfg_oc:
222+
del model_cfg_oc["logical_axis_rules"]
223+
if "logical_axis_rules" in overrides_cfg:
224+
del overrides_cfg["logical_axis_rules"]
225+
226+
# 5. Final merge for all other keys
227+
final_config = omegaconf.OmegaConf.merge(base_yml_config, model_cfg_oc, overrides_cfg)
228+
final_config["logical_axis_rules"] = merged_rules
229+
189230
raw_keys_dict = omegaconf.OmegaConf.to_container(final_config, resolve=True)
190231

191-
# 5. Handle environment variable overrides
192-
cli_keys = set(omegaconf.OmegaConf.to_container(cli_cfg, resolve=True).keys())
193-
kwargs_keys = set(kwargs.keys())
194-
for k in list(raw_keys_dict.keys()):
232+
# 6. Handle environment variable overrides
233+
cli_keys = frozenset(omegaconf.OmegaConf.to_container(cli_cfg, resolve=True).keys())
234+
kwargs_keys = frozenset(kwargs.keys())
235+
for k in tuple(raw_keys_dict.keys()):
195236
env_key = yaml_key_to_env_key(k)
196237
if env_key in os.environ:
197238
if k in cli_keys or k in kwargs_keys:

0 commit comments

Comments
 (0)