|
| 1 | +import matplotlib.pyplot as plt |
| 2 | +import numpy as np |
| 3 | +import pandas as pd |
| 4 | + |
| 5 | + |
| 6 | +if __name__ == "__main__": |
| 7 | + import matplotlib.pyplot as plt |
| 8 | + import numpy as np |
| 9 | + # Define the data |
| 10 | + df = pd.read_csv("/workspace/attention_loss.csv") |
| 11 | + |
| 12 | + decoding_steps = np.arange(500, 8500, 500) |
| 13 | + n = len(decoding_steps) |
| 14 | + models = ['Low Compression', 'Medium Compression', 'High Compression'] |
| 15 | + |
| 16 | + # Sample data - replace with your actual data |
| 17 | + attention_loss = { |
| 18 | + 'Low Compression': df["25_attention_loss"][:n], |
| 19 | + 'Medium Compression': df["50_attention_loss"][:n], |
| 20 | + 'High Compression': df["75_attention_loss"][:n], |
| 21 | + } |
| 22 | + |
| 23 | + ppl_delta = { |
| 24 | + 'Low Compression': df["25_ppl_delta"][:n], |
| 25 | + 'Medium Compression': df["50_ppl_delta"][:n], |
| 26 | + 'High Compression': df["75_ppl_delta"][:n], |
| 27 | + } |
| 28 | + |
| 29 | + # Create the plot |
| 30 | + plt.rcParams.update({'font.size': 20}) |
| 31 | + |
| 32 | + fig, ax1 = plt.subplots(figsize=(20, 10)) |
| 33 | + |
| 34 | + # Colors for each model |
| 35 | + colors = ["#006AA7", '#16a085', '#8e44ad', '#d35400'] |
| 36 | + |
| 37 | + # Plot Attention Loss |
| 38 | + for model, color in zip(models, colors): |
| 39 | + ax1.plot(decoding_steps, attention_loss[model], color=color, label=f'{model} (Attention Loss)', linewidth=6) |
| 40 | + ax1.scatter(decoding_steps, attention_loss[model], color=color, s=400) |
| 41 | + |
| 42 | + ax1.set_xlabel('Decoding Steps', fontsize=32) |
| 43 | + ax1.set_ylabel('Attention Loss', fontsize=32) |
| 44 | + |
| 45 | + ax1.tick_params(axis='y', labelsize=32) |
| 46 | + ax1.tick_params(axis='x', labelsize=32) |
| 47 | + |
| 48 | + # Create a second y-axis for PPL |
| 49 | + ax2 = ax1.twinx() |
| 50 | + |
| 51 | + # Plot Perplexity (PPL) |
| 52 | + for model, color in zip(models, colors): |
| 53 | + ax2.plot(decoding_steps, ppl_delta[model], color=color, linestyle='--', label=f'{model} (PPL Δ)', linewidth=6) |
| 54 | + ax2.scatter(decoding_steps, ppl_delta[model], color=color, marker='s', s=400) |
| 55 | + |
| 56 | + ax2.set_ylabel("Perplexity Delta (PPL Δ)", fontsize=32) |
| 57 | + ax2.tick_params(axis="y", labelsize=32) |
| 58 | + |
| 59 | + # Combine legends |
| 60 | + lines1, labels1 = ax1.get_legend_handles_labels() |
| 61 | + lines2, labels2 = ax2.get_legend_handles_labels() |
| 62 | + ax1.legend(lines1 + lines2, labels1 + labels2, loc='upper left', bbox_to_anchor=(0.05, 0.95), borderaxespad=0.25, fontsize=24) |
| 63 | + |
| 64 | + plt.title("Attention Loss & Perplexity vs Decoding Steps", fontsize=32) |
| 65 | + plt.grid(True) |
| 66 | + plt.tight_layout() |
| 67 | + plt.savefig("/workspace/cold-compress/charts/attention_loss.png") |
0 commit comments