Skip to content

Commit 3228cd2

Browse files
author
Griffin Adams
committed
Move cache_kwargs to cache.py.
1 parent 11db56e commit 3228cd2

File tree

3 files changed

+100
-154
lines changed

3 files changed

+100
-154
lines changed

cache.py

Lines changed: 89 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,93 @@
33
import torch
44
import torch.nn as nn
55
from prompt_compression import prompt_compressor_constructor
6+
import argparse
7+
8+
9+
def add_cache_arguments(parser: argparse.ArgumentParser):
10+
group = parser.add_argument_group("cache_args")
11+
# KV-Cache Kwargs
12+
group.add_argument(
13+
"--max_cache_length",
14+
type=float,
15+
default=[1.0],
16+
nargs="+",
17+
help="Cache size per layer. If len < n layers, the values are tiled. Must have len divisible by n layers. \
18+
If 0 < x <= 1, it is percent of |prompt| + max new tokens. Otherwise, if > 1, its the maximum size.",
19+
)
20+
group.add_argument(
21+
"--cache_strategy",
22+
default="full",
23+
choices=["full", "random", "window", "scissor", "l2"],
24+
)
25+
26+
group.add_argument(
27+
"--prompt_compression_strategy",
28+
default="recent_global",
29+
choices=["recent_global", "snapkv", "l2"],
30+
help="If |prompt| exceeds max_cache_length, we need to specify a strategy for compressing it to max_cache_length.",
31+
)
32+
33+
# Optional Cache Kwargs depending on cache_strategy
34+
group.add_argument(
35+
"--global_tokens",
36+
default=4,
37+
type=int,
38+
help="The number of initial tokens to always include in the KV-Cache. \
39+
If using window strategy, the actual window becomes max_cache_length - global_tokens.",
40+
)
41+
42+
# Locality
43+
group.add_argument(
44+
"--recent_window", # NB: for KVCacheWindow, recent_window is implicitly set to self.max_cache_length - self.global_tokens.
45+
default=10, # 10 is default specified in ScissorHands paper ("r" in Algorithm 2).
46+
type=int,
47+
help="The number of recently generated tokens to always spare from eviction.",
48+
)
49+
50+
# Scissorhands-specific Hyperparameters (--cache_strategy == "scissor")
51+
## See Algorithm 1 & 2 in arxiv.org/abs/2305.17118
52+
group.add_argument(
53+
"--history_window_size", # Equivalent to "m" in Algorithm 2.
54+
default=400, # 400 is default specified in paper.
55+
type=int,
56+
help="The number of past tokens to consider when computing 'Heavy Hitters' in the KV-Cache.",
57+
)
58+
group.add_argument(
59+
"--drop_amount", # Equivalent to "m" in Algorithm 2.
60+
default=0.5, # 0.4 is default specified in paper.
61+
type=float,
62+
help="The number of tokens to evict KV-Cache reaches capacity (max_cache_length). Expressed as a fraction of max_cache_length.",
63+
)
64+
group.add_argument(
65+
"-attn_thresholding",
66+
default=False,
67+
action="store_true",
68+
help="Whether to accumulate number of times a token was unimportant (binary) versus raw un-normalized probabilities. If true, more memory efficient.",
69+
)
70+
71+
group.add_argument(
72+
"--attn_record_freq",
73+
default=10,
74+
type=int,
75+
help="How often to record attention weights for the ScissorHands cache..",
76+
)
77+
78+
79+
def cache_compatibility(args):
80+
if args.cache_strategy == "full":
81+
# Full implies no compression, which means --max_cache_length = [1.0] (same size as prompt + max_new_tokens)
82+
assert all(
83+
[l == 1.0 for l in args.max_cache_length]
84+
), "Full cache strategy only supports max_cache_length=1.0."
85+
86+
# Attention-based eviction policies must use an attention-based prompt compressor
87+
if args.cache_strategy in {"scissor"}:
88+
assert (
89+
args.prompt_compression_strategy == "snapkv"
90+
), 'Scissor requires "snapkv" prompt compression strategy'
91+
92+
print("The cache argument values you provided appear compatible with each other!")
693

794

895
class KVCache(ABC, nn.Module):
@@ -309,6 +396,7 @@ class KVCacheWindow(KVCache):
309396
"max_cache_length",
310397
"global_tokens",
311398
"prompt_compression_strategy",
399+
# NB: "recent_window" is ignored as a relevant kwarg. It is fixed to self.max_cache_length - self.global_tokens.
312400
]
313401

314402
def __init__(
@@ -467,7 +555,7 @@ def return_attn(self) -> bool:
467555
Whether or not we need to return attention weights for cache management.
468556
469557
We return attention weights if 3 conditions are met:
470-
1) The cache is not in the prefill stage
558+
1) The cache is not in the prefill stage.
471559
2) The number of tokens left in the eviction queue // the frequency with which we record attention < attention history window.
472560
3) The number of insertions is a multiple of the frequency with which we record attention.
473561

eval.py

Lines changed: 6 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515
import torch._inductor.config
1616

1717

18+
from cache import add_cache_arguments, cache_compatibility
19+
20+
1821
def device_sync(device):
1922
if "cuda" in device:
2023
torch.cuda.synchronize(device)
@@ -208,73 +211,11 @@ def main(
208211
"--device", type=str, default=default_device, help="Device to use"
209212
)
210213

211-
# KV-Cache Kwargs
212-
parser.add_argument(
213-
"--max_cache_length",
214-
type=float,
215-
default=[1.0],
216-
nargs="+",
217-
help="Cache size per layer. If len < n layers, the values are tiled. Must have len divisible by n layers. \
218-
If 0 < x <= 1, it is percent of |prompt| + max new tokens. Otherwise, if > 1, its the maximum size.",
219-
)
220-
parser.add_argument(
221-
"--cache_strategy",
222-
default="full",
223-
choices=["full", "random", "window", "scissor"],
224-
)
225-
# Optional Cache Kwargs depending on cache_strategy
226-
parser.add_argument(
227-
"--global_tokens",
228-
default=4,
229-
type=int,
230-
help="The number of initial tokens to always include in the KV-Cache. \
231-
If using window strategy, the actual window becomes max_cache_length - global_tokens.",
232-
)
233-
234-
# Scissorhands-specific Hyperparameters (--cache_strategy == "scissor")
235-
## See Algorithm 1 & 2 in arxiv.org/abs/2305.17118
236-
parser.add_argument(
237-
"--history_window_size", # Equivalent to "m" in Algorithm 2.
238-
default=400, # 400 is default specified in paper.
239-
type=int,
240-
help="The number of past tokens to consider when computing 'Heavy Hitters' in the KV-Cache.",
241-
)
242-
parser.add_argument(
243-
"--drop_amount", # Equivalent to "m" in Algorithm 2.
244-
default=0, # 0.4 is default specified in paper.
245-
type=float,
246-
help="The number of tokens to evict KV-Cache reaches capacity (max_cache_length). Expressed as a fraction of max_cache_length.",
247-
)
248-
parser.add_argument(
249-
"--recent_window", # Equivalent to "r" in Algorithm 2.
250-
default=10, # 10 is default specified in paper.
251-
type=int,
252-
help="The number of recently generated tokens to always save when evicting tokens from the ScissorHands KV-Cache.",
253-
)
254-
parser.add_argument(
255-
"-attn_thresholding",
256-
default=False,
257-
action="store_true",
258-
help="Whether to accumulate number of times a token was unimportant (binary) versus raw un-normalized probabilities. If true, less precise yet more space efficient.",
259-
)
214+
add_cache_arguments(parser)
260215

261216
args = parser.parse_args()
262217

263-
if args.cache_strategy == "full":
264-
# Full implies no compression, which means --max_cache_length = [1.0] (same size as prompt + max_new_tokens)
265-
assert all(
266-
[l == 1.0 for l in args.max_cache_length]
267-
), "Full cache strategy only supports max_cache_length=1.0."
268-
269-
cache_kwargs = {
270-
"cache_strategy": args.cache_strategy,
271-
"max_cache_length": args.max_cache_length,
272-
"global_tokens": args.global_tokens,
273-
"history_window_size": args.history_window_size,
274-
"drop_amount": args.drop_amount,
275-
"recent_window": args.recent_window,
276-
"attn_thresholding": args.attn_thresholding,
277-
}
218+
cache_compatibility(args)
278219

279220
main(
280221
args.tasks,
@@ -283,5 +224,5 @@ def main(
283224
args.checkpoint_path,
284225
args.profile,
285226
args.device,
286-
cache_kwargs,
227+
cache_kwargs=vars(args),
287228
)

generate.py

Lines changed: 5 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import torch._dynamo.config
1414
import torch._inductor.config
1515

16+
from cache import add_cache_arguments
1617
from generation_utils import decode_one_token, prefill
1718

1819

@@ -37,6 +38,7 @@ def device_sync(device):
3738

3839
from tokenizer import get_tokenizer
3940
from generation_utils import generate, encode_tokens, _load_model
41+
from cache import add_cache_arguments, cache_compatibility
4042

4143

4244
def _get_model_size(model):
@@ -306,70 +308,7 @@ def callback(x):
306308
"--device", type=str, default=default_device, help="Device to use"
307309
)
308310

309-
# KV-Cache Kwargs
310-
parser.add_argument(
311-
"--max_cache_length",
312-
type=float,
313-
default=[1.0],
314-
nargs="+",
315-
help="Cache size per layer. If len < n layers, the values are tiled. Must have len divisible by n layers. \
316-
If 0 < x <= 1, it is percent of |prompt| + max new tokens. Otherwise, if > 1, its the maximum size.",
317-
)
318-
parser.add_argument(
319-
"--cache_strategy",
320-
default="full",
321-
choices=["full", "random", "window", "scissor", "l2"],
322-
)
323-
324-
parser.add_argument(
325-
"--prompt_compression_strategy",
326-
default="recent_global",
327-
choices=["recent_global", "snapkv", "l2"],
328-
help="If |prompt| exceeds max_cache_length, we need to specify a strategy for compressing it to max_cache_length.",
329-
)
330-
331-
# Optional Cache Kwargs depending on cache_strategy
332-
parser.add_argument(
333-
"--global_tokens",
334-
default=4,
335-
type=int,
336-
help="The number of initial tokens to always include in the KV-Cache. \
337-
If using window strategy, the actual window becomes max_cache_length - global_tokens.",
338-
)
339-
340-
# Scissorhands-specific Hyperparameters (--cache_strategy == "scissor")
341-
## See Algorithm 1 & 2 in arxiv.org/abs/2305.17118
342-
parser.add_argument(
343-
"--history_window_size", # Equivalent to "m" in Algorithm 2.
344-
default=400, # 400 is default specified in paper.
345-
type=int,
346-
help="The number of past tokens to consider when computing 'Heavy Hitters' in the KV-Cache.",
347-
)
348-
parser.add_argument(
349-
"--drop_amount", # Equivalent to "m" in Algorithm 2.
350-
default=0.5, # 0.4 is default specified in paper.
351-
type=float,
352-
help="The number of tokens to evict KV-Cache reaches capacity (max_cache_length). Expressed as a fraction of max_cache_length.",
353-
)
354-
parser.add_argument(
355-
"--recent_window", # Equivalent to "r" in Algorithm 2.
356-
default=10, # 10 is default specified in paper.
357-
type=int,
358-
help="The number of recently generated tokens to always save when evicting tokens from the ScissorHands KV-Cache.",
359-
)
360-
parser.add_argument(
361-
"-attn_thresholding",
362-
default=False,
363-
action="store_true",
364-
help="Whether to accumulate number of times a token was unimportant (binary) versus raw un-normalized probabilities. If true, more memory efficient.",
365-
)
366-
367-
parser.add_argument(
368-
"--attn_record_freq",
369-
default=10,
370-
type=int,
371-
help="How often to record attention weights for the ScissorHands cache. Higher .",
372-
)
311+
add_cache_arguments(parser)
373312

374313
args = parser.parse_args()
375314

@@ -378,29 +317,7 @@ def callback(x):
378317
with open(prompt_fn) as fd:
379318
args.prompt = fd.read().strip()
380319

381-
if args.cache_strategy == "full":
382-
# Full implies no compression, which means --max_cache_length = [1.0] (same size as prompt + max_new_tokens)
383-
assert all(
384-
[l == 1.0 for l in args.max_cache_length]
385-
), "Full cache strategy only supports max_cache_length=1.0."
386-
387-
# Attention-based eviction policies must use an attention-based prompt compressor
388-
if args.cache_strategy in {"scissor"}:
389-
assert (
390-
args.prompt_compression_strategy == "snapkv"
391-
), 'Scissor requires "snapkv" prompt compression strategy'
392-
393-
cache_kwargs = {
394-
"cache_strategy": args.cache_strategy,
395-
"max_cache_length": args.max_cache_length,
396-
"global_tokens": args.global_tokens,
397-
"history_window_size": args.history_window_size,
398-
"drop_amount": args.drop_amount,
399-
"recent_window": args.recent_window,
400-
"attn_thresholding": args.attn_thresholding,
401-
"prompt_compression_strategy": args.prompt_compression_strategy,
402-
"attn_record_freq": args.attn_record_freq,
403-
}
320+
cache_compatibility(args)
404321

405322
main(
406323
args.prompt,
@@ -416,5 +333,5 @@ def callback(x):
416333
args.draft_checkpoint_path,
417334
args.speculate_k,
418335
args.device,
419-
cache_kwargs,
336+
cache_kwargs=vars(args),
420337
)

0 commit comments

Comments
 (0)