Skip to content

Commit 127773a

Browse files
author
Griffin Adams
committed
Add random prompt compression strategy.
1 parent c608e80 commit 127773a

File tree

4 files changed

+42
-7
lines changed

4 files changed

+42
-7
lines changed

cache.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def add_cache_arguments(parser: argparse.ArgumentParser):
3535
group.add_argument(
3636
"--prompt_compression_strategy", # This doesn't matter if args.feed_long_prompts is True
3737
default="recent_global",
38-
choices=["recent_global", "snapkv", "l2"],
38+
choices=["recent_global", "snapkv", "l2", "random"],
3939
help="If |prompt| exceeds max_cache_length, we need to specify a strategy for compressing it to max_cache_length.",
4040
)
4141

cache_configs/random.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
cache_strategy: "random"
22
max_cache_length: [1024]
33
feed_long_prompt: True
4-
global_tokens: 1
4+
global_tokens: 4

cache_configs/scissor.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ cache_strategy: "scissor"
22
max_cache_length: [1024]
33
global_tokens: 4
44
history_window_size: 400
5-
drop_amount: 0.0001
5+
drop_amount: 0
66
recent_window: 10
77
attn_thresholding: False
88
prompt_compression_strategy: "snapkv"

prompt_compression.py

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,37 @@ def is_compatible(self) -> bool:
2424
pass
2525

2626

27+
class PromptCompressorRandom(PromptCompressor):
28+
def __init__(self, head_specific, **kwargs) -> None:
29+
super().__init__(head_specific, **kwargs)
30+
31+
def is_compatible(self) -> bool:
32+
# Can be used with any cache
33+
return True
34+
35+
def requires_attn(self) -> bool:
36+
return False
37+
38+
def __call__(self, input_pos, k_val, v_val):
39+
seq_len = input_pos.shape[0]
40+
global_idxs = torch.arange(self.global_tokens, device=input_pos.device)
41+
rand_idxs = (
42+
(
43+
self.global_tokens
44+
+ torch.randperm(seq_len - self.global_tokens, device=input_pos.device)[
45+
: self.max_cache_length - self.global_tokens
46+
]
47+
)
48+
.sort()
49+
.values
50+
)
51+
keep_idxs = torch.cat([global_idxs, rand_idxs], dim=0)
52+
assert len(keep_idxs) == self.max_cache_length
53+
k_val = k_val[:, :, keep_idxs]
54+
v_val = v_val[:, :, keep_idxs]
55+
return keep_idxs, k_val, v_val
56+
57+
2758
class PromptCompressorRecentGlobal(PromptCompressor):
2859
def __init__(self, head_specific, **kwargs) -> None:
2960
super().__init__(head_specific, **kwargs)
@@ -84,9 +115,10 @@ def requires_attn(self) -> bool:
84115
return True
85116

86117
def __call__(self, input_pos, k_val, v_val, attn):
87-
assert self.head_specific, "SnapKV can only be used with head-specific KV-caches, e.g., placing the same token in different locations across heads)."
118+
seq_len = input_pos.shape[0]
119+
obs_len = min(self.observation_len, seq_len)
88120

89-
priority = attn[:, :, -self.observation_len :, :].mean(dim=2)
121+
priority = attn[:, :, -obs_len:, :].mean(dim=2)
90122
prev_shape = priority.shape
91123

92124
# We'll be returning the attention history so we need to keep a copy before it's modified
@@ -95,8 +127,9 @@ def __call__(self, input_pos, k_val, v_val, attn):
95127
assert (
96128
priority.shape == prev_shape
97129
), f"Pooling operation should not change the dimension: {prev_shape} -> {priority.shape}"
98-
priority[:, :, -self.observation_len :] = (
99-
1.0 # Ensure the observation window is selected
130+
priority[:, :, -obs_len:] = 1.0 # Ensure the observation window is selected
131+
priority[:, :, : self.global_tokens] = (
132+
1.0 # Ensure the global tokens are selected
100133
)
101134
keep_idxs = (
102135
priority.topk(self.max_cache_length, dim=-1).indices.sort(dim=-1).values
@@ -152,5 +185,7 @@ def prompt_compressor_constructor(strategy):
152185
return PromptCompressorSnapKV
153186
elif strategy == "l2":
154187
return PromptCompressorL2
188+
elif strategy == "random":
189+
return PromptCompressorRandom
155190
else:
156191
raise ValueError(f"Unknown prompt compression strategy: {strategy}")

0 commit comments

Comments
 (0)