Skip to content

Commit f845010

Browse files
feat(charts): add plot_metric_trend method
1 parent ba4fdb3 commit f845010

File tree

1 file changed

+50
-29
lines changed

1 file changed

+50
-29
lines changed

charts/plot_metric_trend.py

Lines changed: 50 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,58 +1,79 @@
11
import pandas as pd
22
import plotly.express as px
3+
import numpy as np
4+
from scipy.interpolate import make_interp_spline
35

46
def plot_metric_trend(df, x, y):
5-
fig = px.line(df, x=x, y=y, color='epoch', markers=True)
7+
8+
fig = px.line(df, x=x, y=y, color='max length', markers=True, color_discrete_sequence=['#925EB0', '#7E99F4', '#CC7C71', '#7AB656']) # A5AEB7
69

710
fig.update_xaxes(tickvals=df[x], ticktext=[f'{int(val * 100)}%' for val in df[x].unique()])
811

12+
avg = df.groupby(x)[y].mean().reset_index()
13+
avg['max length'] = 'Average'
14+
15+
x_smooth = np.linspace(avg[x].min(), avg[x].max(), 500)
16+
spline = make_interp_spline(avg[x], avg[y], k=3) # k=3 for cubic spline
17+
y_smooth = spline(x_smooth)
18+
19+
# Add the smooth average line to the plot
20+
fig.add_scatter(x=x_smooth, y=y_smooth, mode='lines', name='Smooth Average',
21+
line={
22+
'color': 'red',
23+
'width': 4,
24+
'dash': 'solid'
25+
})
26+
927
fig.update_layout(xaxis_title='Percentage',
10-
yaxis_title='Agr Average Improvement',
11-
plot_bgcolor='rgba(0,0,0,0)', # 设置绘图区域背景为透明
12-
paper_bgcolor='rgba(0,0,0,0)', # 设置整个图表背景为透明
28+
yaxis_title='Agr Average',
29+
plot_bgcolor='rgba(0,0,0,0)',
30+
paper_bgcolor='rgba(0,0,0,0)',
1331
xaxis={
1432
'showgrid': True,
1533
'gridcolor': 'lightgray',
1634
'linewidth': 1,
1735
'linecolor': 'lightgray',
18-
'showline': True, # 显示x轴线
19-
'mirror': True, # 在对面也显示轴线
36+
'showline': True,
37+
'mirror': True,
2038
},
2139
yaxis={
2240
'showgrid': True,
2341
'gridcolor': 'lightgray',
2442
'linewidth': 1,
2543
'linecolor': 'lightgray',
26-
'showline': True, # 显示y轴线
27-
'mirror': True, # 在对面也显示轴线
44+
'showline': True,
45+
'mirror': True,
2846
})
2947

3048
fig.show()
31-
# TODO: 添加baseline的线
3249

3350
if __name__ == "__main__":
34-
epoch_1 = {
35-
"percentage": [0.05, 0.15, 0.3, 0.4, 0.6, 1],
36-
"agr average": [2.16, 0.48, 1.62, 1.77, 2.99, 2.19]
51+
data = {
52+
"max length 512": {
53+
"percentage": [0, 0.05, 0.15, 0.3, 0.4, 0.6, 0.8, 1],
54+
"agr average": [0, 0.79, 2.078, 1.755, 1.247, 3.447, 2.967, 2.175]
55+
},
56+
"max length 1024": {
57+
"percentage": [0, 0.05, 0.15, 0.3, 0.4, 0.6, 0.8, 1],
58+
"agr average": [0, 1.266, 1.399, 1.723, 1.247, 2.581, 2.581, 1.291] # TODO: 0.4, 0.6
59+
},
60+
"max length 1536": {
61+
"percentage": [0, 0.05, 0.15, 0.3, 0.4, 0.6, 0.8, 1],
62+
"agr average": [0, 0.983, 1.622, 2.563, 1.453, 1.326, 1.225, 1.885]
63+
},
64+
"max length 2048": {
65+
"percentage": [0, 0.05, 0.15, 0.3, 0.4, 0.6, 0.8, 1],
66+
"agr average": [0, 0.365, 0.967, 2.231, 1.256, 1.616, 2.052, 1.529]
67+
}
3768
}
38-
epoch_2 = {
39-
"percentage": [0.05, 0.15, 0.3, 0.4, 0.6, 1],
40-
"agr average": [1.78, 0.52, 2.23, 2.21, 1.62, 1.53]
41-
}
42-
epoch_3 = {
43-
"percentage": [0.05, 0.15, 0.3, 0.4, 0.6, 1],
44-
"agr average": [1.39, 0.21, 1.60, 1.76, 1.31, 1.80]
45-
}
46-
47-
df_epoch_1 = pd.DataFrame(epoch_1)
48-
df_epoch_1['epoch'] = 'Epoch 1'
4969

50-
df_epoch_2 = pd.DataFrame(epoch_2)
51-
df_epoch_2['epoch'] = 'Epoch 2'
70+
df_list = []
71+
for length, values in data.items():
72+
df_temp = pd.DataFrame(values)
73+
df_temp['max length'] = length
74+
df_list.append(df_temp)
5275

53-
df_epoch_3 = pd.DataFrame(epoch_3)
54-
df_epoch_3['epoch'] = 'Epoch 3'
76+
df = pd.concat(df_list, ignore_index=True)
5577

56-
df_combined = pd.concat([df_epoch_1, df_epoch_2, df_epoch_3])
78+
plot_metric_trend(df, 'percentage', 'agr average')
5779

58-
plot_metric_trend(df_combined, x='percentage', y='agr average')

0 commit comments

Comments
 (0)