Skip to content

Commit c88f98c

Browse files
authored
Merge pull request #5501 from reyoung/feature/lod_array_length
Add `lod_array_length` operator
2 parents 03fa1ed + d24d8c2 commit c88f98c

File tree

3 files changed

+101
-0
lines changed

3 files changed

+101
-0
lines changed
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/framework/lod_tensor_array.h"
16+
#include "paddle/framework/op_registry.h"
17+
18+
namespace paddle {
19+
namespace operators {
20+
21+
class LoDArrayLengthOp : public framework::OperatorBase {
22+
public:
23+
LoDArrayLengthOp(const std::string &type,
24+
const framework::VariableNameMap &inputs,
25+
const framework::VariableNameMap &outputs,
26+
const framework::AttributeMap &attrs)
27+
: OperatorBase(type, inputs, outputs, attrs) {}
28+
void Run(const framework::Scope &scope,
29+
const platform::DeviceContext &dev_ctx) const override {
30+
auto &x = scope.FindVar(Input("X"))->Get<framework::LoDTensorArray>();
31+
auto &out =
32+
*scope.FindVar(Output("Out"))->GetMutable<framework::LoDTensor>();
33+
out.Resize({1});
34+
auto cpu = platform::CPUPlace();
35+
*out.mutable_data<int64_t>(cpu) = static_cast<int64_t>(x.size());
36+
}
37+
};
38+
39+
class LoDArrayLengthProtoMaker : public framework::OpProtoAndCheckerMaker {
40+
public:
41+
LoDArrayLengthProtoMaker(framework::OpProto *proto,
42+
framework::OpAttrChecker *op_checker)
43+
: OpProtoAndCheckerMaker(proto, op_checker) {
44+
AddInput("X", "(LoDTensorArray) The input tensor array.");
45+
AddOutput("Out", "(Tensor) 1x1 CPU Tensor of length, int64_t");
46+
AddComment(R"DOC(Get the length of lod tensor array
47+
48+
Out = len(X)
49+
50+
NOTE: The output is a CPU Tensor since the control variable should be only in
51+
CPU and the length of LoDTensorArray should be used as control variables.
52+
)DOC");
53+
}
54+
};
55+
56+
class LoDArrayLengthInferShape : public framework::InferShapeBase {
57+
public:
58+
void operator()(framework::InferShapeContext *context) const override {
59+
PADDLE_ENFORCE(context->HasInput("X"));
60+
PADDLE_ENFORCE(context->HasOutput("Out"));
61+
context->SetOutputDim("Out", {1});
62+
}
63+
};
64+
65+
} // namespace operators
66+
} // namespace paddle
67+
68+
namespace ops = paddle::operators;
69+
REGISTER_OPERATOR(lod_array_length, ops::LoDArrayLengthOp,
70+
ops::LoDArrayLengthInferShape, ops::LoDArrayLengthProtoMaker,
71+
paddle::framework::EmptyGradOpMaker);

python/paddle/v2/framework/layers.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -947,3 +947,12 @@ def shrink_memory(x, i, table, main_program=None):
947947
outputs={'Out': [out]},
948948
attrs={})
949949
return out
950+
951+
952+
def array_length(array, main_program=None):
953+
helper = LayerHelper('array_length', **locals())
954+
tmp = helper.create_tmp_variable(dtype='int64')
955+
tmp.stop_gradient = True
956+
helper.append_op(
957+
type='lod_array_length', inputs={'X': [array]}, outputs={'Out': [tmp]})
958+
return tmp
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import unittest
2+
import paddle.v2.framework.layers as layers
3+
from paddle.v2.framework.executor import Executor
4+
import paddle.v2.framework.core as core
5+
import numpy
6+
7+
8+
class TestLoDArrayLength(unittest.TestCase):
9+
def test_array_length(self):
10+
tmp = layers.zeros(shape=[10], dtype='int32')
11+
i = layers.fill_constant(shape=[1], dtype='int64', value=10)
12+
arr = layers.array_write(tmp, i=i)
13+
arr_len = layers.array_length(arr)
14+
cpu = core.CPUPlace()
15+
exe = Executor(cpu)
16+
result = numpy.array(exe.run(fetch_list=[arr_len])[0])
17+
self.assertEqual(11, result[0])
18+
19+
20+
if __name__ == '__main__':
21+
unittest.main()

0 commit comments

Comments
 (0)