1414
1515"""Tests for decode with various configs."""
1616
17+ import io
1718import os
1819import unittest
1920
2021import pytest
2122
2223from absl .testing import absltest
24+ from contextlib import redirect_stdout
2325
2426from MaxText .decode import main as decode_main
2527from MaxText .globals import MAXTEXT_PKG_DIR , MAXTEXT_ASSETS_ROOT
2830class DecodeTests (unittest .TestCase ):
2931 """Tests decode with various configs."""
3032
33+ GEMMA_2B_CKPT_PATH = "gs://maxtext-gemma/2b/2025-11-04-04-33//0/items"
3134 CONFIGS = {
3235 "base" : [ # tests decode
3336 None ,
@@ -70,6 +73,41 @@ class DecodeTests(unittest.TestCase):
7073 "per_device_batch_size=.25" ,
7174 rf"tokenizer_path={ os .path .join ('src' , MAXTEXT_ASSETS_ROOT , 'tokenizer.llama2' )} " ,
7275 ],
76+ "decode_sampling" : [
77+ None ,
78+ os .path .join (MAXTEXT_PKG_DIR , "configs" , "base.yml" ),
79+ "base_output_directory=gs://runner-maxtext-logs" ,
80+ "run_name=runner_test" ,
81+ f"load_parameters_path={ GEMMA_2B_CKPT_PATH } " ,
82+ "per_device_batch_size=1" ,
83+ "max_prefill_predict_length=8" ,
84+ "max_target_length=16" ,
85+ "dataset_type=synthetic" ,
86+ "steps=10" ,
87+ "async_checkpointing=False" ,
88+ "model_name=gemma-2b" ,
89+ rf"tokenizer_path={ os .path .join ('src' , MAXTEXT_ASSETS_ROOT , 'tokenizer.gemma' )} " ,
90+ "attention=dot_product" ,
91+ "prompt=I love to" ,
92+ "skip_jax_distributed_system=True" ,
93+ ],
94+ }
95+ SAMPLING_STRATEGY_CONFIG = {
96+ "greedy" : [
97+ "decode_sampling_strategy=greedy" ,
98+ ],
99+ "weighted" : [
100+ "decode_sampling_strategy=weighted" ,
101+ "decode_sampling_temperature=.00001" ,
102+ ],
103+ "nucleus" : [
104+ "decode_sampling_strategy=nucleus" ,
105+ "decode_sampling_nucleus_p=0" ,
106+ ],
107+ "topk" : [
108+ "decode_sampling_strategy=topk" ,
109+ "decode_sampling_top_k=1" ,
110+ ],
73111 }
74112
75113 @pytest .mark .tpu_only
@@ -96,6 +134,46 @@ def test_tpu_pdb_lt_1(self):
96134 def test_gpu_pdb_lt_1 (self ):
97135 decode_main (DecodeTests .CONFIGS ["pdb_lt_1" ] + ["attention=dot_product" ])
98136
137+ @pytest .mark .tpu_only
138+ @pytest .mark .scheduled_only
139+ def test_decode_greedy_sampling (self ):
140+ config = DecodeTests .CONFIGS ["decode_sampling" ] + DecodeTests .SAMPLING_STRATEGY_CONFIG ["greedy" ]
141+ captured_out = run_decoding (config )
142+ expected_output = "Input `I love to` -> ` travel and I love to write"
143+ assert expected_output in captured_out
144+
145+ @pytest .mark .tpu_only
146+ @pytest .mark .scheduled_only
147+ def test_decode_weighted_sampling (self ):
148+ config = DecodeTests .CONFIGS ["decode_sampling" ] + DecodeTests .SAMPLING_STRATEGY_CONFIG ["weighted" ]
149+ captured_out = run_decoding (config )
150+ expected_output = "Input `I love to` -> ` travel and I love to write"
151+ assert expected_output in captured_out
152+
153+ @pytest .mark .tpu_only
154+ @pytest .mark .scheduled_only
155+ def test_decode_nucleus_sampling (self ):
156+ config = DecodeTests .CONFIGS ["decode_sampling" ] + DecodeTests .SAMPLING_STRATEGY_CONFIG ["nucleus" ]
157+ captured_out = run_decoding (config )
158+ expected_output = "Input `I love to` -> ` travel and I love to write"
159+ assert expected_output in captured_out
160+
161+ @pytest .mark .tpu_only
162+ @pytest .mark .scheduled_only
163+ def test_decode_topk_sampling (self ):
164+ config = DecodeTests .CONFIGS ["decode_sampling" ] + DecodeTests .SAMPLING_STRATEGY_CONFIG ["topk" ]
165+ captured_out = run_decoding (config )
166+ expected_output = "Input `I love to` -> ` travel and I love to write"
167+ assert expected_output in captured_out
168+
169+
170+ def run_decoding (config ):
171+ f = io .StringIO ()
172+ with redirect_stdout (f ):
173+ decode_main (config )
174+ captured_out = f .getvalue ()
175+ return captured_out
176+
99177
100178if __name__ == "__main__" :
101179 absltest .main ()
0 commit comments