Skip to content

Commit a2d0a19

Browse files
Merge pull request #2775 from SamuelMarks:loud-fields
PiperOrigin-RevId: 842371995
2 parents 5c77d54 + 43ff2e0 commit a2d0a19

File tree

3 files changed

+36
-3
lines changed

3 files changed

+36
-3
lines changed

src/MaxText/configs/types.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,10 @@ class ProfilerType(str, Enum):
223223
class RunInfo(BaseModel):
224224
"""Configuration for the overall run, model identity, and logging."""
225225

226+
base_config: None | str = Field(
227+
None,
228+
description="Base config to inherit from. This is a meta-field and is consumed by the config loading system.",
229+
)
226230
run_name: str = Field(
227231
"",
228232
description="The name of the run. Checkpoints will be stored under this name.",

src/MaxText/pyconfig.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from MaxText import pyconfig_deprecated
3030
from MaxText.common_types import DecoderBlockType, ShardMode
3131
from MaxText.configs import types
32+
from MaxText.configs.types import MaxTextConfig
3233
from MaxText.inference_utils import str2bool
3334

3435
logger = logging.getLogger(__name__)
@@ -104,7 +105,7 @@ def _prepare_for_pydantic(raw_keys: dict[str, Any]) -> dict[str, Any]:
104105
for key, value in raw_keys.items():
105106
if key not in valid_fields:
106107
logger.warning("Ignoring invalid/unsupported field from YAML/CLI: %s", repr(key))
107-
continue
108+
raise ValueError(f"{key!r} not in {", ".join(map(repr, valid_fields))}.")
108109

109110
new_value = value
110111
if isinstance(new_value, str) and new_value.lower() == "none":
@@ -179,6 +180,21 @@ def get_keys(self) -> dict[str, Any]:
179180

180181
def initialize(argv: list[str], **kwargs) -> HyperParameters:
181182
"""Initializes the configuration by loading YAML files, and applying CLI, env, and kwarg overrides."""
183+
pydantic_config = initialize_pydantic(argv, **kwargs)
184+
config = HyperParameters(pydantic_config)
185+
186+
if config.log_config:
187+
for k, v in sorted(config.get_keys().items()):
188+
if k != "hf_access_token":
189+
logger.info("Config param %s: %s", k, v)
190+
191+
return config
192+
193+
194+
def initialize_pydantic(argv: list[str], **kwargs) -> MaxTextConfig:
195+
"""Initializes the configuration by loading YAML files, and applying CLI, env, and kwarg overrides.
196+
Returns pydantic MaxTextConfig class whereas `initialize` returns the og `HyperParameters`
197+
"""
182198
# 1. Load base and inherited configs from file(s)
183199
config_path = resolve_config_path(argv[1])
184200
base_yml_config = _load_config(config_path)
@@ -287,9 +303,9 @@ def initialize(argv: list[str], **kwargs) -> HyperParameters:
287303
if k not in KEYS_NO_LOGGING:
288304
logger.info("Config param %s: %s", k, v)
289305

290-
return config
306+
return pydantic_config
291307

292308

293309
# Shim for backward compatibility with pyconfig_deprecated_test.py
294310
validate_and_update_keys = pyconfig_deprecated.validate_and_update_keys
295-
__all__ = ["initialize"]
311+
__all__ = ["initialize", "initialize_pydantic"]

tests/configs_value_test.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from MaxText import pyconfig
2424
from MaxText.configs import types
2525
from MaxText.globals import MAXTEXT_REPO_ROOT
26+
from MaxText.pyconfig import initialize_pydantic
2627

2728
# Path to the base.yml config. This assumes that `pytest` is run from the project root.
2829
_BASE_CONFIG_PATH = os.path.join(MAXTEXT_REPO_ROOT, "src", "MaxText", "configs", "base.yml")
@@ -131,6 +132,18 @@ def test_llama3_tokenizer_correction(self):
131132
config = pyconfig.initialize(argv)
132133
self.assertEqual(config.tokenizer_type, "tiktoken")
133134

135+
def test_initialize_pydantic_bad_keys(self):
136+
"""Test that `pydantic.ValidationError` is raised on keys not in MaxTextConfig"""
137+
with self.assertRaises(ValueError):
138+
initialize_pydantic(
139+
[
140+
"",
141+
_BASE_CONFIG_PATH,
142+
"tokenizer_path=assets/tokenizer_llama3.tiktoken",
143+
"NOT_A_VALID_KEY=test",
144+
]
145+
)
146+
134147

135148
if __name__ == "__main__":
136149
unittest.main()

0 commit comments

Comments
 (0)