Skip to content

Commit 4a93486

Browse files
authored
Merge pull request #14031 from shippingwang/fix_plot_1.0
Fix plot 1.0
2 parents 587f3dd + 7931f8d commit 4a93486

File tree

2 files changed

+118
-1
lines changed

2 files changed

+118
-1
lines changed

python/paddle/utils/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
__all__ = ['dump_config']
15+
from plot import Ploter
16+
17+
__all__ = ['dump_config', 'Ploter']

python/paddle/utils/plot.py

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
17+
18+
class PlotData(object):
19+
def __init__(self):
20+
self.step = []
21+
self.value = []
22+
23+
def append(self, step, value):
24+
self.step.append(step)
25+
self.value.append(value)
26+
27+
def reset(self):
28+
self.step = []
29+
self.value = []
30+
31+
32+
class Ploter(object):
33+
"""
34+
Plot input data in a 2D graph
35+
36+
Args:
37+
title: assign the title of input data.
38+
step: x_axis of the data.
39+
value: y_axis of the data.
40+
"""
41+
42+
def __init__(self, *args):
43+
self.__args__ = args
44+
self.__plot_data__ = {}
45+
for title in args:
46+
self.__plot_data__[title] = PlotData()
47+
# demo in notebooks will use Ploter to plot figure, but when we convert
48+
# the ipydb to py file for testing, the import of matplotlib will make the
49+
# script crash. So we can use `export DISABLE_PLOT=True` to disable import
50+
# these libs
51+
self.__disable_plot__ = os.environ.get("DISABLE_PLOT")
52+
if not self.__plot_is_disabled__():
53+
import matplotlib.pyplot as plt
54+
from IPython import display
55+
self.plt = plt
56+
self.display = display
57+
58+
def __plot_is_disabled__(self):
59+
return self.__disable_plot__ == "True"
60+
61+
def append(self, title, step, value):
62+
"""
63+
Feed data
64+
65+
Args:
66+
title: assign the group data to this subtitle.
67+
step: the x_axis of data.
68+
value: the y_axis of data.
69+
70+
Examples:
71+
.. code-block:: python
72+
plot_curve = Ploter("Curve 1","Curve 2")
73+
plot_curve.append(title="Curve 1",step=1,value=1)
74+
"""
75+
assert isinstance(title, basestring)
76+
assert self.__plot_data__.has_key(title)
77+
data = self.__plot_data__[title]
78+
assert isinstance(data, PlotData)
79+
data.append(step, value)
80+
81+
def plot(self, path=None):
82+
"""
83+
Plot data in a 2D graph
84+
85+
Args:
86+
path: store the figure to this file path. Defaul None.
87+
88+
Examples:
89+
.. code-block:: python
90+
plot_curve = Ploter()
91+
plot_cure.plot()
92+
"""
93+
if self.__plot_is_disabled__():
94+
return
95+
96+
titles = []
97+
for title in self.__args__:
98+
data = self.__plot_data__[title]
99+
assert isinstance(data, PlotData)
100+
if len(data.step) > 0:
101+
titles.append(title)
102+
self.plt.plot(data.step, data.value)
103+
self.plt.legend(titles, loc='upper left')
104+
if path is None:
105+
self.display.clear_output(wait=True)
106+
self.display.display(self.plt.gcf())
107+
else:
108+
self.plt.savefig(path)
109+
self.plt.gcf().clear()
110+
111+
def reset(self):
112+
for key in self.__plot_data__:
113+
data = self.__plot_data__[key]
114+
assert isinstance(data, PlotData)
115+
data.reset()

0 commit comments

Comments
 (0)