11# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
2-
2+ import re
33import subprocess
44import sys
55from contextlib import redirect_stderr , redirect_stdout
@@ -87,7 +87,15 @@ def test_main(mocked_input, stop_iteration, fake_checkpoint_dir, monkeypatch, te
8787 mocked_input .side_effect = ["Hello" , stop_iteration ]
8888
8989 config_path = fake_checkpoint_dir / "model_config.yaml"
90- config = {"block_size" : 128 , "vocab_size" : 50 , "n_layer" : 2 , "n_head" : 4 , "n_embd" : 8 , "rotary_percentage" : 1 }
90+ config = {
91+ "name" : "Llama 3" ,
92+ "block_size" : 128 ,
93+ "vocab_size" : 50 ,
94+ "n_layer" : 2 ,
95+ "n_head" : 4 ,
96+ "n_embd" : 8 ,
97+ "rotary_percentage" : 1 ,
98+ }
9199 config_path .write_text (yaml .dump (config ))
92100
93101 load_mock = Mock ()
@@ -112,10 +120,8 @@ def test_main(mocked_input, stop_iteration, fake_checkpoint_dir, monkeypatch, te
112120 assert generate_mock .mock_calls == [
113121 call (ANY , tensor_like , 128 , temperature = 2.0 , top_k = 2 , stop_tokens = ([tokenizer_mock .return_value .eos_id ],))
114122 ]
115- # # only the generated result is printed to stdout
116- assert out .getvalue () == ">> Reply: foo bar baz\n "
117-
118- assert "'padded_vocab_size': 512, 'n_layer': 2, 'n_head': 4" in err .getvalue ()
123+ # only the generated result is printed to stdout
124+ assert re .match ("Now chatting with Llama 3.*>> .*Reply: foo bar baz" , out .getvalue (), re .DOTALL )
119125
120126
121127@pytest .mark .parametrize ("mode" , ["file" , "entrypoint" ])
0 commit comments