Skip to content

Commit b6aaa87

Browse files
Merge pull request #2705 from AI-Hypercomputer:chengnuojin-xaot
PiperOrigin-RevId: 834357725
2 parents 8de5059 + 0f84a7a commit b6aaa87

File tree

5 files changed

+164
-14
lines changed

5 files changed

+164
-14
lines changed

.github/workflows/run_pathways_tests_internal.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ jobs:
7575
python3 -m pip install -e . --no-dependencies &&
7676
python3 -m pip uninstall -y libtpu &&
7777
# TODO(b/454659463): Enable test_default_hlo_match after volume mount is supported.
78-
python3 -m pytest ${{ inputs.pytest_addopts }} -v -m "${FINAL_PYTEST_MARKER}" -k "not AotHloIdenticalTest" --durations=0
78+
python3 -m pytest ${{ inputs.pytest_addopts }} -v -m "${FINAL_PYTEST_MARKER}" -k "not AotHloIdenticalTest and not CompileThenLoad" --durations=0
7979
8080
services:
8181
resource_manager:

src/MaxText/maxtext_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ def should_prevent_cse_in_remat(config):
164164
return True
165165

166166

167-
def load_compiled(config, partial_train, state):
167+
def load_compiled(config, partial_train, state, execution_devices):
168168
"""# Loading a serialized compiled train step function."""
169169

170170
# Currently partial_train and state are needed to reconstruct
@@ -187,7 +187,7 @@ def get_train_input_output_trees(func, input_args, input_kwargs):
187187
shaped_input_args = (state, shaped_batch, example_rng)
188188
shaped_input_kwargs = {}
189189
in_tree, out_tree = get_train_input_output_trees(partial_train, shaped_input_args, shaped_input_kwargs)
190-
p_train_step = deserialize_and_load(serialized_compiled, in_tree, out_tree)
190+
p_train_step = deserialize_and_load(serialized_compiled, in_tree, out_tree, execution_devices=execution_devices)
191191
return p_train_step
192192

193193

src/MaxText/train.py

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,12 @@ def train_step(model, config, state_mesh_shardings, params_shardings, state, dat
228228
rng2: A new rng key that can be used in future calls.
229229
230230
"""
231-
reference_params, reference_params_sharding, extra_dpo_args, _loss_fn = [], [], [], loss_fn
231+
reference_params, reference_params_sharding, extra_dpo_args, _loss_fn = (
232+
[],
233+
[],
234+
[],
235+
loss_fn,
236+
)
232237
if config.use_dpo:
233238
state, reference_params = _split_dpo_state(state)
234239
state_mesh_shardings, reference_params_sharding = _split_dpo_state(state_mesh_shardings)
@@ -252,15 +257,19 @@ def train_step(model, config, state_mesh_shardings, params_shardings, state, dat
252257
if config.optimizer_memory_host_offload:
253258
if config.use_dpo:
254259
reference_params = jax.device_put(
255-
reference_params, max_utils.with_memory_kind(reference_params_sharding, "device")
260+
reference_params,
261+
max_utils.with_memory_kind(reference_params_sharding, "device"),
256262
)
257263
extra_dpo_args = [reference_params]
258264
if config.shard_optimizer_over_data:
259265
params = jax.tree.map(jax.lax.with_sharding_constraint, params, params_shardings)
260266
grad_func = jax.value_and_grad(_loss_fn, argnums=4, has_aux=True)
261267
(loss, aux), raw_grads = grad_func(model, config, data, dropout_rng, params, *extra_dpo_args, is_train=True)
262268

263-
raw_grads = jax.tree_util.tree_map(lambda x: x.astype(config.grad_dtype) if x.dtype == jnp.float32 else x, raw_grads)
269+
raw_grads = jax.tree_util.tree_map(
270+
lambda x: x.astype(config.grad_dtype) if x.dtype == jnp.float32 else x,
271+
raw_grads,
272+
)
264273
intermediate_outputs = aux["intermediate_outputs"]
265274
total_weights = aux["total_weights"]
266275
moe_lb_loss = aux["moe_lb_loss"]
@@ -274,7 +283,10 @@ def train_step(model, config, state_mesh_shardings, params_shardings, state, dat
274283
state = state.replace(
275284
opt_state=jax.device_put(
276285
state.opt_state,
277-
jax.tree_util.tree_map(lambda x: x.with_memory_kind(kind="device"), state_mesh_shardings.opt_state),
286+
jax.tree_util.tree_map(
287+
lambda x: x.with_memory_kind(kind="device"),
288+
state_mesh_shardings.opt_state,
289+
),
278290
)
279291
)
280292
# Move all parameters to device before optimizer update
@@ -378,16 +390,25 @@ def train_loop(config, recorder, state=None):
378390
params_shardings, state_mesh_shardings = sharding.maybe_update_params_sharding_with_opt(config, state_mesh_shardings)
379391

380392
p_train_step, p_eval_step = train_utils.jit_train_and_eval_step(
381-
config, model, mesh, state, state_mesh_shardings, train_step, eval_step, eval_data_iterator, params_shardings
393+
config,
394+
model,
395+
mesh,
396+
state,
397+
state_mesh_shardings,
398+
train_step,
399+
eval_step,
400+
eval_data_iterator,
401+
params_shardings,
382402
)
383403

384404
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
385405
shaped_batch = maxtext_utils.get_shaped_batch(config)
386406
if config.shard_optimizer_over_data:
387407
state = sharding.maybe_shard_with_name(state, state_mesh_shardings, config.shard_mode)
388-
compiled = p_train_step.lower(state, shaped_batch, init_rng).compile()
389-
compiled_stats = compiled.memory_analysis()
390-
max_utils.print_compiled_memory_stats(compiled_stats)
408+
if config.compiled_trainstep_file == "": # compile only when there is no pre-compiled file loaded
409+
compiled = p_train_step.lower(state, shaped_batch, init_rng).compile()
410+
compiled_stats = compiled.memory_analysis()
411+
max_utils.print_compiled_memory_stats(compiled_stats)
391412

392413
start_step = get_first_step(state) # this is the start_step for training
393414
prof = profiler.Profiler(config, offset_step=start_step)

src/MaxText/train_utils.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,10 +90,11 @@ def jit_train_step(config, model, state, state_mesh_shardings, data_sharding, tr
9090

9191
# Define the compilation of functional_train, either by loading the compiled version or wrapping a new one in a jit
9292
if config.compiled_trainstep_file != "":
93-
print("Loading the compiled function...", flush=True)
93+
max_logging.log("Loading the compiled function...")
94+
execution_devices = model.mesh.devices.flatten().tolist()
9495
# Need to pass train signature and state to determine i/o shapes of train_state for now.
95-
p_train_step = maxtext_utils.load_compiled(config, functional_train, state)
96-
print("Loaded compiled function!", flush=True)
96+
p_train_step = maxtext_utils.load_compiled(config, functional_train, state, execution_devices)
97+
max_logging.log("Loaded compiled function!")
9798
else:
9899
p_train_step = jax.jit(
99100
functional_train,

tests/xaot_test.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""
16+
These tests verify the Compile-Then-Load workflow.
17+
It ensures that a model compiled via train_compile.py can be successfully
18+
loaded and executed by train.py.
19+
"""
20+
21+
import tempfile
22+
import unittest
23+
import pytest
24+
import os
25+
import shutil
26+
import jax
27+
from MaxText.globals import MAXTEXT_PKG_DIR
28+
from MaxText import train_compile
29+
from MaxText import train
30+
31+
32+
class CompileThenLoadTest(unittest.TestCase):
33+
"""Tests for the Split Compile and Train workflow"""
34+
35+
def setUp(self):
36+
"""Create a temporary directory for the compiled pickle file."""
37+
self.temp_dir = tempfile.mkdtemp()
38+
self.pickle_file = os.path.join(self.temp_dir, "test_compiled_train.pickle")
39+
40+
# Ensure JAX cache doesn't interfere with clean test runs
41+
jax.config.update("jax_enable_compilation_cache", False)
42+
43+
def tearDown(self):
44+
"""Clean up the temporary directory."""
45+
if os.path.exists(self.temp_dir):
46+
shutil.rmtree(self.temp_dir)
47+
48+
def get_device_user_facing_name(self):
49+
"""Gets TPU device user facing name to generate correct AOT arguments."""
50+
devices = jax.devices()
51+
if not devices or "tpu" not in devices[0].platform.lower():
52+
pytest.skip("This test requires a TPU environment.")
53+
54+
num_devices = len(devices)
55+
device_kind = devices[0].device_kind
56+
device_info = {
57+
"TPU v4": ("v4", 2 * num_devices),
58+
"TPU v5 lite": ("v5e", num_devices),
59+
"TPU v5": ("v5p", 2 * num_devices),
60+
"TPU v6": ("v6e", num_devices),
61+
}
62+
63+
prefix, topology_devices = next((v for k, v in device_info.items() if k in device_kind), (None, None))
64+
if prefix is None:
65+
raise ValueError(f"Unsupported TPU device kind for AOT test: {device_kind}")
66+
67+
return f"{prefix}-{topology_devices}"
68+
69+
def run_compile_then_load(self, test_name, *extra_args):
70+
"""
71+
Executes the compile step, checks for pickle existence,
72+
then executes the load/train step.
73+
"""
74+
75+
# Common arguments derived from your request
76+
shared_args = [
77+
"global_parameter_scale=1",
78+
"per_device_batch_size=4",
79+
"steps=1",
80+
"learning_rate=1e-3",
81+
"dataset_type=synthetic",
82+
"enable_checkpointing=False",
83+
]
84+
85+
if extra_args:
86+
shared_args.extend(extra_args)
87+
88+
# Compilation
89+
topology = self.get_device_user_facing_name()
90+
91+
compile_specific_args = [
92+
f"compile_topology={topology}",
93+
"compile_topology_num_slices=1",
94+
f"compiled_trainstep_file={self.pickle_file}",
95+
]
96+
97+
compile_argv = (
98+
(None, os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml")) + tuple(shared_args) + tuple(compile_specific_args)
99+
)
100+
101+
print(f"\n--- Starting Compilation Step for {test_name} ---")
102+
# Clear caches before compile to ensure clean state
103+
jax.clear_caches()
104+
train_compile.main(compile_argv)
105+
106+
# Assert the pickle file was actually created
107+
assert os.path.exists(self.pickle_file), f"Compilation failed: {self.pickle_file} was not created."
108+
109+
load_specific_args = [
110+
"base_output_directory=gs://runner-maxtext-logs",
111+
f"run_name=compile_then_load_{test_name}",
112+
f"compiled_trainstep_file={self.pickle_file}",
113+
]
114+
115+
train_argv = (
116+
(None, os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml")) + tuple(shared_args) + tuple(load_specific_args)
117+
)
118+
119+
print(f"\n--- Starting Load/Train Step for {test_name} ---")
120+
# Clear caches before train to ensure we are actually loading from the pickle
121+
jax.clear_caches()
122+
train.main(train_argv)
123+
124+
print(f"Successfully compiled and loaded for test {test_name}!")
125+
126+
@pytest.mark.tpu_only
127+
def test_default_compile_load(self):
128+
self.run_compile_then_load("default_run")

0 commit comments

Comments
 (0)