13
13
# limitations under the License.
14
14
15
15
from py_paddle import DataProviderConverter
16
-
16
+ import collections
17
17
import paddle .trainer .PyDataProvider2 as pydp2
18
18
19
19
__all__ = ['DataFeeder' ]
@@ -35,15 +35,30 @@ class DataFeeder(DataProviderConverter):
35
35
DataFeeder converts this mini-batch data entries into Arguments in order
36
36
to feed it to C++ interface.
37
37
38
- The example usage:
38
+ The simple usage shows below
39
+
40
+ .. code-block:: python
41
+
42
+ feeding = ['image', 'label']
43
+ data_types = enumerate_data_types_of_data_layers(topology)
44
+ feeder = DataFeeder(data_types=data_types, feeding=feeding)
45
+
46
+ minibatch_data = [([1.0, 2.0, 3.0, ...], 5)]
47
+
48
+ arg = feeder(minibatch_data)
49
+
50
+
51
+ If mini-batch data and data layers are not one to one mapping, we
52
+ could pass a dictionary to feeding parameter to represent the mapping
53
+ relationship.
39
54
40
55
41
56
.. code-block:: python
42
57
43
58
data_types = [('image', paddle.data_type.dense_vector(784)),
44
59
('label', paddle.data_type.integer_value(10))]
45
- reader_dict = {'image':0, 'label':1}
46
- feeder = DataFeeder(data_types=data_types, reader_dict=reader_dict )
60
+ feeding = {'image':0, 'label':1}
61
+ feeder = DataFeeder(data_types=data_types, feeding=feeding )
47
62
minibatch_data = [
48
63
( [1.0,2.0,3.0,4.0], 5, [6,7,8] ), # first sample
49
64
( [1.0,2.0,3.0,4.0], 5, [6,7,8] ) # second sample
@@ -65,16 +80,23 @@ class DataFeeder(DataProviderConverter):
65
80
a tuple of (data_name, data_type).
66
81
67
82
:type data_types: list
68
- :param reader_dict : A dictionary to specify the position of each data
69
- in the input data.
70
- :type feeding: dict
83
+ :param feeding : A dictionary or a sequence to specify the position of each
84
+ data in the input data.
85
+ :type feeding: dict|collections.Sequence|None
71
86
"""
72
87
73
88
def __init__ (self , data_types , feeding = None ):
74
89
self .input_names = []
75
90
input_types = []
76
91
if feeding is None :
77
92
feeding = default_feeding_map (data_types )
93
+ elif isinstance (feeding , collections .Sequence ):
94
+ feed_list = feeding
95
+ feeding = dict ()
96
+ for i , name in enumerate (feed_list ):
97
+ feeding [name ] = i
98
+ elif not isinstance (feeding , dict ):
99
+ raise TypeError ("Feeding should be dict or sequence or None." )
78
100
79
101
self .feeding = feeding
80
102
for each in data_types :
0 commit comments