Skip to content

Commit e8483dd

Browse files
authored
Merge pull request #7311 from ranqiu92/plot
Add script to plot learning curve
2 parents 6ecbf08 + 893a15f commit e8483dd

File tree

1 file changed

+114
-0
lines changed

1 file changed

+114
-0
lines changed

benchmark/paddle/image/plotlog.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import sys
16+
import argparse
17+
import matplotlib.pyplot as plt
18+
19+
20+
def parse_args():
21+
parser = argparse.ArgumentParser('Parse Log')
22+
parser.add_argument(
23+
'--file_path', '-f', type=str, help='the path of the log file')
24+
parser.add_argument(
25+
'--sample_rate',
26+
'-s',
27+
type=float,
28+
default=1.0,
29+
help='the rate to take samples from log')
30+
parser.add_argument(
31+
'--log_period', '-p', type=int, default=1, help='the period of log')
32+
33+
args = parser.parse_args()
34+
return args
35+
36+
37+
def parse_file(file_name):
38+
loss = []
39+
error = []
40+
with open(file_name) as f:
41+
for i, line in enumerate(f):
42+
line = line.strip()
43+
if not line.startswith('pass'):
44+
continue
45+
line_split = line.split(' ')
46+
if len(line_split) != 5:
47+
continue
48+
49+
loss_str = line_split[2][:-1]
50+
cur_loss = float(loss_str.split('=')[-1])
51+
loss.append(cur_loss)
52+
53+
err_str = line_split[3][:-1]
54+
cur_err = float(err_str.split('=')[-1])
55+
error.append(cur_err)
56+
57+
accuracy = [1.0 - err for err in error]
58+
59+
return loss, accuracy
60+
61+
62+
def sample(metric, sample_rate):
63+
interval = int(1.0 / sample_rate)
64+
if interval > len(metric):
65+
return metric[:1]
66+
67+
num = len(metric) / interval
68+
idx = [interval * i for i in range(num)]
69+
metric_sample = [metric[id] for id in idx]
70+
return metric_sample
71+
72+
73+
def plot_metric(metric,
74+
batch_id,
75+
graph_title,
76+
line_style='b-',
77+
line_label='y',
78+
line_num=1):
79+
plt.figure()
80+
plt.title(graph_title)
81+
if line_num == 1:
82+
plt.plot(batch_id, metric, line_style, label=line_label)
83+
else:
84+
for i in range(line_num):
85+
plt.plot(batch_id, metric[i], line_style[i], label=line_label[i])
86+
plt.xlabel('batch')
87+
plt.ylabel(graph_title)
88+
plt.legend()
89+
plt.savefig(graph_title + '.jpg')
90+
plt.close()
91+
92+
93+
def main():
94+
args = parse_args()
95+
assert args.sample_rate > 0. and args.sample_rate <= 1.0, "The sample rate should in the range (0, 1]."
96+
97+
loss, accuracy = parse_file(args.file_path)
98+
batch = [args.log_period * i for i in range(len(loss))]
99+
100+
batch_sample = sample(batch, args.sample_rate)
101+
loss_sample = sample(loss, args.sample_rate)
102+
accuracy_sample = sample(accuracy, args.sample_rate)
103+
104+
plot_metric(loss_sample, batch_sample, 'loss', line_label='loss')
105+
plot_metric(
106+
accuracy_sample,
107+
batch_sample,
108+
'accuracy',
109+
line_style='g-',
110+
line_label='accuracy')
111+
112+
113+
if __name__ == '__main__':
114+
main()

0 commit comments

Comments
 (0)