Skip to content

Commit f04c97a

Browse files
authored
refine test_understand_sentiment_lstm (#5781)
* fix * Fix a bug
1 parent 3e9ea34 commit f04c97a

File tree

1 file changed

+31
-23
lines changed

1 file changed

+31
-23
lines changed

python/paddle/v2/fluid/tests/book/test_understand_sentiment_lstm.py

Lines changed: 31 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -54,51 +54,59 @@ def to_lodtensor(data, place):
5454
return res
5555

5656

57-
def chop_data(data, chop_len=80, batch_len=50):
57+
def chop_data(data, chop_len=80, batch_size=50):
5858
data = [(x[0][:chop_len], x[1]) for x in data if len(x[0]) >= chop_len]
5959

60-
return data[:batch_len]
60+
return data[:batch_size]
6161

6262

6363
def prepare_feed_data(data, place):
6464
tensor_words = to_lodtensor(map(lambda x: x[0], data), place)
6565

6666
label = np.array(map(lambda x: x[1], data)).astype("int64")
67-
label = label.reshape([50, 1])
67+
label = label.reshape([len(label), 1])
6868
tensor_label = core.LoDTensor()
6969
tensor_label.set(label, place)
7070

7171
return tensor_words, tensor_label
7272

7373

7474
def main():
75-
word_dict = paddle.dataset.imdb.word_dict()
76-
cost, acc = lstm_net(dict_dim=len(word_dict), class_dim=2)
75+
BATCH_SIZE = 100
76+
PASS_NUM = 5
7777

78-
batch_size = 100
79-
train_data = paddle.batch(
80-
paddle.reader.buffered(
81-
paddle.dataset.imdb.train(word_dict), size=batch_size * 10),
82-
batch_size=batch_size)
78+
word_dict = paddle.dataset.imdb.word_dict()
79+
print "load word dict successfully"
80+
dict_dim = len(word_dict)
81+
class_dim = 2
8382

84-
data = chop_data(next(train_data()))
83+
cost, acc = lstm_net(dict_dim=dict_dim, class_dim=class_dim)
8584

85+
train_data = paddle.batch(
86+
paddle.reader.shuffle(
87+
paddle.dataset.imdb.train(word_dict), buf_size=BATCH_SIZE * 10),
88+
batch_size=BATCH_SIZE)
8689
place = core.CPUPlace()
87-
tensor_words, tensor_label = prepare_feed_data(data, place)
8890
exe = Executor(place)
91+
8992
exe.run(framework.default_startup_program())
9093

91-
while True:
92-
outs = exe.run(framework.default_main_program(),
93-
feed={"words": tensor_words,
94-
"label": tensor_label},
95-
fetch_list=[cost, acc])
96-
cost_val = np.array(outs[0])
97-
acc_val = np.array(outs[1])
98-
99-
print("cost=" + str(cost_val) + " acc=" + str(acc_val))
100-
if acc_val > 0.9:
101-
break
94+
for pass_id in xrange(PASS_NUM):
95+
for data in train_data():
96+
chopped_data = chop_data(data)
97+
tensor_words, tensor_label = prepare_feed_data(chopped_data, place)
98+
99+
outs = exe.run(framework.default_main_program(),
100+
feed={"words": tensor_words,
101+
"label": tensor_label},
102+
fetch_list=[cost, acc])
103+
cost_val = np.array(outs[0])
104+
acc_val = np.array(outs[1])
105+
106+
print("cost=" + str(cost_val) + " acc=" + str(acc_val))
107+
if acc_val > 0.7:
108+
exit(0)
109+
exit(1)
102110

103111

104112
if __name__ == '__main__':

0 commit comments

Comments
 (0)