Skip to content

Commit 665eb01

Browse files
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into bi_tensor_prod_op
2 parents ab41648 + 5fe9746 commit 665eb01

File tree

4 files changed

+179
-17
lines changed

4 files changed

+179
-17
lines changed

paddle/platform/call_once.h

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,20 +27,22 @@ namespace platform {
2727
2828
This wrap is a hack to avoid this bug.
2929
*/
30-
template <class Callable, class... Args>
30+
template <typename Callable, typename... Args>
3131
inline void call_once(std::once_flag& flag, Callable&& f, Args&&... args) {
3232
bool good = false;
3333
std::exception ex;
34-
std::call_once(flag, [&]() {
35-
try {
36-
f(args...);
37-
good = true;
38-
} catch (const std::exception& e) {
39-
ex = e;
40-
} catch (...) {
41-
ex = std::runtime_error("excption caught in call_once");
42-
}
43-
});
34+
std::call_once(flag,
35+
[&](Args&&... args) {
36+
try {
37+
f(args...);
38+
good = true;
39+
} catch (const std::exception& e) {
40+
ex = e;
41+
} catch (...) {
42+
ex = std::runtime_error("excption caught in call_once");
43+
}
44+
},
45+
args...);
4446
if (!good) {
4547
throw std::exception(ex);
4648
}

python/paddle/v2/framework/layer_helper.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from paddle.v2.framework.framework import Variable, g_main_program, \
55
g_startup_program, unique_name, Program
66
from paddle.v2.framework.initializer import ConstantInitializer, \
7-
UniformInitializer
7+
UniformInitializer, XavierInitializer
88

99

1010
class LayerHelper(object):
@@ -61,7 +61,7 @@ def input(self, input_param_name='input'):
6161

6262
@property
6363
def param_attr(self):
64-
default = {'name': None, 'initializer': UniformInitializer()}
64+
default = {'name': None, 'initializer': XavierInitializer()}
6565
actual = self.kwargs.get('param_attr', None)
6666
if actual is None:
6767
actual = default
@@ -70,10 +70,11 @@ def param_attr(self):
7070
actual[default_field] = default[default_field]
7171
return actual
7272

73+
@property
7374
def bias_attr(self):
74-
default = {'name': None, 'initializer': ConstantInitializer()}
75+
default = {'name': None, 'initializer': XavierInitializer()}
7576
bias_attr = self.kwargs.get('bias_attr', None)
76-
if bias_attr is True:
77+
if bias_attr is None:
7778
bias_attr = default
7879

7980
if isinstance(bias_attr, dict):
@@ -166,7 +167,7 @@ def append_bias_op(self, input_var, num_flatten_dims=None):
166167
num_flatten_dims = 1
167168

168169
size = list(input_var.shape[num_flatten_dims:])
169-
bias_attr = self.bias_attr()
170+
bias_attr = self.bias_attr
170171
if not bias_attr:
171172
return input_var
172173

python/paddle/v2/framework/layers.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
def fc(input,
1717
size,
1818
param_attr=None,
19-
bias_attr=True,
19+
bias_attr=None,
2020
name=None,
2121
act=None,
2222
num_flatten_dims=1,
@@ -125,6 +125,55 @@ def embedding(input,
125125
return tmp
126126

127127

128+
# TODO(qijun): expose H0 and C0
129+
def dynamic_lstm(input,
130+
size,
131+
data_type='float32',
132+
param_attr=None,
133+
bias_attr=None,
134+
use_peepholes=True,
135+
is_reverse=False,
136+
gate_activation='sigmoid',
137+
cell_activation='tanh',
138+
candidate_activation='tanh',
139+
main_program=None,
140+
startup_program=None):
141+
helper = LayerHelper('lstm', **locals())
142+
size = size / 4
143+
weight = helper.create_parameter(
144+
attr=helper.param_attr, shape=[size, 4 * size], dtype=data_type)
145+
bias_size = [1, 7 * size]
146+
if not use_peepholes:
147+
bias_size[1] = 4 * size
148+
bias = helper.create_parameter(
149+
attr=helper.bias_attr, shape=bias_size, dtype=data_type, suffix='b')
150+
151+
hidden = helper.create_tmp_variable(data_type)
152+
cell = helper.create_tmp_variable(data_type)
153+
batch_gate = helper.create_tmp_variable(data_type)
154+
batch_cell_pre_act = helper.create_tmp_variable(data_type)
155+
156+
helper.append_op(
157+
type='lstm',
158+
inputs={'Input': input,
159+
'Weight': weight,
160+
'Bias': bias},
161+
outputs={
162+
'Hidden': hidden,
163+
'Cell': cell,
164+
'BatchGate': batch_gate,
165+
'BatchCellPreAct': batch_cell_pre_act
166+
},
167+
attrs={
168+
'use_peepholes': use_peepholes,
169+
'is_reverse': is_reverse,
170+
'gate_activation': gate_activation,
171+
'cell_activation': cell_activation,
172+
'candidate_activation': candidate_activation
173+
})
174+
return hidden, cell
175+
176+
128177
def data(name,
129178
shape,
130179
data_type='float32',
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
import paddle.v2 as paddle
2+
import paddle.v2.framework.layers as layers
3+
import paddle.v2.framework.nets as nets
4+
import paddle.v2.framework.core as core
5+
import paddle.v2.framework.optimizer as optimizer
6+
7+
from paddle.v2.framework.framework import Program, g_main_program, g_startup_program
8+
from paddle.v2.framework.executor import Executor
9+
10+
import numpy as np
11+
12+
13+
def stacked_lstm_net(input_dim,
14+
class_dim=2,
15+
emb_dim=128,
16+
hid_dim=512,
17+
stacked_num=3):
18+
assert stacked_num % 2 == 1
19+
data = layers.data(name="words", shape=[1], data_type="int64")
20+
label = layers.data(name="label", shape=[1], data_type="int64")
21+
22+
emb = layers.embedding(input=data, size=[input_dim, emb_dim])
23+
# add bias attr
24+
25+
# TODO(qijun) linear act
26+
fc1 = layers.fc(input=emb, size=hid_dim)
27+
lstm1, cell1 = layers.dynamic_lstm(input=fc1, size=hid_dim)
28+
29+
inputs = [fc1, lstm1]
30+
31+
for i in range(2, stacked_num + 1):
32+
fc = layers.fc(input=inputs, size=hid_dim)
33+
lstm, cell = layers.dynamic_lstm(
34+
input=fc, size=hid_dim, is_reverse=(i % 2) == 0)
35+
inputs = [fc, lstm]
36+
37+
fc_last = layers.sequence_pool(input=inputs[0], pool_type='max')
38+
lstm_last = layers.sequence_pool(input=inputs[1], pool_type='max')
39+
40+
prediction = layers.fc(input=[fc_last, lstm_last],
41+
size=class_dim,
42+
act='softmax')
43+
cost = layers.cross_entropy(input=prediction, label=label)
44+
avg_cost = layers.mean(x=cost)
45+
adam_optimizer = optimizer.AdamOptimizer(learning_rate=0.002)
46+
opts = adam_optimizer.minimize(avg_cost)
47+
acc = layers.accuracy(input=prediction, label=label)
48+
return avg_cost, acc
49+
50+
51+
def to_lodtensor(data, place):
52+
seq_lens = [len(seq) for seq in data]
53+
cur_len = 0
54+
lod = [cur_len]
55+
for l in seq_lens:
56+
cur_len += l
57+
lod.append(cur_len)
58+
flattened_data = np.concatenate(data, axis=0).astype("int64")
59+
flattened_data = flattened_data.reshape([len(flattened_data), 1])
60+
res = core.LoDTensor()
61+
res.set(flattened_data, place)
62+
res.set_lod([lod])
63+
return res
64+
65+
66+
def main():
67+
BATCH_SIZE = 100
68+
PASS_NUM = 5
69+
70+
word_dict = paddle.dataset.imdb.word_dict()
71+
print "load word dict successfully"
72+
dict_dim = len(word_dict)
73+
class_dim = 2
74+
75+
cost, acc = stacked_lstm_net(input_dim=dict_dim, class_dim=class_dim)
76+
77+
train_data = paddle.batch(
78+
paddle.reader.shuffle(
79+
paddle.dataset.imdb.train(word_dict), buf_size=1000),
80+
batch_size=BATCH_SIZE)
81+
place = core.CPUPlace()
82+
exe = Executor(place)
83+
84+
exe.run(g_startup_program)
85+
86+
for pass_id in xrange(PASS_NUM):
87+
for data in train_data():
88+
tensor_words = to_lodtensor(map(lambda x: x[0], data), place)
89+
90+
label = np.array(map(lambda x: x[1], data)).astype("int64")
91+
label = label.reshape([BATCH_SIZE, 1])
92+
93+
tensor_label = core.LoDTensor()
94+
tensor_label.set(label, place)
95+
96+
outs = exe.run(g_main_program,
97+
feed={"words": tensor_words,
98+
"label": tensor_label},
99+
fetch_list=[cost, acc])
100+
cost_val = np.array(outs[0])
101+
acc_val = np.array(outs[1])
102+
103+
print("cost=" + str(cost_val) + " acc=" + str(acc_val))
104+
if cost_val < 1.0 and acc_val > 0.7:
105+
exit(0)
106+
exit(1)
107+
108+
109+
if __name__ == '__main__':
110+
main()

0 commit comments

Comments
 (0)