@@ -70,12 +70,19 @@ def sample(metric, sample_rate):
70
70
return metric_sample
71
71
72
72
73
- def plot_metric (metric , batch_id , graph_title ):
73
+ def plot_metric (metric , batch_id , graph_title , line_style = 'b-' ,
74
+ line_label = 'y' ,
75
+ line_num = 1 ):
74
76
plt .figure ()
75
77
plt .title (graph_title )
76
- plt .plot (batch_id , metric )
78
+ if line_num == 1 :
79
+ plt .plot (batch_id , metric , line_style , line_label )
80
+ else :
81
+ for i in line_num :
82
+ plt .plot (batch_id , metric [i ], line_style [i ], line_label [i ])
77
83
plt .xlabel ('batch' )
78
84
plt .ylabel (graph_title )
85
+ plt .legend ()
79
86
plt .savefig (graph_title + '.jpg' )
80
87
plt .close ()
81
88
@@ -91,8 +98,8 @@ def main():
91
98
loss_sample = sample (loss , args .sample_rate )
92
99
accuracy_sample = sample (accuracy , args .sample_rate )
93
100
94
- plot_metric (loss_sample , batch_sample , 'loss' )
95
- plot_metric (accuracy_sample , batch_sample , 'accuracy' )
101
+ plot_metric (loss_sample , batch_sample , 'loss' , line_label = 'loss' )
102
+ plot_metric (accuracy_sample , batch_sample , 'accuracy' , line_style = 'g-' , line_label = 'accuracy' )
96
103
97
104
98
105
if __name__ == '__main__' :
0 commit comments