Skip to content

Commit 7953793

Browse files
committed
Added test
1 parent 30fbada commit 7953793

File tree

2 files changed

+71
-2
lines changed

2 files changed

+71
-2
lines changed

litgpt/generate/base.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -404,7 +404,8 @@ def generate(
404404
Takes a list of prompts as input (1D tensors, can be of different lengths)
405405
and generates tokens as specified.
406406
407-
Choice of `pro mpt_chunksize`:
407+
Choice of `prompt_chunksize`:
408+
408409
This parameter can speed up inference for long prompts. Let
409410
`M = min_prompt_size - max_prefill_length`, the minimum prompt length
410411
minus the max prefill length of the KV cache. If `M > 0`, the prompt

tests/test_generate.py

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
import litgpt.generate.base as generate
1717
from litgpt import GPT, Config
1818
from litgpt.generate.base import sample
19+
from litgpt.kvcache.base import KVCacheParams
20+
from litgpt.kvcache.test_utils import create_kv_cache
1921

2022

2123
skip_in_ci_on_macos = pytest.mark.skipif(
@@ -128,7 +130,73 @@ def test_generate_single_vs_batch(max_seq_length):
128130
for rb, rs, prompt in zip(res_batch, res_single, prompts):
129131
print(f"rs: {rs}\nrb: {rb}\npr: {prompt}")
130132
torch.testing.assert_close(rs, rb)
131-
print("OK")
133+
134+
135+
def test_prompt_chunksize():
136+
import lightning as L
137+
L.seed_everything(1234)
138+
139+
batch_size = 3
140+
vocab_size = 128
141+
max_seq_length = 64
142+
n_layer = 2
143+
params = KVCacheParams(
144+
batch_size=batch_size,
145+
n_query_groups=4,
146+
cache_length=16,
147+
head_size=8,
148+
n_head=4,
149+
device=torch.device("cpu"),
150+
dtype=torch.bfloat16,
151+
)
152+
kv_cache = create_kv_cache("mostrec-default", params)
153+
config = Config(
154+
block_size=max_seq_length,
155+
vocab_size=vocab_size,
156+
n_layer=n_layer,
157+
n_head=params.n_head,
158+
n_embd=params.n_head * params.head_size,
159+
rotary_percentage=1,
160+
)
161+
model = GPT(
162+
config,
163+
kv_cache=[
164+
create_kv_cache("mostrec-default", params)
165+
for _ in range(n_layer)
166+
],
167+
)
168+
169+
prompt_sizes = [32, 37, 42]
170+
prompts = [
171+
torch.randint(
172+
low=0,
173+
high=vocab_size,
174+
size=(sz,)
175+
)
176+
for sz in prompt_sizes
177+
]
178+
179+
results = []
180+
chunk_sizes = [1, 2, 4, 5, 16]
181+
for prompt_chunksize in chunk_sizes:
182+
results.append(
183+
generate.generate(
184+
model=model,
185+
prompts=prompts,
186+
prompt_chunksize=prompt_chunksize,
187+
max_returned_tokens=max_seq_length,
188+
top_k=1,
189+
)
190+
)
191+
192+
result_1 = results[0]
193+
assert len(result_1) == batch_size
194+
for prompt_chunksize, result in zip(chunk_sizes[1:], results[1:]):
195+
print(f"prompt_chunksize: 1 versus {prompt_chunksize}")
196+
assert len(result) == batch_size
197+
for res1, resn in zip(result_1, result):
198+
print(f"res1: {res1}\nres{prompt_chunksize}: {resn}")
199+
torch.testing.assert_close(res1, resn)
132200

133201

134202
@skip_in_ci_on_macos

0 commit comments

Comments
 (0)