Skip to content

Commit 7484915

Browse files
reyoungJiayiFeng
authored andcommitted
Add LoDRankTable (#5349)
* Add LoDRankTable LoD Rank Table stores the `level` of `lod` which is ordered by sequence length in descending order. It is useful when implement dynamic RNN and is shared by dynamic RNN memory, dynamic RNN slice input and dynamic RNN slice output operators. * Add InferVarType
1 parent 73632de commit 7484915

File tree

13 files changed

+249
-3
lines changed

13 files changed

+249
-3
lines changed

paddle/framework/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,9 @@ add_custom_command(TARGET framework_py_proto POST_BUILD
4545

4646
cc_library(backward SRCS backward.cc DEPS net_op)
4747
cc_test(backward_test SRCS backward_test.cc DEPS backward recurrent_op device_context fill_constant_op)
48+
cc_library(lod_rank_table SRCS lod_rank_table.cc DEPS lod_tensor)
4849

49-
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto backward glog)
50+
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto backward glog lod_rank_table)
5051

5152
cc_library(prune SRCS prune.cc DEPS framework_proto)
5253
cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context)

paddle/framework/executor.cc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ limitations under the License. */
2121
#include <vector>
2222

2323
#include "paddle/framework/feed_fetch_type.h"
24+
#include "paddle/framework/lod_rank_table.h"
2425
#include "paddle/framework/lod_tensor.h"
2526
#include "paddle/framework/op_registry.h"
2627
#include "paddle/framework/scope.h"
@@ -70,10 +71,12 @@ static void CreateTensor(Variable* var, VarDesc::VarType var_type) {
7071
var->GetMutable<FeedFetchList>();
7172
} else if (var_type == VarDesc::STEP_SCOPES) {
7273
var->GetMutable<std::vector<framework::Scope>>();
74+
} else if (var_type == VarDesc::LOD_RANK_TABLE) {
75+
var->GetMutable<LoDRankTable>();
7376
} else {
7477
PADDLE_THROW(
7578
"Variable type %d is not in "
76-
"[LoDTensor, SelectedRows, FEED_MINIBATCH, FETCH_LIST]",
79+
"[LoDTensor, SelectedRows, FEED_MINIBATCH, FETCH_LIST, LOD_RANK_TABLE]",
7780
var_type);
7881
}
7982
}

paddle/framework/framework.proto

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ message VarDesc {
116116
FEED_MINIBATCH = 3;
117117
FETCH_LIST = 4;
118118
STEP_SCOPES = 5;
119+
LOD_RANK_TABLE = 6;
119120
}
120121
required string name = 1;
121122
required VarType type = 2;

paddle/framework/lod_rank_table.cc

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/framework/lod_rank_table.h"
16+
17+
namespace paddle {
18+
namespace framework {
19+
void LoDRankTable::Reset(const LoD& lod, size_t level) {
20+
this->coarse_lod_.clear();
21+
this->items_.clear();
22+
PADDLE_ENFORCE(level < lod.size(),
23+
"Cannot rank lod since the level %d is less than lod size %d",
24+
level, lod.size());
25+
coarse_lod_.reserve(level);
26+
for (size_t i = 0; i < level; ++i) {
27+
coarse_lod_.push_back(lod[i]);
28+
}
29+
auto& vec = lod[level];
30+
for (size_t i = 0; i < vec.size() - 1; ++i) {
31+
TableItem item;
32+
item.index = i;
33+
item.length = vec[i + 1] - vec[i];
34+
items_.emplace_back(item);
35+
}
36+
std::sort(items_.begin(), items_.end(),
37+
[](const TableItem& a, const TableItem& b) {
38+
return a.length > b.length;
39+
});
40+
}
41+
42+
} // namespace framework
43+
} // namespace paddle

paddle/framework/lod_rank_table.h

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
#include "paddle/framework/lod_tensor.h"
17+
18+
namespace paddle {
19+
namespace framework {
20+
21+
// LoD Rank Table stores the `level` of `lod` which is ordered by sequence
22+
// length in descending order. It is useful when implement dynamic RNN and is
23+
// shared by dynamic RNN memory, dynamic RNN slice input and dynamic RNN slice
24+
// output operators.
25+
//
26+
// The table item contains two element. The length of sequence and the index of
27+
// sequence in that level.
28+
//
29+
// LoDRankTable also stores the coarse_lod, which is the lod information whose
30+
// level is less than input level, in order to restore the output LoD
31+
// information.
32+
class LoDRankTable {
33+
public:
34+
struct TableItem {
35+
size_t index;
36+
size_t length;
37+
};
38+
39+
LoDRankTable() {}
40+
41+
void Reset(const LoD& lod, size_t level);
42+
43+
const std::vector<TableItem>& items() const { return this->items_; }
44+
45+
const LoD& coarse_lod() const { return this->coarse_lod_; }
46+
47+
size_t level() const { return coarse_lod_.size(); }
48+
49+
private:
50+
LoD coarse_lod_;
51+
std::vector<TableItem> items_;
52+
};
53+
54+
} // namespace framework
55+
} // namespace paddle

paddle/framework/var_desc.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ limitations under the License. */
1515
#pragma once
1616

1717
#include <vector>
18+
#include "glog/logging.h"
1819
#include "paddle/framework/framework.pb.h"
1920

2021
namespace paddle {

paddle/operators/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ set(DEPS_OPS
141141
pool_with_index_op
142142
nccl_op
143143
sequence_conv_op
144+
lod_rank_table_op
144145
lstm_op)
145146

146147
op_library(cond_op SRCS cond_op.cc DEPS framework_proto tensor operator net_op)
@@ -149,6 +150,7 @@ op_library(softmax_with_cross_entropy_op DEPS cross_entropy softmax)
149150
op_library(sum_op DEPS net_op selected_rows_functor)
150151
op_library(pool_op DEPS pooling)
151152
op_library(pool_with_index_op DEPS pooling)
153+
op_library(lod_rank_table_op SRCS lod_rank_table_op.cc DEPS lod_rank_table)
152154
if(WITH_GPU)
153155
op_library(nccl_op DEPS nccl_common)
154156
endif()

paddle/operators/lod_rank_table_op.cc

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
#include "paddle/framework/lod_rank_table.h"
15+
#include "paddle/framework/op_registry.h"
16+
namespace paddle {
17+
namespace operators {
18+
19+
class LoDRankTableOp : public framework::OperatorBase {
20+
public:
21+
LoDRankTableOp(const std::string &type,
22+
const framework::VariableNameMap &inputs,
23+
const framework::VariableNameMap &outputs,
24+
const framework::AttributeMap &attrs)
25+
: OperatorBase(type, inputs, outputs, attrs) {}
26+
void Run(const framework::Scope &scope,
27+
const platform::DeviceContext &dev_ctx) const override {
28+
auto x = scope.FindVar(Input("X"))->Get<framework::LoDTensor>();
29+
auto *out =
30+
scope.FindVar(Output("Out"))->GetMutable<framework::LoDRankTable>();
31+
out->Reset(x.lod(), static_cast<size_t>(Attr<int>("level")));
32+
}
33+
};
34+
35+
class LoDRankTableOpProtoMaker : public framework::OpProtoAndCheckerMaker {
36+
public:
37+
LoDRankTableOpProtoMaker(framework::OpProto *proto,
38+
framework::OpAttrChecker *op_checker)
39+
: OpProtoAndCheckerMaker(proto, op_checker) {
40+
AddInput("X",
41+
"(LoDTensor) input lod tensor, must contain lod information.");
42+
AddOutput("Out", "(LoDRankTable) The rank table of specific level.");
43+
AddAttr<int>("level", "(int) the specific lod level to rank.")
44+
.SetDefault(0)
45+
.EqualGreaterThan(0);
46+
AddComment(R"DOC(Create LoDRanTable by LoDTensor
47+
48+
LoD Rank Table stores the `level` of `lod` which is ordered by sequence
49+
length in descending order. It is useful when implement dynamic RNN and is
50+
shared by dynamic RNN memory, dynamic RNN slice input and dynamic RNN slice
51+
output operators.
52+
)DOC");
53+
}
54+
};
55+
56+
class LoDRankTableInferShape : public framework::InferShapeBase {
57+
public:
58+
void operator()(framework::InferShapeContext *context) const override {
59+
PADDLE_ENFORCE(context->HasInput("X"), "LoDRankTable must has input X");
60+
}
61+
};
62+
63+
class LoDRankTableInferVarType : public framework::VarTypeInference {
64+
public:
65+
void operator()(const framework::OpDescBind &op_desc,
66+
framework::BlockDescBind *block) const override {
67+
for (auto &o : op_desc.Output("Out")) {
68+
block->Var(o)->SetType(framework::VarDesc::LOD_RANK_TABLE);
69+
}
70+
}
71+
};
72+
73+
} // namespace operators
74+
} // namespace paddle
75+
76+
REGISTER_OPERATOR(lod_rank_table, paddle::operators::LoDRankTableOp,
77+
paddle::operators::LoDRankTableOpProtoMaker,
78+
paddle::operators::LoDRankTableInferShape,
79+
paddle::operators::LoDRankTableInferVarType,
80+
paddle::framework::EmptyGradOpMaker);

paddle/pybind/protobuf.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,8 @@ void BindVarDsec(py::module &m) {
238238
.value("SELECTED_ROWS", VarDesc::SELECTED_ROWS)
239239
.value("FEED_MINIBATCH", VarDesc::FEED_MINIBATCH)
240240
.value("FETCH_LIST", VarDesc::FETCH_LIST)
241-
.value("STEP_SCOPES", VarDesc::STEP_SCOPES);
241+
.value("STEP_SCOPES", VarDesc::STEP_SCOPES)
242+
.value("LOD_RANK_TABLE", VarDesc::LOD_RANK_TABLE);
242243
}
243244

244245
void BindOpDesc(py::module &m) {

paddle/pybind/pybind.cc

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ limitations under the License. */
2121
#include "paddle/framework/executor.h"
2222
#include "paddle/framework/feed_fetch_method.h"
2323
#include "paddle/framework/framework.pb.h"
24+
#include "paddle/framework/lod_rank_table.h"
2425
#include "paddle/framework/lod_tensor.h"
2526
#include "paddle/framework/prune.h"
2627
#include "paddle/framework/selected_rows.h"
@@ -224,6 +225,9 @@ All parameter, weight, gradient are variables in Paddle.
224225
return self.GetMutable<LoDTensor>();
225226
},
226227
py::return_value_policy::reference)
228+
.def("get_lod_rank_table",
229+
[](Variable &self) { return self.GetMutable<LoDRankTable>(); },
230+
py::return_value_policy::reference)
227231
.def("get_selected_rows",
228232
[](Variable &self) -> SelectedRows * {
229233
return self.GetMutable<SelectedRows>();
@@ -492,6 +496,15 @@ All parameter, weight, gradient are variables in Paddle.
492496
BindVarDsec(m);
493497
BindOpDesc(m);
494498

499+
py::class_<framework::LoDRankTable>(m, "LodRankTable")
500+
.def("items", [](framework::LoDRankTable &table) {
501+
std::vector<std::pair<size_t, size_t>> res;
502+
for (auto &item : table.items()) {
503+
res.push_back({item.index, item.length});
504+
}
505+
return res;
506+
});
507+
495508
m.def("op_support_gpu", OpSupportGPU);
496509
#ifdef PADDLE_WITH_CUDA
497510
m.def("get_cuda_device_count", platform::GetCUDADeviceCount);

0 commit comments

Comments
 (0)