Skip to content

Commit 212c7a7

Browse files
committed
Fixed tests
1 parent 16526f3 commit 212c7a7

File tree

7 files changed

+35
-19
lines changed

7 files changed

+35
-19
lines changed

litgpt/generate/adapter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def main(
147147
temperature=temperature,
148148
top_k=top_k,
149149
top_p=top_p,
150-
eos_id=tokenizer.eos_id,
150+
eos_id=int(tokenizer.eos_id),
151151
)[0]
152152
t = time.perf_counter() - t0
153153

litgpt/generate/adapter_v2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ def main(
146146
temperature=temperature,
147147
top_k=top_k,
148148
top_p=top_p,
149-
eos_id=tokenizer.eos_id,
149+
eos_id=int(tokenizer.eos_id),
150150
)[0]
151151
t = time.perf_counter() - t0
152152

litgpt/generate/base.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -415,6 +415,14 @@ def generate(
415415
However, KV cache eviction is done in a more coarse-grained manner,
416416
which can lead to worse performance.
417417
418+
Key-value caching:
419+
420+
KV caches must have been assigned in `model`, in that
421+
`model.are_kv_caches_assigned() == True`. This is done by either
422+
assigning KV caches with `model.assign_kv_caches(...)`, or by creating
423+
default (dense) KV caches with `model.set_kv_cache(...)`. The latter does
424+
not allow to control memory being used.
425+
418426
Args:
419427
model: The model to use.
420428
prompts: List of batch_size 1D tensors, each being a prompt sequence
@@ -570,6 +578,8 @@ def main(
570578
with fabric.init_tensor():
571579
# set the max_seq_length to limit the memory usage to what we need
572580
model.max_seq_length = max_returned_tokens
581+
# enable the kv cache
582+
model.set_kv_cache(batch_size=1)
573583
model.eval()
574584

575585
if compile:
@@ -594,7 +604,7 @@ def main(
594604
temperature=temperature,
595605
top_k=top_k,
596606
top_p=top_p,
597-
eos_id=tokenizer.eos_id,
607+
eos_id=int(tokenizer.eos_id),
598608
)[0]
599609
t = time.perf_counter() - t0
600610
fabric.print(tokenizer.decode(y))

litgpt/generate/full.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ def main(
141141
temperature=temperature,
142142
top_k=top_k,
143143
top_p=top_p,
144-
eos_id=tokenizer.eos_id,
144+
eos_id=int(tokenizer.eos_id),
145145
)[0]
146146
t = time.perf_counter() - t0
147147

litgpt/generate/sequentially.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,7 @@ def main(
238238
# still, use init_tensor for the precision
239239
with fabric.init_tensor(), torch.device("meta"):
240240
model = GPT(config)
241+
model.set_kv_cache(batch_size=1)
241242
print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr)
242243

243244
t0 = time.perf_counter()
@@ -276,13 +277,13 @@ def main(
276277
temperature=temperature,
277278
top_k=top_k,
278279
top_p=top_p,
279-
eos_id=tokenizer.eos_id,
280+
eos_id=int(tokenizer.eos_id),
280281
)[0]
281282
t = time.perf_counter() - t0
282-
model.clear_kv_cache()
283283
print(tokenizer.decode(y))
284284
tokens_generated = y.size(0) - prompt_length
285285
print(
286286
f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec", file=sys.stderr
287287
)
288+
model.clear_kv_cache()
288289
print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB", file=sys.stderr)

litgpt/generate/tp.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,7 @@ def main(
202202
# still, use init_tensor for the precision
203203
with fabric.init_tensor(), torch.device("meta"):
204204
model = GPT(config)
205+
model.set_kv_cache(batch_size=1)
205206
fabric.print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr)
206207

207208
# sequentially do: load the checkpoint on CPU -> quantize -> apply tp -> move to device
@@ -253,14 +254,14 @@ def main(
253254
temperature=temperature,
254255
top_k=top_k,
255256
top_p=top_p,
256-
eos_id=tokenizer.eos_id,
257+
eos_id=int(tokenizer.eos_id),
257258
)[0]
258259
t = time.perf_counter() - t0
259-
model.clear_kv_cache()
260260
fabric.print(tokenizer.decode(y))
261261
tokens_generated = y.size(0) - prompt_length
262262
fabric.print(
263263
f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec", file=sys.stderr
264264
)
265+
model.clear_kv_cache()
265266
if fabric.device.type == "cuda":
266267
fabric.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB", file=sys.stderr)

litgpt/model.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,12 @@ def max_seq_length(self, value: int) -> None:
101101
device=self.cos.device,
102102
)
103103

104-
def _are_kv_caches_assigned(self) -> bool:
105-
return any(block.attn.kv_cache is not None for block in self.transformer.h)
104+
def are_kv_caches_assigned(self) -> bool:
105+
status = [block.attn.kv_cache is not None for block in self.transformer.h]
106+
result = any(status)
107+
if result and not all(status):
108+
raise IndexError("Some layers have KV caches assigned, but not all")
109+
return result
106110

107111
def assign_kv_caches(
108112
self, kv_caches: List[KVCache]
@@ -120,7 +124,7 @@ def assign_kv_caches(
120124
kv_caches: KV caches, one for each layer of the model
121125
122126
"""
123-
if self._are_kv_caches_assigned():
127+
if self.are_kv_caches_assigned():
124128
raise ValueError("Model has KV caches assigned already")
125129
if len(kv_caches) != self.config.n_layer:
126130
raise ValueError(f"kv_caches must have one entry per layer, so {self.config.n_layer} entries ")
@@ -154,7 +158,7 @@ def set_kv_cache(
154158
`self.max_seq_length`
155159
156160
"""
157-
if self._are_kv_caches_assigned() and not self._default_kv_cache:
161+
if self.are_kv_caches_assigned() and not self._default_kv_cache:
158162
raise ValueError("Model has KV caches assigned already")
159163
if max_seq_length is None:
160164
max_seq_length = self.max_seq_length
@@ -269,15 +273,14 @@ def forward(
269273
raise ValueError(f"Cannot forward sequence of length {T}, max seq length is only {self.max_seq_length}.")
270274
for_prefill = False
271275
if input_pos is not None:
272-
for_prefill = (input_pos == 0)
273276
# Few tokens generation. This needs a KV cache. If none is assigned,
274277
# the call fails
275-
msg_suffix = f"."
276-
for l_ix, block in enumerate(self.transformer.h):
277-
kv_cache = block.attn.kv_cache
278-
if kv_cache is None:
279-
raise ValueError("KV caches are not assigned. Assign KV caches with 'assign_kv_caches' or create default caches with 'set_kv_cache'")
280-
if not for_prefill:
278+
if not self.are_kv_caches_assigned():
279+
raise ValueError("KV caches are not assigned. Assign KV caches with 'assign_kv_caches' or create default caches with 'set_kv_cache'")
280+
for_prefill = (input_pos == 0)
281+
if not for_prefill:
282+
for l_ix, block in enumerate(self.transformer.h):
283+
kv_cache = block.attn.kv_cache
281284
if kv_cache.next_token_pos is None:
282285
raise ValueError("Inference calls need to start with pre-fill, i.e. 'input_pos=0'")
283286
if kv_cache.next_token_pos != input_pos:
@@ -373,6 +376,7 @@ def clear_kv_cache(self) -> None:
373376
if self._default_kv_cache:
374377
for block in self.transformer.h:
375378
block.attn.kv_cache = None
379+
self._default_kv_cache = False
376380

377381
def get_kv_cache_params(self) -> Optional[KVCacheParams]:
378382
kv_cache = self.transformer.h[0].attn.kv_cache

0 commit comments

Comments
 (0)