Skip to content

Commit 33b8164

Browse files
committed
Fix bug in multple objects in define_py_sources
1 parent 8295eb9 commit 33b8164

File tree

3 files changed

+15
-2
lines changed

3 files changed

+15
-2
lines changed

python/paddle/trainer_config_helpers/data_sources.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def __is_splitable__(o):
139139
test_obj = obj
140140
train_obj = obj
141141
if __is_splitable__(obj):
142-
train_module, test_module = module
142+
train_obj, test_obj = obj
143143

144144
if args is None:
145145
args = ""

python/paddle/trainer_config_helpers/tests/configs/generate_protostr.sh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@ test_sequence_pooling test_lstmemory_layer test_grumemory_layer
1111
last_first_seq test_expand_layer test_ntm_layers test_hsigmoid
1212
img_layers img_trans_layers util_layers simple_rnn_layers unused_layers test_cost_layers
1313
test_rnn_group shared_fc shared_lstm test_cost_layers_with_weight
14-
test_bilinear_interp test_maxout test_bi_grumemory math_ops)
14+
test_bilinear_interp test_maxout test_bi_grumemory math_ops
15+
test_spilit_datasource)
1516

1617

1718
for conf in ${configs[*]}
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from paddle.trainer_config_helpers import *
2+
3+
define_py_data_sources2(train_list="train.list",
4+
test_list="test.list",
5+
module=["a", "b"],
6+
obj=("c", "d"))
7+
settings(
8+
learning_rate=1e-3,
9+
batch_size=1000
10+
)
11+
12+
outputs(data_layer(name="a", size=10))

0 commit comments

Comments
 (0)