Skip to content

Commit f623f83

Browse files
Merge pull request #2606 from AI-Hypercomputer:legacy_tests
PiperOrigin-RevId: 830700561
2 parents b63ce5c + 995867e commit f623f83

File tree

5 files changed

+346
-68
lines changed

5 files changed

+346
-68
lines changed

end_to_end/tpu/test_decode.sh

Lines changed: 0 additions & 37 deletions
This file was deleted.

end_to_end/tpu/test_determinism.sh

Lines changed: 0 additions & 31 deletions
This file was deleted.

tests/decode_tests.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,14 @@
1414

1515
"""Tests for decode with various configs."""
1616

17+
import io
1718
import os
1819
import unittest
1920

2021
import pytest
2122

2223
from absl.testing import absltest
24+
from contextlib import redirect_stdout
2325

2426
from MaxText.decode import main as decode_main
2527
from MaxText.globals import MAXTEXT_PKG_DIR, MAXTEXT_ASSETS_ROOT
@@ -28,6 +30,7 @@
2830
class DecodeTests(unittest.TestCase):
2931
"""Tests decode with various configs."""
3032

33+
GEMMA_2B_CKPT_PATH = "gs://maxtext-gemma/2b/2025-11-04-04-33//0/items"
3134
CONFIGS = {
3235
"base": [ # tests decode
3336
None,
@@ -70,6 +73,41 @@ class DecodeTests(unittest.TestCase):
7073
"per_device_batch_size=.25",
7174
rf"tokenizer_path={os.path.join('src', MAXTEXT_ASSETS_ROOT, 'tokenizer.llama2')}",
7275
],
76+
"decode_sampling": [
77+
None,
78+
os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml"),
79+
"base_output_directory=gs://runner-maxtext-logs",
80+
"run_name=runner_test",
81+
f"load_parameters_path={GEMMA_2B_CKPT_PATH}",
82+
"per_device_batch_size=1",
83+
"max_prefill_predict_length=8",
84+
"max_target_length=16",
85+
"dataset_type=synthetic",
86+
"steps=10",
87+
"async_checkpointing=False",
88+
"model_name=gemma-2b",
89+
rf"tokenizer_path={os.path.join('src', MAXTEXT_ASSETS_ROOT, 'tokenizer.gemma')}",
90+
"attention=dot_product",
91+
"prompt=I love to",
92+
"skip_jax_distributed_system=True",
93+
],
94+
}
95+
SAMPLING_STRATEGY_CONFIG = {
96+
"greedy": [
97+
"decode_sampling_strategy=greedy",
98+
],
99+
"weighted": [
100+
"decode_sampling_strategy=weighted",
101+
"decode_sampling_temperature=.00001",
102+
],
103+
"nucleus": [
104+
"decode_sampling_strategy=nucleus",
105+
"decode_sampling_nucleus_p=0",
106+
],
107+
"topk": [
108+
"decode_sampling_strategy=topk",
109+
"decode_sampling_top_k=1",
110+
],
73111
}
74112

75113
@pytest.mark.tpu_only
@@ -96,6 +134,46 @@ def test_tpu_pdb_lt_1(self):
96134
def test_gpu_pdb_lt_1(self):
97135
decode_main(DecodeTests.CONFIGS["pdb_lt_1"] + ["attention=dot_product"])
98136

137+
@pytest.mark.tpu_only
138+
@pytest.mark.scheduled_only
139+
def test_decode_greedy_sampling(self):
140+
config = DecodeTests.CONFIGS["decode_sampling"] + DecodeTests.SAMPLING_STRATEGY_CONFIG["greedy"]
141+
captured_out = run_decoding(config)
142+
expected_output = "Input `I love to` -> ` travel and I love to write"
143+
assert expected_output in captured_out
144+
145+
@pytest.mark.tpu_only
146+
@pytest.mark.scheduled_only
147+
def test_decode_weighted_sampling(self):
148+
config = DecodeTests.CONFIGS["decode_sampling"] + DecodeTests.SAMPLING_STRATEGY_CONFIG["weighted"]
149+
captured_out = run_decoding(config)
150+
expected_output = "Input `I love to` -> ` travel and I love to write"
151+
assert expected_output in captured_out
152+
153+
@pytest.mark.tpu_only
154+
@pytest.mark.scheduled_only
155+
def test_decode_nucleus_sampling(self):
156+
config = DecodeTests.CONFIGS["decode_sampling"] + DecodeTests.SAMPLING_STRATEGY_CONFIG["nucleus"]
157+
captured_out = run_decoding(config)
158+
expected_output = "Input `I love to` -> ` travel and I love to write"
159+
assert expected_output in captured_out
160+
161+
@pytest.mark.tpu_only
162+
@pytest.mark.scheduled_only
163+
def test_decode_topk_sampling(self):
164+
config = DecodeTests.CONFIGS["decode_sampling"] + DecodeTests.SAMPLING_STRATEGY_CONFIG["topk"]
165+
captured_out = run_decoding(config)
166+
expected_output = "Input `I love to` -> ` travel and I love to write"
167+
assert expected_output in captured_out
168+
169+
170+
def run_decoding(config):
171+
f = io.StringIO()
172+
with redirect_stdout(f):
173+
decode_main(config)
174+
captured_out = f.getvalue()
175+
return captured_out
176+
99177

100178
if __name__ == "__main__":
101179
absltest.main()

tests/determinism_test.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# Copyright 2023–2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Tests to verify the deterministic nature of MaxText training runs.
16+
17+
This module ensures that when the MaxText training is executed multiple times
18+
with identical configurations, the loss metrics across runs are exactly
19+
the same.
20+
"""
21+
22+
import datetime
23+
import json
24+
import os
25+
import unittest
26+
27+
import pytest
28+
29+
from MaxText.train import main as train_main
30+
from MaxText.globals import MAXTEXT_PKG_DIR
31+
32+
33+
def compare_target_metrics(metrics_files, target):
34+
"""Asserts over loss values from two runs."""
35+
loss = []
36+
for file in metrics_files:
37+
with open(file, "rt", encoding="utf8") as f:
38+
run_loss = json.loads(f.readlines()[-1])[target]
39+
loss.append(run_loss)
40+
assert loss[0] == loss[1]
41+
42+
43+
class DeterminismTests(unittest.TestCase):
44+
"""Tests determinism by running MaxText training multiple times and comparing loss."""
45+
46+
@pytest.mark.tpu_only
47+
@pytest.mark.scheduled_only
48+
def test_determinism(self):
49+
"""Executes two identical training runs and verifies training loss is the same."""
50+
run_name = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
51+
common_config = [
52+
None,
53+
os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml"),
54+
"steps=5",
55+
"enable_checkpointing=False",
56+
"enable_data_shuffling=True",
57+
"enable_dropout=False",
58+
"base_output_directory=gs://runner-maxtext-logs",
59+
"dataset_path=gs://maxtext-dataset",
60+
"skip_jax_distributed_system=True",
61+
]
62+
train_1_config = common_config + [
63+
f"run_name={run_name}_1",
64+
f"metrics_file={run_name}_1_metrics.txt",
65+
]
66+
train_2_config = common_config + [
67+
f"run_name={run_name}_2",
68+
f"metrics_file={run_name}_2_metrics.txt",
69+
]
70+
71+
train_main(train_1_config)
72+
train_main(train_2_config)
73+
compare_target_metrics([f"{run_name}_1_metrics.txt", f"{run_name}_2_metrics.txt"], "learning/loss")

0 commit comments

Comments
 (0)