Skip to content

Commit 09f2406

Browse files
awaelchlirasbt
authored andcommitted
Make chat output less verbose (#1123)
1 parent 3bb7a62 commit 09f2406

File tree

2 files changed

+14
-8
lines changed

2 files changed

+14
-8
lines changed

litgpt/chat/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,6 @@ def main(
142142
print("Merging LoRA weights with the base model. This won't take long and is a one-time-only thing.")
143143
merge_lora(checkpoint_path)
144144

145-
fabric.print(f"Loading model {str(checkpoint_path)!r} with {config.__dict__}", file=sys.stderr)
146145
with fabric.init_module(empty_init=True):
147146
model = GPT(config)
148147
# enable the kv cache
@@ -163,13 +162,14 @@ def main(
163162
prompt_style = load_prompt_style(checkpoint_dir) if has_prompt_style(checkpoint_dir) else PromptStyle.from_config(config)
164163
stop_tokens = prompt_style.stop_tokens(tokenizer)
165164

165+
print(f"Now chatting with {config.name}.\nTo exit, press 'Enter' on an empty prompt.\n")
166166
L.seed_everything(1234)
167167
while True:
168168
try:
169169
prompt = input(">> Prompt: ")
170170
except KeyboardInterrupt:
171171
break
172-
if not prompt:
172+
if prompt.lower().strip() in ("", "quit", "exit"):
173173
break
174174
prompt = prompt_style.apply(prompt=prompt)
175175
encoded_prompt = tokenizer.encode(prompt, device=fabric.device)

tests/test_chat.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
2-
2+
import re
33
import subprocess
44
import sys
55
from contextlib import redirect_stderr, redirect_stdout
@@ -87,7 +87,15 @@ def test_main(mocked_input, stop_iteration, fake_checkpoint_dir, monkeypatch, te
8787
mocked_input.side_effect = ["Hello", stop_iteration]
8888

8989
config_path = fake_checkpoint_dir / "model_config.yaml"
90-
config = {"block_size": 128, "vocab_size": 50, "n_layer": 2, "n_head": 4, "n_embd": 8, "rotary_percentage": 1}
90+
config = {
91+
"name": "Llama 3",
92+
"block_size": 128,
93+
"vocab_size": 50,
94+
"n_layer": 2,
95+
"n_head": 4,
96+
"n_embd": 8,
97+
"rotary_percentage": 1,
98+
}
9199
config_path.write_text(yaml.dump(config))
92100

93101
load_mock = Mock()
@@ -112,10 +120,8 @@ def test_main(mocked_input, stop_iteration, fake_checkpoint_dir, monkeypatch, te
112120
assert generate_mock.mock_calls == [
113121
call(ANY, tensor_like, 128, temperature=2.0, top_k=2, stop_tokens=([tokenizer_mock.return_value.eos_id],))
114122
]
115-
# # only the generated result is printed to stdout
116-
assert out.getvalue() == ">> Reply: foo bar baz\n"
117-
118-
assert "'padded_vocab_size': 512, 'n_layer': 2, 'n_head': 4" in err.getvalue()
123+
# only the generated result is printed to stdout
124+
assert re.match("Now chatting with Llama 3.*>> .*Reply: foo bar baz", out.getvalue(), re.DOTALL)
119125

120126

121127
@pytest.mark.parametrize("mode", ["file", "entrypoint"])

0 commit comments

Comments
 (0)