|
16 | 16 | import litgpt.generate.base as generate
|
17 | 17 | from litgpt import GPT, Config
|
18 | 18 | from litgpt.generate.base import sample
|
| 19 | +from litgpt.kvcache.base import KVCacheParams |
| 20 | +from litgpt.kvcache.test_utils import create_kv_cache |
19 | 21 |
|
20 | 22 |
|
21 | 23 | skip_in_ci_on_macos = pytest.mark.skipif(
|
@@ -128,7 +130,73 @@ def test_generate_single_vs_batch(max_seq_length):
|
128 | 130 | for rb, rs, prompt in zip(res_batch, res_single, prompts):
|
129 | 131 | print(f"rs: {rs}\nrb: {rb}\npr: {prompt}")
|
130 | 132 | torch.testing.assert_close(rs, rb)
|
131 |
| - print("OK") |
| 133 | + |
| 134 | + |
| 135 | +def test_prompt_chunksize(): |
| 136 | + import lightning as L |
| 137 | + L.seed_everything(1234) |
| 138 | + |
| 139 | + batch_size = 3 |
| 140 | + vocab_size = 128 |
| 141 | + max_seq_length = 64 |
| 142 | + n_layer = 2 |
| 143 | + params = KVCacheParams( |
| 144 | + batch_size=batch_size, |
| 145 | + n_query_groups=4, |
| 146 | + cache_length=16, |
| 147 | + head_size=8, |
| 148 | + n_head=4, |
| 149 | + device=torch.device("cpu"), |
| 150 | + dtype=torch.bfloat16, |
| 151 | + ) |
| 152 | + kv_cache = create_kv_cache("mostrec-default", params) |
| 153 | + config = Config( |
| 154 | + block_size=max_seq_length, |
| 155 | + vocab_size=vocab_size, |
| 156 | + n_layer=n_layer, |
| 157 | + n_head=params.n_head, |
| 158 | + n_embd=params.n_head * params.head_size, |
| 159 | + rotary_percentage=1, |
| 160 | + ) |
| 161 | + model = GPT( |
| 162 | + config, |
| 163 | + kv_cache=[ |
| 164 | + create_kv_cache("mostrec-default", params) |
| 165 | + for _ in range(n_layer) |
| 166 | + ], |
| 167 | + ) |
| 168 | + |
| 169 | + prompt_sizes = [32, 37, 42] |
| 170 | + prompts = [ |
| 171 | + torch.randint( |
| 172 | + low=0, |
| 173 | + high=vocab_size, |
| 174 | + size=(sz,) |
| 175 | + ) |
| 176 | + for sz in prompt_sizes |
| 177 | + ] |
| 178 | + |
| 179 | + results = [] |
| 180 | + chunk_sizes = [1, 2, 4, 5, 16] |
| 181 | + for prompt_chunksize in chunk_sizes: |
| 182 | + results.append( |
| 183 | + generate.generate( |
| 184 | + model=model, |
| 185 | + prompts=prompts, |
| 186 | + prompt_chunksize=prompt_chunksize, |
| 187 | + max_returned_tokens=max_seq_length, |
| 188 | + top_k=1, |
| 189 | + ) |
| 190 | + ) |
| 191 | + |
| 192 | + result_1 = results[0] |
| 193 | + assert len(result_1) == batch_size |
| 194 | + for prompt_chunksize, result in zip(chunk_sizes[1:], results[1:]): |
| 195 | + print(f"prompt_chunksize: 1 versus {prompt_chunksize}") |
| 196 | + assert len(result) == batch_size |
| 197 | + for res1, resn in zip(result_1, result): |
| 198 | + print(f"res1: {res1}\nres{prompt_chunksize}: {resn}") |
| 199 | + torch.testing.assert_close(res1, resn) |
132 | 200 |
|
133 | 201 |
|
134 | 202 | @skip_in_ci_on_macos
|
|
0 commit comments