Skip to content

Commit 1f4f044

Browse files
committed
A tiny fix in PyDataProvider2
* hidden decorator kwargs in DataProvider.__init__ * also add unit test for this.
1 parent 2965df5 commit 1f4f044

File tree

2 files changed

+9
-5
lines changed

2 files changed

+9
-5
lines changed

paddle/gserver/tests/test_PyDataProvider2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from paddle.trainer.PyDataProvider2 import *
1818

1919

20-
@provider(input_types=[dense_vector(200, seq_type=SequenceType.NO_SEQUENCE)])
20+
@provider(slots=[dense_vector(200, seq_type=SequenceType.NO_SEQUENCE)])
2121
def test_dense_no_seq(setting, filename):
2222
for i in xrange(200):
2323
yield [(float(j - 100) * float(i + 1)) / 200.0 for j in xrange(200)]

python/paddle/trainer/PyDataProvider2.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ def provider(input_types=None,
232232
check=False,
233233
check_fail_continue=False,
234234
init_hook=None,
235-
**kwargs):
235+
**outter_kwargs):
236236
"""
237237
Provider decorator. Use it to make a function into PyDataProvider2 object.
238238
In this function, user only need to get each sample for some train/test
@@ -318,11 +318,15 @@ def __init__(self, file_list, **kwargs):
318318
self.logger = logging.getLogger("")
319319
self.logger.setLevel(logging.INFO)
320320
self.input_types = None
321-
if 'slots' in kwargs:
321+
if 'slots' in outter_kwargs:
322322
self.logger.warning('setting slots value is deprecated, '
323323
'please use input_types instead.')
324-
self.slots = kwargs['slots']
325-
self.slots = input_types
324+
self.slots = outter_kwargs['slots']
325+
if input_types is not None:
326+
self.slots = input_types
327+
328+
assert self.slots is not None, \
329+
"Data Provider's input_types must be set"
326330
self.should_shuffle = should_shuffle
327331

328332
true_table = [1, 't', 'true', 'on']

0 commit comments

Comments
 (0)