Skip to content

Commit 9acfc21

Browse files
committed
Faster DataProvider Converter
1 parent 50434cb commit 9acfc21

File tree

2 files changed

+35
-11
lines changed

2 files changed

+35
-11
lines changed

demo/mnist/.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,6 @@ plot.png
55
train.log
66
*pyc
77
.ipynb_checkpoints
8+
*.w0
9+
*.wbias
10+
*.bin

paddle/py_paddle/dataprovider_converter.py

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,12 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import paddle.trainer.PyDataProvider2 as dp2
1615
import collections
16+
import itertools
17+
18+
import paddle.trainer.PyDataProvider2 as dp2
19+
1720
import swig_paddle
18-
import numpy
1921

2022
__all__ = ['DataProviderConverter']
2123

@@ -26,6 +28,12 @@ def __init__(self, input_type, pos):
2628
assert isinstance(self.input_type, dp2.InputType)
2729
self.pos = pos
2830

31+
def pre_scan_loop(self, dat):
32+
pass
33+
34+
def finish_pre_scan(self, argument):
35+
pass
36+
2937
def scan(self, dat):
3038
pass
3139

@@ -37,18 +45,24 @@ class DenseScanner(IScanner):
3745
def __init__(self, input_type, pos):
3846
IScanner.__init__(self, input_type, pos)
3947
self.__mat__ = None
48+
self.__height__ = 0
49+
50+
def pre_scan_loop(self, dat):
51+
self.__height__ += 1
52+
53+
def finish_pre_scan(self, argument):
54+
self.__mat__ = swig_paddle.Matrix.createZero(self.__height__,
55+
self.input_type.dim, False)
56+
self.__height__ = 0
4057

4158
def scan(self, dat):
42-
if self.__mat__ is None:
43-
self.__mat__ = numpy.array([dat], dtype='float32')
44-
else:
45-
self.__mat__ = numpy.append(self.__mat__, [dat], axis=0)
59+
assert isinstance(self.__mat__, swig_paddle.Matrix)
60+
a = self.__mat__.toNumpyMatInplace()
61+
a[self.__height__, ] = dat
62+
self.__height__ += 1
4663

4764
def finish_scan(self, argument):
48-
assert isinstance(argument, swig_paddle.Arguments)
49-
assert isinstance(self.input_type, dp2.InputType)
50-
m = swig_paddle.Matrix.createDenseFromNumpy(self.__mat__, True, False)
51-
argument.setSlotValue(self.pos, m)
65+
argument.setSlotValue(self.pos, self.__mat__)
5266

5367

5468
class SparseBinaryScanner(IScanner):
@@ -146,7 +160,14 @@ def convert(self, dat, argument=None):
146160
]
147161

148162
for each_sample in dat:
149-
for each_step, scanner in zip(each_sample, scanners):
163+
for each_step, scanner in itertools.izip(each_sample, scanners):
164+
scanner.pre_scan_loop(each_step)
165+
166+
for scanner in scanners:
167+
scanner.finish_pre_scan(argument)
168+
169+
for each_sample in dat:
170+
for each_step, scanner in itertools.izip(each_sample, scanners):
150171
scanner.scan(each_step)
151172

152173
for scanner in scanners:

0 commit comments

Comments
 (0)