12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
15
- import paddle .trainer .PyDataProvider2 as dp2
16
15
import collections
16
+ import itertools
17
+
18
+ import paddle .trainer .PyDataProvider2 as dp2
19
+
17
20
import swig_paddle
18
- import numpy
19
21
20
22
__all__ = ['DataProviderConverter' ]
21
23
@@ -26,6 +28,12 @@ def __init__(self, input_type, pos):
26
28
assert isinstance (self .input_type , dp2 .InputType )
27
29
self .pos = pos
28
30
31
+ def pre_scan_loop (self , dat ):
32
+ pass
33
+
34
+ def finish_pre_scan (self , argument ):
35
+ pass
36
+
29
37
def scan (self , dat ):
30
38
pass
31
39
@@ -37,18 +45,24 @@ class DenseScanner(IScanner):
37
45
def __init__ (self , input_type , pos ):
38
46
IScanner .__init__ (self , input_type , pos )
39
47
self .__mat__ = None
48
+ self .__height__ = 0
49
+
50
+ def pre_scan_loop (self , dat ):
51
+ self .__height__ += 1
52
+
53
+ def finish_pre_scan (self , argument ):
54
+ self .__mat__ = swig_paddle .Matrix .createZero (self .__height__ ,
55
+ self .input_type .dim , False )
56
+ self .__height__ = 0
40
57
41
58
def scan (self , dat ):
42
- if self .__mat__ is None :
43
- self .__mat__ = numpy . array ([ dat ], dtype = 'float32' )
44
- else :
45
- self .__mat__ = numpy . append ( self . __mat__ , [ dat ], axis = 0 )
59
+ assert isinstance ( self .__mat__ , swig_paddle . Matrix )
60
+ a = self .__mat__ . toNumpyMatInplace ( )
61
+ a [ self . __height__ , ] = dat
62
+ self .__height__ += 1
46
63
47
64
def finish_scan (self , argument ):
48
- assert isinstance (argument , swig_paddle .Arguments )
49
- assert isinstance (self .input_type , dp2 .InputType )
50
- m = swig_paddle .Matrix .createDenseFromNumpy (self .__mat__ , True , False )
51
- argument .setSlotValue (self .pos , m )
65
+ argument .setSlotValue (self .pos , self .__mat__ )
52
66
53
67
54
68
class SparseBinaryScanner (IScanner ):
@@ -146,7 +160,14 @@ def convert(self, dat, argument=None):
146
160
]
147
161
148
162
for each_sample in dat :
149
- for each_step , scanner in zip (each_sample , scanners ):
163
+ for each_step , scanner in itertools .izip (each_sample , scanners ):
164
+ scanner .pre_scan_loop (each_step )
165
+
166
+ for scanner in scanners :
167
+ scanner .finish_pre_scan (argument )
168
+
169
+ for each_sample in dat :
170
+ for each_step , scanner in itertools .izip (each_sample , scanners ):
150
171
scanner .scan (each_step )
151
172
152
173
for scanner in scanners :
0 commit comments