Skip to content

Commit 892cc82

Browse files
authored
Merge pull request #1766 from reyoung/feature/add_list_type_of_feeding
Add list type of feeding
2 parents fbea391 + 66f5052 commit 892cc82

File tree

2 files changed

+30
-8
lines changed

2 files changed

+30
-8
lines changed

python/paddle/v2/data_feeder.py

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
from py_paddle import DataProviderConverter
16-
16+
import collections
1717
import paddle.trainer.PyDataProvider2 as pydp2
1818

1919
__all__ = ['DataFeeder']
@@ -35,15 +35,30 @@ class DataFeeder(DataProviderConverter):
3535
DataFeeder converts this mini-batch data entries into Arguments in order
3636
to feed it to C++ interface.
3737
38-
The example usage:
38+
The simple usage shows below
39+
40+
.. code-block:: python
41+
42+
feeding = ['image', 'label']
43+
data_types = enumerate_data_types_of_data_layers(topology)
44+
feeder = DataFeeder(data_types=data_types, feeding=feeding)
45+
46+
minibatch_data = [([1.0, 2.0, 3.0, ...], 5)]
47+
48+
arg = feeder(minibatch_data)
49+
50+
51+
If mini-batch data and data layers are not one to one mapping, we
52+
could pass a dictionary to feeding parameter to represent the mapping
53+
relationship.
3954
4055
4156
.. code-block:: python
4257
4358
data_types = [('image', paddle.data_type.dense_vector(784)),
4459
('label', paddle.data_type.integer_value(10))]
45-
reader_dict = {'image':0, 'label':1}
46-
feeder = DataFeeder(data_types=data_types, reader_dict=reader_dict)
60+
feeding = {'image':0, 'label':1}
61+
feeder = DataFeeder(data_types=data_types, feeding=feeding)
4762
minibatch_data = [
4863
( [1.0,2.0,3.0,4.0], 5, [6,7,8] ), # first sample
4964
( [1.0,2.0,3.0,4.0], 5, [6,7,8] ) # second sample
@@ -65,16 +80,23 @@ class DataFeeder(DataProviderConverter):
6580
a tuple of (data_name, data_type).
6681
6782
:type data_types: list
68-
:param reader_dict: A dictionary to specify the position of each data
69-
in the input data.
70-
:type feeding: dict
83+
:param feeding: A dictionary or a sequence to specify the position of each
84+
data in the input data.
85+
:type feeding: dict|collections.Sequence|None
7186
"""
7287

7388
def __init__(self, data_types, feeding=None):
7489
self.input_names = []
7590
input_types = []
7691
if feeding is None:
7792
feeding = default_feeding_map(data_types)
93+
elif isinstance(feeding, collections.Sequence):
94+
feed_list = feeding
95+
feeding = dict()
96+
for i, name in enumerate(feed_list):
97+
feeding[name] = i
98+
elif not isinstance(feeding, dict):
99+
raise TypeError("Feeding should be dict or sequence or None.")
78100

79101
self.feeding = feeding
80102
for each in data_types:

python/paddle/v2/trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def train(self, reader, num_passes=1, event_handler=None, feeding=None):
8181
:type event_handler: (BaseEvent) => None
8282
:param feeding: Feeding is a map of neural network input name and array
8383
index that reader returns.
84-
:type feeding: dict
84+
:type feeding: dict|list
8585
:return:
8686
"""
8787
if event_handler is None:

0 commit comments

Comments
 (0)