|
| 1 | +# Copyright 2025 Google LLC |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# https://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +""" |
| 16 | +These tests verify the Compile-Then-Load workflow. |
| 17 | +It ensures that a model compiled via train_compile.py can be successfully |
| 18 | +loaded and executed by train.py. |
| 19 | +""" |
| 20 | + |
| 21 | +import tempfile |
| 22 | +import unittest |
| 23 | +import pytest |
| 24 | +import os |
| 25 | +import shutil |
| 26 | +import jax |
| 27 | +from MaxText.globals import MAXTEXT_PKG_DIR |
| 28 | +from MaxText import train_compile |
| 29 | +from MaxText import train |
| 30 | + |
| 31 | + |
| 32 | +class CompileThenLoadTest(unittest.TestCase): |
| 33 | + """Tests for the Split Compile and Train workflow""" |
| 34 | + |
| 35 | + def setUp(self): |
| 36 | + """Create a temporary directory for the compiled pickle file.""" |
| 37 | + self.temp_dir = tempfile.mkdtemp() |
| 38 | + self.pickle_file = os.path.join(self.temp_dir, "test_compiled_train.pickle") |
| 39 | + |
| 40 | + # Ensure JAX cache doesn't interfere with clean test runs |
| 41 | + jax.config.update("jax_enable_compilation_cache", False) |
| 42 | + |
| 43 | + def tearDown(self): |
| 44 | + """Clean up the temporary directory.""" |
| 45 | + if os.path.exists(self.temp_dir): |
| 46 | + shutil.rmtree(self.temp_dir) |
| 47 | + |
| 48 | + def get_device_user_facing_name(self): |
| 49 | + """Gets TPU device user facing name to generate correct AOT arguments.""" |
| 50 | + devices = jax.devices() |
| 51 | + if not devices or "tpu" not in devices[0].platform.lower(): |
| 52 | + pytest.skip("This test requires a TPU environment.") |
| 53 | + |
| 54 | + num_devices = len(devices) |
| 55 | + device_kind = devices[0].device_kind |
| 56 | + device_info = { |
| 57 | + "TPU v4": ("v4", 2 * num_devices), |
| 58 | + "TPU v5 lite": ("v5e", num_devices), |
| 59 | + "TPU v5": ("v5p", 2 * num_devices), |
| 60 | + "TPU v6": ("v6e", num_devices), |
| 61 | + } |
| 62 | + |
| 63 | + prefix, topology_devices = next((v for k, v in device_info.items() if k in device_kind), (None, None)) |
| 64 | + if prefix is None: |
| 65 | + raise ValueError(f"Unsupported TPU device kind for AOT test: {device_kind}") |
| 66 | + |
| 67 | + return f"{prefix}-{topology_devices}" |
| 68 | + |
| 69 | + def run_compile_then_load(self, test_name, *extra_args): |
| 70 | + """ |
| 71 | + Executes the compile step, checks for pickle existence, |
| 72 | + then executes the load/train step. |
| 73 | + """ |
| 74 | + |
| 75 | + # Common arguments derived from your request |
| 76 | + shared_args = [ |
| 77 | + "global_parameter_scale=1", |
| 78 | + "per_device_batch_size=4", |
| 79 | + "steps=1", |
| 80 | + "learning_rate=1e-3", |
| 81 | + "dataset_type=synthetic", |
| 82 | + "enable_checkpointing=False", |
| 83 | + ] |
| 84 | + |
| 85 | + if extra_args: |
| 86 | + shared_args.extend(extra_args) |
| 87 | + |
| 88 | + # Compilation |
| 89 | + topology = self.get_device_user_facing_name() |
| 90 | + |
| 91 | + compile_specific_args = [ |
| 92 | + f"compile_topology={topology}", |
| 93 | + "compile_topology_num_slices=1", |
| 94 | + f"compiled_trainstep_file={self.pickle_file}", |
| 95 | + ] |
| 96 | + |
| 97 | + compile_argv = ( |
| 98 | + (None, os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml")) + tuple(shared_args) + tuple(compile_specific_args) |
| 99 | + ) |
| 100 | + |
| 101 | + print(f"\n--- Starting Compilation Step for {test_name} ---") |
| 102 | + # Clear caches before compile to ensure clean state |
| 103 | + jax.clear_caches() |
| 104 | + train_compile.main(compile_argv) |
| 105 | + |
| 106 | + # Assert the pickle file was actually created |
| 107 | + assert os.path.exists(self.pickle_file), f"Compilation failed: {self.pickle_file} was not created." |
| 108 | + |
| 109 | + load_specific_args = [ |
| 110 | + "base_output_directory=gs://runner-maxtext-logs", |
| 111 | + f"run_name=compile_then_load_{test_name}", |
| 112 | + f"compiled_trainstep_file={self.pickle_file}", |
| 113 | + ] |
| 114 | + |
| 115 | + train_argv = ( |
| 116 | + (None, os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml")) + tuple(shared_args) + tuple(load_specific_args) |
| 117 | + ) |
| 118 | + |
| 119 | + print(f"\n--- Starting Load/Train Step for {test_name} ---") |
| 120 | + # Clear caches before train to ensure we are actually loading from the pickle |
| 121 | + jax.clear_caches() |
| 122 | + train.main(train_argv) |
| 123 | + |
| 124 | + print(f"Successfully compiled and loaded for test {test_name}!") |
| 125 | + |
| 126 | + @pytest.mark.tpu_only |
| 127 | + def test_default_compile_load(self): |
| 128 | + self.run_compile_then_load("default_run") |
0 commit comments