Skip to content

Commit 2076426

Browse files
author
ranqiu
committed
Add script to plot learning curve
1 parent e8a96a8 commit 2076426

File tree

1 file changed

+87
-0
lines changed

1 file changed

+87
-0
lines changed

benchmark/paddle/image/plotlog.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
#coding=utf-8
2+
3+
import sys
4+
import argparse
5+
import matplotlib.pyplot as plt
6+
7+
8+
def parse_args():
9+
parser = argparse.ArgumentParser('Parse Log')
10+
parser.add_argument(
11+
'--file_path', '-f', type=str, help='the path of the log file')
12+
parser.add_argument(
13+
'--sample_rate',
14+
'-s',
15+
type=float,
16+
default=1.0,
17+
help='the rate to take samples from log')
18+
parser.add_argument(
19+
'--log_period', '-p', type=int, default=1, help='the period of log')
20+
21+
args = parser.parse_args()
22+
return args
23+
24+
25+
def parse_file(file_name):
26+
loss = []
27+
error = []
28+
with open(file_name) as f:
29+
for i, line in enumerate(f):
30+
line = line.strip()
31+
if not line.startswith('pass'):
32+
continue
33+
line_split = line.split(' ')
34+
if len(line_split) != 5:
35+
continue
36+
37+
loss_str = line_split[2][:-1]
38+
cur_loss = float(loss_str.split('=')[-1])
39+
loss.append(cur_loss)
40+
41+
err_str = line_split[3][:-1]
42+
cur_err = float(err_str.split('=')[-1])
43+
error.append(cur_err)
44+
45+
accuracy = [1.0 - err for err in error]
46+
47+
return loss, accuracy
48+
49+
50+
def sample(metric, sample_rate):
51+
interval = int(1.0 / sample_rate)
52+
if interval > len(metric):
53+
return metric[:1]
54+
55+
num = len(metric) / interval
56+
idx = [interval * i for i in range(num)]
57+
metric_sample = [metric[id] for id in idx]
58+
return metric_sample
59+
60+
61+
def plot_metric(metric, batch_id, graph_title):
62+
plt.figure()
63+
plt.title(graph_title)
64+
plt.plot(batch_id, metric)
65+
plt.xlabel('batch')
66+
plt.ylabel(graph_title)
67+
plt.savefig(graph_title + '.jpg')
68+
plt.close()
69+
70+
71+
def main():
72+
args = parse_args()
73+
assert args.sample_rate > 0. and args.sample_rate <= 1.0, "The sample rate should in the range (0, 1]."
74+
75+
loss, accuracy = parse_file(args.file_path)
76+
batch = [args.log_period * i for i in range(len(loss))]
77+
78+
batch_sample = sample(batch, args.sample_rate)
79+
loss_sample = sample(loss, args.sample_rate)
80+
accuracy_sample = sample(accuracy, args.sample_rate)
81+
82+
plot_metric(loss_sample, batch_sample, 'loss')
83+
plot_metric(accuracy_sample, batch_sample, 'accuracy')
84+
85+
86+
if __name__ == '__main__':
87+
main()

0 commit comments

Comments
 (0)