Skip to content

Commit b15a478

Browse files
emailweixuluotao1
authored andcommitted
Correctly handling multiple inputs and integer inputs for recurrent_g… (#114)
* Correctly handling multiple inputs and integer inputs for recurrent_group * Fix ScatterAgentLayer for generation * Revert sequence_(nest)_rnn.conf
1 parent ffc3416 commit b15a478

File tree

10 files changed

+210
-47
lines changed

10 files changed

+210
-47
lines changed

paddle/cuda/src/hl_cuda_cublas.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ void hl_matrix_mul(real *A_d, hl_trans_op_t transa,
217217
} else {
218218
LOG(FATAL) << "parameter transa error!";
219219
}
220-
CHECK_EQ(stat, CUBLAS_STATUS_SUCCESS);
220+
CHECK_EQ(stat, CUBLAS_STATUS_SUCCESS) << hl_cublas_get_error_string(stat);
221221
CHECK_SYNC("hl_matrix_mul failed");
222222
}
223223

@@ -266,7 +266,7 @@ void hl_matrix_mul_vector(real *A_d, hl_trans_op_t trans,
266266
LOG(FATAL) << "parameter transa error!";
267267
}
268268

269-
CHECK_EQ(stat, CUBLAS_STATUS_SUCCESS);
269+
CHECK_EQ(stat, CUBLAS_STATUS_SUCCESS) << hl_cublas_get_error_string(stat);
270270
CHECK_SYNC("hl_matrix_mul_vector");
271271
}
272272

paddle/gserver/gradientmachines/RecurrentGradientMachine.cpp

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -497,20 +497,21 @@ void RecurrentGradientMachine::forward(const std::vector<Argument>& inArgs,
497497
int idSize = 0;
498498
// connect in_links
499499
for (size_t j = 0; j < inFrameLines_.size(); ++j) {
500+
Info& info = info_[shareInlinkInfo ? 0 : j];
500501
// idSize denotes the sum number of tokens in each length i
501-
idSize = info_[j].idIndex[i + 1] - info_[j].idIndex[i];
502+
idSize = info.idIndex[i + 1] - info.idIndex[i];
502503
InFrameLine inFrameLine = inFrameLines_[j];
503504
auto scatterAgent =
504505
dynamic_cast<ScatterAgentLayer*>(inFrameLine.agents[i].get());
505506
scatterAgent->setRealLayerAndOutput(inFrameLine.inLayer,
506-
inFrameLine.outArg, info_[j].allIds,
507-
info_[j].idIndex[i], idSize);
507+
inFrameLine.outArg, info.allIds,
508+
info.idIndex[i], idSize);
508509
if (hasSubseq) {
509510
// size: the length of subsequence
510511
int size =
511-
info_[j].seqStartPosIndex[i + 1] - info_[j].seqStartPosIndex[i];
512-
scatterAgent->setSequenceStartPositions(info_[j].sequenceStartPositions,
513-
info_[j].seqStartPosIndex[i],
512+
info.seqStartPosIndex[i + 1] - info.seqStartPosIndex[i];
513+
scatterAgent->setSequenceStartPositions(info.sequenceStartPositions,
514+
info.seqStartPosIndex[i],
514515
size);
515516
}
516517
}
@@ -744,16 +745,24 @@ void RecurrentGradientMachine::selectRowsOneTime(LayerPtr layer,
744745
const IVectorPtr& allIds,
745746
Argument* arg,
746747
PassType passType) {
747-
const MatrixPtr& realV = layer->getOutputValue();
748-
int height = realV->getHeight();
749-
int width = realV->getWidth();
750-
Matrix::resizeOrCreate(arg->value, height, width, /* trans */ false, useGpu_);
751-
arg->value->zeroMem();
752-
arg->value->selectRows(*realV, *allIds);
753-
if (passType != PASS_TEST) {
754-
Matrix::resizeOrCreate(arg->grad, height, width, /* trans */ false,
755-
useGpu_);
756-
arg->grad->zeroMem();
748+
Argument& src = layer->getOutput();
749+
if (src.value) {
750+
const MatrixPtr& realV = src.value;
751+
int height = realV->getHeight();
752+
int width = realV->getWidth();
753+
Matrix::resizeOrCreate(
754+
arg->value, height, width, /* trans */ false, useGpu_);
755+
arg->value->zeroMem();
756+
arg->value->selectRows(*realV, *allIds);
757+
if (passType != PASS_TEST) {
758+
Matrix::resizeOrCreate(arg->grad, height, width, /* trans */ false,
759+
useGpu_);
760+
arg->grad->zeroMem();
761+
}
762+
}
763+
if (src.ids) {
764+
IVector::resizeOrCreate(arg->ids, src.ids->getSize(), useGpu_);
765+
arg->ids->selectFrom(*src.ids, *allIds);
757766
}
758767
}
759768

paddle/gserver/layers/AgentLayer.cpp

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -139,15 +139,16 @@ void ScatterAgentLayer::forward(PassType passType) {
139139
Layer::forward(passType);
140140
CHECK_EQ(realLayer_->getDeviceId(), this->getDeviceId());
141141

142-
if (realLayer_->getOutput().ids) { // ids scatter
143-
IVector::resizeOrCreate(output_.ids, ids_->getSize(), useGpu_);
144-
output_.ids->selectFrom(*realLayer_->getOutput().ids, *ids_);
145-
} else { // value scatter
146-
int width = this->getSize();
147-
if (realOutArg_.value) {
148-
output_.subArgFrom(realOutArg_, /* offset */ idIndex_ * width, idSize_,
149-
width, useGpu_);
150-
} else { // used in generation
142+
int width = this->getSize();
143+
if (realOutArg_.value || realOutArg_.ids) {
144+
output_.subArgFrom(realOutArg_, /* offset */ idIndex_, idSize_,
145+
width, useGpu_);
146+
} else { // used in generation
147+
if (realLayer_->getOutput().ids) {
148+
IVector::resizeOrCreate(output_.ids, ids_->getSize(), useGpu_);
149+
output_.ids->selectFrom(*realLayer_->getOutput().ids, *ids_);
150+
}
151+
if (realLayer_->getOutput().value) {
151152
int height = ids_->getSize();
152153
resetOutput(height, width);
153154

@@ -213,18 +214,17 @@ void SequenceGatherAgentLayer::forward(PassType passType) {
213214
void SequenceScatterAgentLayer::forward(PassType passType) {
214215
Layer::forward(passType);
215216
CHECK_EQ(realLayer_->getDeviceId(), this->getDeviceId());
216-
CHECK(!realLayer_->getOutput().ids) << "Not supported";
217217

218218
const Argument& input = realLayer_->getOutput();
219-
CHECK_EQ(input.value->getWidth(), this->getSize());
219+
CHECK_EQ(realLayer_->getSize(), this->getSize());
220220
int width = this->getSize();
221221

222222
AsyncGpuBlock asyncGpuBlock;
223223
REGISTER_TIMER_INFO("SequenceAgentLayerForward", getName().c_str());
224224

225-
if (realOutArg_.value) {
225+
if (realOutArg_.value || realOutArg_.ids) {
226226
CHECK(realOutArg_.sequenceStartPositions);
227-
output_.subArgFrom(realOutArg_, /* offset */ idIndex_ * width, idSize_,
227+
output_.subArgFrom(realOutArg_, /* offset */ idIndex_, idSize_,
228228
width, useGpu_, /* trans */ false, /* seqFlag */ true,
229229
/* seqStart */ seqStartPosIndex_,
230230
/* seqSize */ numSequences_);

paddle/gserver/tests/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@ add_test(NAME test_RecurrentGradientMachine
5656
COMMAND .set_python_path.sh -d
5757
${PROJ_ROOT}/python:${PROJ_ROOT}/paddle/gserver/tests
5858
${CMAKE_CURRENT_BINARY_DIR}/test_RecurrentGradientMachine
59-
--use_gpu=false
6059
WORKING_DIRECTORY ${PROJ_ROOT}/paddle)
6160

6261
add_unittest_without_exec(test_NetworkCompare
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
#edit-mode: -*- python -*-
2+
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
from paddle.trainer_config_helpers import *
17+
18+
######################## data source ################################
19+
define_py_data_sources2(train_list='gserver/tests/Sequence/dummy.list',
20+
test_list=None,
21+
module='rnn_data_provider',
22+
obj='process_subseq')
23+
24+
25+
settings(batch_size=2, learning_rate=0.01)
26+
######################## network configure ################################
27+
dict_dim = 10
28+
word_dim = 8
29+
hidden_dim = 8
30+
label_dim = 3
31+
32+
data = data_layer(name="word", size=dict_dim)
33+
34+
emb = embedding_layer(input=data, size=word_dim)
35+
36+
# This hierachical RNN is designed to be equivalent to the simple RNN in
37+
# sequence_rnn.conf
38+
39+
def outer_step(wid, x):
40+
outer_mem = memory(name="outer_rnn_state", size=hidden_dim)
41+
def inner_step(y, wid):
42+
z = embedding_layer(input=wid, size=word_dim)
43+
inner_mem = memory(name="inner_rnn_state",
44+
size=hidden_dim,
45+
boot_layer=outer_mem)
46+
out = fc_layer(input=[y, z, inner_mem],
47+
size=hidden_dim,
48+
act=TanhActivation(),
49+
bias_attr=True,
50+
name="inner_rnn_state")
51+
return out
52+
53+
inner_rnn_output = recurrent_group(
54+
step=inner_step,
55+
name="inner",
56+
input=[x, wid])
57+
last = last_seq(input=inner_rnn_output, name="outer_rnn_state")
58+
59+
# "return last" should also work. But currently RecurrentGradientMachine
60+
# does not handle it correctly. Current implementation requires that
61+
# all the out links are from sequences. However, it does not report error
62+
# when the out links are not sequences.
63+
return inner_rnn_output
64+
65+
out = recurrent_group(
66+
name="outer",
67+
step=outer_step,
68+
input=[SubsequenceInput(data), SubsequenceInput(emb)])
69+
70+
rep = last_seq(input=out)
71+
prob = fc_layer(size=label_dim,
72+
input=rep,
73+
act=SoftmaxActivation(),
74+
bias_attr=True)
75+
76+
outputs(classification_cost(input=prob,
77+
label=data_layer(name="label", size=label_dim)))
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
#edit-mode: -*- python -*-
2+
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
from paddle.trainer_config_helpers import *
17+
18+
######################## data source ################################
19+
define_py_data_sources2(train_list='gserver/tests/Sequence/dummy.list',
20+
test_list=None,
21+
module='rnn_data_provider',
22+
obj='process_seq')
23+
24+
25+
settings(batch_size=2, learning_rate=0.01)
26+
######################## network configure ################################
27+
dict_dim = 10
28+
word_dim = 8
29+
hidden_dim = 8
30+
label_dim = 3
31+
32+
data = data_layer(name="word", size=dict_dim)
33+
34+
emb = embedding_layer(input=data, size=word_dim)
35+
36+
def step(y, wid):
37+
z = embedding_layer(input=wid, size=word_dim)
38+
mem = memory(name="rnn_state", size=hidden_dim)
39+
out = fc_layer(input=[y, z, mem],
40+
size=hidden_dim,
41+
act=TanhActivation(),
42+
bias_attr=True,
43+
name="rnn_state")
44+
return out
45+
46+
out = recurrent_group(
47+
name="rnn",
48+
step=step,
49+
input=[emb, data])
50+
51+
rep = last_seq(input=out)
52+
prob = fc_layer(size=label_dim,
53+
input=rep,
54+
act=SoftmaxActivation(),
55+
bias_attr=True)
56+
57+
outputs(classification_cost(input=prob,
58+
label=data_layer(name="label", size=label_dim)))

paddle/gserver/tests/test_RecurrentGradientMachine.cpp

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,11 @@ void CalCost(const string& conf, const string& dir, real* cost,
9292
rmDir(dir.c_str());
9393
}
9494

95-
void test(const string& conf1, const string& conf2, double eps) {
95+
void test(const string& conf1, const string& conf2, double eps, bool useGpu) {
96+
if (!paddle::version::isWithGpu() && useGpu) {
97+
return;
98+
}
99+
FLAGS_use_gpu = useGpu;
96100
int num_passes = 5;
97101
real* cost1 = new real[num_passes];
98102
const string dir1 = "gserver/tests/t1";
@@ -113,17 +117,28 @@ void test(const string& conf1, const string& conf2, double eps) {
113117
}
114118

115119
TEST(RecurrentGradientMachine, HasSubSequence) {
116-
test("gserver/tests/sequence_layer_group.conf",
117-
"gserver/tests/sequence_nest_layer_group.conf",
118-
1e-5);
120+
for (bool useGpu : {false, true}) {
121+
test("gserver/tests/sequence_layer_group.conf",
122+
"gserver/tests/sequence_nest_layer_group.conf",
123+
1e-5, useGpu);
124+
}
119125
}
120126

121127
TEST(RecurrentGradientMachine, rnn) {
122-
test("gserver/tests/sequence_rnn.conf",
123-
"gserver/tests/sequence_nest_rnn.conf",
124-
0);
128+
for (bool useGpu : {false, true}) {
129+
test("gserver/tests/sequence_rnn.conf",
130+
"gserver/tests/sequence_nest_rnn.conf",
131+
1e-6, useGpu);
132+
}
125133
}
126134

135+
TEST(RecurrentGradientMachine, rnn_multi_input) {
136+
for (bool useGpu : {false, true}) {
137+
test("gserver/tests/sequence_rnn_multi_input.conf",
138+
"gserver/tests/sequence_nest_rnn_multi_input.conf",
139+
1e-6, useGpu);
140+
}
141+
}
127142

128143
int main(int argc, char** argv) {
129144
if (paddle::version::isWithPyDataProvider()) {

paddle/parameter/Argument.cpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -554,11 +554,16 @@ void Argument::degradeSequence(const Argument& input, bool useGpu) {
554554
void Argument::subArgFrom(const Argument& input, size_t offset, size_t height,
555555
size_t width, bool useGpu, bool trans, bool seqFlag,
556556
size_t seqStart, size_t seqSize) {
557-
value = Matrix::create(input.value->getData() + offset, height, width, trans,
558-
useGpu);
557+
if (input.value) {
558+
value = Matrix::create(input.value->getData() + offset * width,
559+
height, width, trans, useGpu);
560+
}
561+
if (input.ids) {
562+
ids = IVector::create(input.ids->getData() + offset, height, useGpu);
563+
}
559564
if (input.grad) {
560-
grad = Matrix::create(input.grad->getData() + offset, height, width, trans,
561-
useGpu);
565+
grad = Matrix::create(input.grad->getData() + offset * width,
566+
height, width, trans, useGpu);
562567
}
563568
if (seqFlag) {
564569
sequenceStartPositions = std::make_shared<ICpuGpuVector>(

paddle/parameter/Argument.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -177,11 +177,11 @@ struct Argument {
177177
}
178178

179179
/**
180-
* @brief (value, grad, sequenceStartPositions) of output are subset of
180+
* @brief (value, ids, grad, sequenceStartPositions) of output are subset of
181181
* input. Note that, output share the same memory of input.
182182
*
183183
* @param input[in] input
184-
* @param offset[in] offset of input.value
184+
* @param offset[in] offset in terms of rows
185185
* @param height[in] height of output.value
186186
* @param width[in] width of output.value
187187
* @param useGpu[in]

python/paddle/trainer_config_helpers/layers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ def check_input(input):
216216
"""
217217

218218
if isinstance(input, LayerOutput):
219-
return [LayerOutput]
219+
return [input]
220220
assert isinstance(input, list)
221221
for inp in input:
222222
assert isinstance(inp, LayerOutput)
@@ -764,7 +764,7 @@ def print_layer(input, name=None):
764764
:type input: LayerOutput|list|tuple
765765
:return: No return
766766
"""
767-
check_input(input)
767+
input = check_input(input)
768768

769769
Layer(
770770
name=name,

0 commit comments

Comments
 (0)