Skip to content

Commit 77c8ddb

Browse files
committed
Move python/paddle/v2/plot/plot.py to /utils
1 parent d23c3ff commit 77c8ddb

File tree

2 files changed

+85
-1
lines changed

2 files changed

+85
-1
lines changed

python/paddle/utils/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
__all__ = ['dump_config']
15+
from plot import Ploter
16+
17+
__all__ = ['dump_config', 'Ploter']

python/paddle/utils/plot.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
17+
18+
class PlotData(object):
19+
def __init__(self):
20+
self.step = []
21+
self.value = []
22+
23+
def append(self, step, value):
24+
self.step.append(step)
25+
self.value.append(value)
26+
27+
def reset(self):
28+
self.step = []
29+
self.value = []
30+
31+
32+
class Ploter(object):
33+
def __init__(self, *args):
34+
self.__args__ = args
35+
self.__plot_data__ = {}
36+
for title in args:
37+
self.__plot_data__[title] = PlotData()
38+
# demo in notebooks will use Ploter to plot figure, but when we convert
39+
# the ipydb to py file for testing, the import of matplotlib will make the
40+
# script crash. So we can use `export DISABLE_PLOT=True` to disable import
41+
# these libs
42+
self.__disable_plot__ = os.environ.get("DISABLE_PLOT")
43+
if not self.__plot_is_disabled__():
44+
import matplotlib.pyplot as plt
45+
from IPython import display
46+
self.plt = plt
47+
self.display = display
48+
49+
def __plot_is_disabled__(self):
50+
return self.__disable_plot__ == "True"
51+
52+
def append(self, title, step, value):
53+
assert isinstance(title, basestring)
54+
assert self.__plot_data__.has_key(title)
55+
data = self.__plot_data__[title]
56+
assert isinstance(data, PlotData)
57+
data.append(step, value)
58+
59+
def plot(self, path=None):
60+
if self.__plot_is_disabled__():
61+
return
62+
63+
titles = []
64+
for title in self.__args__:
65+
data = self.__plot_data__[title]
66+
assert isinstance(data, PlotData)
67+
if len(data.step) > 0:
68+
titles.append(title)
69+
self.plt.plot(data.step, data.value)
70+
self.plt.legend(titles, loc='upper left')
71+
if path is None:
72+
self.display.clear_output(wait=True)
73+
self.display.display(self.plt.gcf())
74+
else:
75+
self.plt.savefig(path)
76+
self.plt.gcf().clear()
77+
78+
def reset(self):
79+
for key in self.__plot_data__:
80+
data = self.__plot_data__[key]
81+
assert isinstance(data, PlotData)
82+
data.reset()

0 commit comments

Comments
 (0)