Skip to content

Commit 11db56e

Browse files
author
Griffin Adams
committed
Update generate.py to pull from generation_utils.py
1 parent e3e1b25 commit 11db56e

File tree

2 files changed

+12
-9
lines changed

2 files changed

+12
-9
lines changed

generate.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,14 @@
77
import sys
88
import time
99
from pathlib import Path
10-
from typing import Optional, Tuple
10+
from typing import Optional
1111

1212
import torch
1313
import torch._dynamo.config
1414
import torch._inductor.config
1515

16+
from generation_utils import decode_one_token, prefill
17+
1618

1719
def device_sync(device):
1820
if "cuda" in device:

prompt_compression.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,14 @@ def __init__(self, head_specific, **kwargs) -> None:
6868
self.kernel_size = 5
6969
self.observation_len = 16
7070

71+
self.pool = torch.nn.AvgPool1d(
72+
self.kernel_size,
73+
stride=1,
74+
padding=self.kernel_size // 2,
75+
ceil_mode=False,
76+
count_include_pad=False,
77+
)
78+
7179
def is_compatible(self) -> bool:
7280
# Can only be used with head-specific KV-caches
7381
return self.head_specific
@@ -78,19 +86,12 @@ def requires_attn(self) -> bool:
7886
def __call__(self, input_pos, k_val, v_val, attn):
7987
assert self.head_specific, "SnapKV can only be used with head-specific KV-caches, e.g., placing the same token in different locations across heads)."
8088

81-
pool = torch.nn.AvgPool1d(
82-
self.kernel_size,
83-
stride=1,
84-
padding=self.kernel_size // 2,
85-
ceil_mode=False,
86-
count_include_pad=False,
87-
)
8889
priority = attn[:, :, -self.observation_len :, :].mean(dim=2)
8990
prev_shape = priority.shape
9091

9192
# We'll be returning the attention history so we need to keep a copy before it's modified
9293
attn_history = priority.clone()
93-
priority = pool(priority)
94+
priority = self.pool(priority)
9495
assert (
9596
priority.shape == prev_shape
9697
), f"Pooling operation should not change the dimension: {prev_shape} -> {priority.shape}"

0 commit comments

Comments
 (0)