Skip to content

Commit a10b9a9

Browse files
author
Griffin Adams
committed
Add attention loss to README.
1 parent 9bbca97 commit a10b9a9

File tree

4 files changed

+74
-1
lines changed

4 files changed

+74
-1
lines changed

README.md

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ To better understand why one method may work better than another, it is importan
249249

250250
Specifically, it’s nice to be able to understand the deviation from the full attention caused by token dropping. As defined in [L2-Norm](https://arxiv.org/abs/2406.11430) and [FastGen](https://arxiv.org/abs/2310.01801), we compute the attention loss as the sum of the attention probabilities for the evicted tokens.
251251

252-
![Attention Loss Diagram](images/AttentionLoss.png)
252+
![Attention Loss Diagram](images/attention_loss_concept.png)
253253

254254
To calculate the **Attention Loss**, we need to keep all tokens in the KVCache, e.g., set cache strategy to `full`, while simulating evictions for a compressed cache.
255255

@@ -265,6 +265,12 @@ A handful of debugging experiments can be kicked off by running:
265265
bash experiments/attention_loss.sh
266266
```
267267

268+
These experiments record Attention Loss at various decoding steps. From these experiments, which record PPL on [PG-19 Book Corpus](https://github.com/google-deepmind/pg19), we can show a clear correlation between Attention Loss and downstream performance (PPL).
269+
270+
![Attention Loss Results](images/attention_loss_pg19.png)
271+
272+
This suggests that **Attention Loss** might be an decent proxy to approximate downstream degradation from compression.
273+
268274
## Extending Cold Compress
269275

270276
### Adding a new Cache Strategy

charts/attention_loss.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
import matplotlib.pyplot as plt
2+
import numpy as np
3+
import pandas as pd
4+
5+
6+
if __name__ == "__main__":
7+
import matplotlib.pyplot as plt
8+
import numpy as np
9+
# Define the data
10+
df = pd.read_csv("/workspace/attention_loss.csv")
11+
12+
decoding_steps = np.arange(500, 8500, 500)
13+
n = len(decoding_steps)
14+
models = ['Low Compression', 'Medium Compression', 'High Compression']
15+
16+
# Sample data - replace with your actual data
17+
attention_loss = {
18+
'Low Compression': df["25_attention_loss"][:n],
19+
'Medium Compression': df["50_attention_loss"][:n],
20+
'High Compression': df["75_attention_loss"][:n],
21+
}
22+
23+
ppl_delta = {
24+
'Low Compression': df["25_ppl_delta"][:n],
25+
'Medium Compression': df["50_ppl_delta"][:n],
26+
'High Compression': df["75_ppl_delta"][:n],
27+
}
28+
29+
# Create the plot
30+
plt.rcParams.update({'font.size': 20})
31+
32+
fig, ax1 = plt.subplots(figsize=(20, 10))
33+
34+
# Colors for each model
35+
colors = ["#006AA7", '#16a085', '#8e44ad', '#d35400']
36+
37+
# Plot Attention Loss
38+
for model, color in zip(models, colors):
39+
ax1.plot(decoding_steps, attention_loss[model], color=color, label=f'{model} (Attention Loss)', linewidth=6)
40+
ax1.scatter(decoding_steps, attention_loss[model], color=color, s=400)
41+
42+
ax1.set_xlabel('Decoding Steps', fontsize=32)
43+
ax1.set_ylabel('Attention Loss', fontsize=32)
44+
45+
ax1.tick_params(axis='y', labelsize=32)
46+
ax1.tick_params(axis='x', labelsize=32)
47+
48+
# Create a second y-axis for PPL
49+
ax2 = ax1.twinx()
50+
51+
# Plot Perplexity (PPL)
52+
for model, color in zip(models, colors):
53+
ax2.plot(decoding_steps, ppl_delta[model], color=color, linestyle='--', label=f'{model} (PPL Δ)', linewidth=6)
54+
ax2.scatter(decoding_steps, ppl_delta[model], color=color, marker='s', s=400)
55+
56+
ax2.set_ylabel("Perplexity Delta (PPL Δ)", fontsize=32)
57+
ax2.tick_params(axis="y", labelsize=32)
58+
59+
# Combine legends
60+
lines1, labels1 = ax1.get_legend_handles_labels()
61+
lines2, labels2 = ax2.get_legend_handles_labels()
62+
ax1.legend(lines1 + lines2, labels1 + labels2, loc='upper left', bbox_to_anchor=(0.05, 0.95), borderaxespad=0.25, fontsize=24)
63+
64+
plt.title("Attention Loss & Perplexity vs Decoding Steps", fontsize=32)
65+
plt.grid(True)
66+
plt.tight_layout()
67+
plt.savefig("/workspace/cold-compress/charts/attention_loss.png")
File renamed without changes.

images/attention_loss_pg19.png

217 KB
Loading

0 commit comments

Comments
 (0)