|
29 | 29 | from MaxText import pyconfig_deprecated |
30 | 30 | from MaxText.common_types import DecoderBlockType, ShardMode |
31 | 31 | from MaxText.configs import types |
| 32 | +from MaxText.configs.types import MaxTextConfig |
32 | 33 | from MaxText.inference_utils import str2bool |
33 | 34 |
|
34 | 35 | logger = logging.getLogger(__name__) |
@@ -104,7 +105,7 @@ def _prepare_for_pydantic(raw_keys: dict[str, Any]) -> dict[str, Any]: |
104 | 105 | for key, value in raw_keys.items(): |
105 | 106 | if key not in valid_fields: |
106 | 107 | logger.warning("Ignoring invalid/unsupported field from YAML/CLI: %s", repr(key)) |
107 | | - continue |
| 108 | + raise ValueError(f"{key!r} not in {", ".join(map(repr, valid_fields))}.") |
108 | 109 |
|
109 | 110 | new_value = value |
110 | 111 | if isinstance(new_value, str) and new_value.lower() == "none": |
@@ -179,6 +180,21 @@ def get_keys(self) -> dict[str, Any]: |
179 | 180 |
|
180 | 181 | def initialize(argv: list[str], **kwargs) -> HyperParameters: |
181 | 182 | """Initializes the configuration by loading YAML files, and applying CLI, env, and kwarg overrides.""" |
| 183 | + pydantic_config = initialize_pydantic(argv, **kwargs) |
| 184 | + config = HyperParameters(pydantic_config) |
| 185 | + |
| 186 | + if config.log_config: |
| 187 | + for k, v in sorted(config.get_keys().items()): |
| 188 | + if k != "hf_access_token": |
| 189 | + logger.info("Config param %s: %s", k, v) |
| 190 | + |
| 191 | + return config |
| 192 | + |
| 193 | + |
| 194 | +def initialize_pydantic(argv: list[str], **kwargs) -> MaxTextConfig: |
| 195 | + """Initializes the configuration by loading YAML files, and applying CLI, env, and kwarg overrides. |
| 196 | + Returns pydantic MaxTextConfig class whereas `initialize` returns the og `HyperParameters` |
| 197 | + """ |
182 | 198 | # 1. Load base and inherited configs from file(s) |
183 | 199 | config_path = resolve_config_path(argv[1]) |
184 | 200 | base_yml_config = _load_config(config_path) |
@@ -287,9 +303,9 @@ def initialize(argv: list[str], **kwargs) -> HyperParameters: |
287 | 303 | if k not in KEYS_NO_LOGGING: |
288 | 304 | logger.info("Config param %s: %s", k, v) |
289 | 305 |
|
290 | | - return config |
| 306 | + return pydantic_config |
291 | 307 |
|
292 | 308 |
|
293 | 309 | # Shim for backward compatibility with pyconfig_deprecated_test.py |
294 | 310 | validate_and_update_keys = pyconfig_deprecated.validate_and_update_keys |
295 | | -__all__ = ["initialize"] |
| 311 | +__all__ = ["initialize", "initialize_pydantic"] |
0 commit comments