diff --git a/src/MaxText/pyconfig.py b/src/MaxText/pyconfig.py index 959a36a6a..ee58ec89b 100644 --- a/src/MaxText/pyconfig.py +++ b/src/MaxText/pyconfig.py @@ -129,6 +129,25 @@ def _prepare_for_pydantic(raw_keys: dict[str, Any]) -> dict[str, Any]: return pydantic_kwargs +def validate_and_set_hlo_dump_defaults(raw_keys): + if not raw_keys["dump_hlo"]: + return raw_keys + if os.environ.get("XLA_FLAGS") and raw_keys["dump_hlo_xla_flags"]: + raise ValueError("You must set either XLA_FLAGS or dump_hlo_xla_flags to dump HLO, but not both.") + if not os.environ.get("XLA_FLAGS") and not raw_keys["dump_hlo_xla_flags"]: + raw_keys["dump_hlo_xla_flags"] = f"--xla_dump_to={raw_keys['dump_hlo_local_dir']} --xla_dump_large_constants" + if raw_keys["dump_hlo_local_module_name"]: + raw_keys["dump_hlo_xla_flags"] = ( + f"{raw_keys['dump_hlo_xla_flags']} --xla_dump_hlo_module_re={raw_keys['dump_hlo_local_module_name']}" + ) + if not raw_keys["dump_hlo_gcs_dir"]: + raw_keys["dump_hlo_gcs_dir"] = os.path.join(raw_keys["base_output_directory"], raw_keys["run_name"], "xla_dump") + print(raw_keys["dump_hlo_gcs_dir"]) + else: + raw_keys["dump_hlo_gcs_dir"] = gcs_utils.add_trailing_slash(raw_keys["dump_hlo_gcs_dir"]) + if not os.environ.get("XLA_FLAGS"): + os.environ["XLA_FLAGS"] = raw_keys["dump_hlo_xla_flags"] + return raw_keys class HyperParameters: """ @@ -279,6 +298,8 @@ def initialize(argv: list[str], **kwargs) -> HyperParameters: compilation_cache.set_cache_dir(os.path.expanduser(pydantic_kwargs["jax_cache_dir"])) + validate_and_set_hlo_dump_defaults(pydantic_kwargs) + pydantic_config = types.MaxTextConfig(**pydantic_kwargs) config = HyperParameters(pydantic_config)