|
1 | 1 | import pandas as pd |
2 | 2 | import plotly.express as px |
| 3 | +import numpy as np |
| 4 | +from scipy.interpolate import make_interp_spline |
3 | 5 |
|
4 | 6 | def plot_metric_trend(df, x, y): |
5 | | - fig = px.line(df, x=x, y=y, color='epoch', markers=True) |
| 7 | + |
| 8 | + fig = px.line(df, x=x, y=y, color='max length', markers=True, color_discrete_sequence=['#925EB0', '#7E99F4', '#CC7C71', '#7AB656']) # A5AEB7 |
6 | 9 |
|
7 | 10 | fig.update_xaxes(tickvals=df[x], ticktext=[f'{int(val * 100)}%' for val in df[x].unique()]) |
8 | 11 |
|
| 12 | + avg = df.groupby(x)[y].mean().reset_index() |
| 13 | + avg['max length'] = 'Average' |
| 14 | + |
| 15 | + x_smooth = np.linspace(avg[x].min(), avg[x].max(), 500) |
| 16 | + spline = make_interp_spline(avg[x], avg[y], k=3) # k=3 for cubic spline |
| 17 | + y_smooth = spline(x_smooth) |
| 18 | + |
| 19 | + # Add the smooth average line to the plot |
| 20 | + fig.add_scatter(x=x_smooth, y=y_smooth, mode='lines', name='Smooth Average', |
| 21 | + line={ |
| 22 | + 'color': 'red', |
| 23 | + 'width': 4, |
| 24 | + 'dash': 'solid' |
| 25 | + }) |
| 26 | + |
9 | 27 | fig.update_layout(xaxis_title='Percentage', |
10 | | - yaxis_title='Agr Average Improvement', |
11 | | - plot_bgcolor='rgba(0,0,0,0)', # 设置绘图区域背景为透明 |
12 | | - paper_bgcolor='rgba(0,0,0,0)', # 设置整个图表背景为透明 |
| 28 | + yaxis_title='Agr Average', |
| 29 | + plot_bgcolor='rgba(0,0,0,0)', |
| 30 | + paper_bgcolor='rgba(0,0,0,0)', |
13 | 31 | xaxis={ |
14 | 32 | 'showgrid': True, |
15 | 33 | 'gridcolor': 'lightgray', |
16 | 34 | 'linewidth': 1, |
17 | 35 | 'linecolor': 'lightgray', |
18 | | - 'showline': True, # 显示x轴线 |
19 | | - 'mirror': True, # 在对面也显示轴线 |
| 36 | + 'showline': True, |
| 37 | + 'mirror': True, |
20 | 38 | }, |
21 | 39 | yaxis={ |
22 | 40 | 'showgrid': True, |
23 | 41 | 'gridcolor': 'lightgray', |
24 | 42 | 'linewidth': 1, |
25 | 43 | 'linecolor': 'lightgray', |
26 | | - 'showline': True, # 显示y轴线 |
27 | | - 'mirror': True, # 在对面也显示轴线 |
| 44 | + 'showline': True, |
| 45 | + 'mirror': True, |
28 | 46 | }) |
29 | 47 |
|
30 | 48 | fig.show() |
31 | | - # TODO: 添加baseline的线 |
32 | 49 |
|
33 | 50 | if __name__ == "__main__": |
34 | | - epoch_1 = { |
35 | | - "percentage": [0.05, 0.15, 0.3, 0.4, 0.6, 1], |
36 | | - "agr average": [2.16, 0.48, 1.62, 1.77, 2.99, 2.19] |
| 51 | + data = { |
| 52 | + "max length 512": { |
| 53 | + "percentage": [0, 0.05, 0.15, 0.3, 0.4, 0.6, 0.8, 1], |
| 54 | + "agr average": [0, 0.79, 2.078, 1.755, 1.247, 3.447, 2.967, 2.175] |
| 55 | + }, |
| 56 | + "max length 1024": { |
| 57 | + "percentage": [0, 0.05, 0.15, 0.3, 0.4, 0.6, 0.8, 1], |
| 58 | + "agr average": [0, 1.266, 1.399, 1.723, 1.247, 2.581, 2.581, 1.291] # TODO: 0.4, 0.6 |
| 59 | + }, |
| 60 | + "max length 1536": { |
| 61 | + "percentage": [0, 0.05, 0.15, 0.3, 0.4, 0.6, 0.8, 1], |
| 62 | + "agr average": [0, 0.983, 1.622, 2.563, 1.453, 1.326, 1.225, 1.885] |
| 63 | + }, |
| 64 | + "max length 2048": { |
| 65 | + "percentage": [0, 0.05, 0.15, 0.3, 0.4, 0.6, 0.8, 1], |
| 66 | + "agr average": [0, 0.365, 0.967, 2.231, 1.256, 1.616, 2.052, 1.529] |
| 67 | + } |
37 | 68 | } |
38 | | - epoch_2 = { |
39 | | - "percentage": [0.05, 0.15, 0.3, 0.4, 0.6, 1], |
40 | | - "agr average": [1.78, 0.52, 2.23, 2.21, 1.62, 1.53] |
41 | | - } |
42 | | - epoch_3 = { |
43 | | - "percentage": [0.05, 0.15, 0.3, 0.4, 0.6, 1], |
44 | | - "agr average": [1.39, 0.21, 1.60, 1.76, 1.31, 1.80] |
45 | | - } |
46 | | - |
47 | | - df_epoch_1 = pd.DataFrame(epoch_1) |
48 | | - df_epoch_1['epoch'] = 'Epoch 1' |
49 | 69 |
|
50 | | - df_epoch_2 = pd.DataFrame(epoch_2) |
51 | | - df_epoch_2['epoch'] = 'Epoch 2' |
| 70 | + df_list = [] |
| 71 | + for length, values in data.items(): |
| 72 | + df_temp = pd.DataFrame(values) |
| 73 | + df_temp['max length'] = length |
| 74 | + df_list.append(df_temp) |
52 | 75 |
|
53 | | - df_epoch_3 = pd.DataFrame(epoch_3) |
54 | | - df_epoch_3['epoch'] = 'Epoch 3' |
| 76 | + df = pd.concat(df_list, ignore_index=True) |
55 | 77 |
|
56 | | - df_combined = pd.concat([df_epoch_1, df_epoch_2, df_epoch_3]) |
| 78 | + plot_metric_trend(df, 'percentage', 'agr average') |
57 | 79 |
|
58 | | - plot_metric_trend(df_combined, x='percentage', y='agr average') |
|
0 commit comments