Skip to content

Commit 6cde889

Browse files
authored
Add unittest, backward of array read/write op (#5409)
* Use stable_sort in lod_rank_table It is easy to debug and test when use `stable_sort`and the time complexity is not changed. * Add LoDTensorArray * Stash * Better debug message for IsInitialized * Stash * Better debug message for IsInitialized * Complete array read/write op unittests * Add unittest, Gradient of array read/write * Follow comments
1 parent b25804c commit 6cde889

File tree

11 files changed

+210
-24
lines changed

11 files changed

+210
-24
lines changed

paddle/framework/op_desc.cc

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,11 @@ class CompileTimeInferShapeContext : public InferShapeContext {
6767
out);
6868
in_var->SetLoDLevel(out_var->GetLodLevel());
6969
}
70+
bool IsRuntime() const override;
71+
72+
protected:
73+
VarDesc::VarType GetVarType(const std::string &name) const override;
7074

71-
private:
7275
DDim GetDim(const std::string &name) const override;
7376

7477
void SetDim(const std::string &name, const DDim &dim) override;
@@ -451,6 +454,12 @@ void CompileTimeInferShapeContext::SetDim(const std::string &name,
451454
const DDim &dim) {
452455
block_.FindVarRecursive(name)->SetShape(framework::vectorize(dim));
453456
}
457+
bool CompileTimeInferShapeContext::IsRuntime() const { return false; }
458+
459+
VarDesc::VarType CompileTimeInferShapeContext::GetVarType(
460+
const std::string &name) const {
461+
return block_.FindVarRecursive(name)->GetType();
462+
}
454463

455464
} // namespace framework
456465
} // namespace paddle

paddle/framework/operator.cc

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@ limitations under the License. */
1515
#include "paddle/framework/operator.h"
1616
#include <algorithm>
1717
#include <atomic>
18+
#include "paddle/framework/lod_tensor_array.h"
1819
#include "paddle/framework/shape_inference.h"
20+
#include "paddle/framework/var_type.h"
1921

2022
namespace paddle {
2123
namespace framework {
@@ -365,7 +367,9 @@ class RuntimeInferShapeContext : public InferShapeContext {
365367
out_tensor->set_lod(in_tensor.lod());
366368
}
367369

368-
private:
370+
bool IsRuntime() const override { return true; }
371+
372+
protected:
369373
DDim GetDim(const std::string& name) const override {
370374
Variable* var = scope_.FindVar(name);
371375
if (var->IsType<LoDTensor>()) {
@@ -388,6 +392,12 @@ class RuntimeInferShapeContext : public InferShapeContext {
388392
}
389393
}
390394

395+
VarDesc::VarType GetVarType(const std::string& name) const override {
396+
auto* var = scope_.FindVar(name);
397+
return ToVarType(var->Type());
398+
}
399+
400+
private:
391401
const OperatorBase& op_;
392402
const Scope& scope_;
393403
};

paddle/framework/shape_inference.cc

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,23 @@ void InferShapeContext::SetDims(const std::vector<std::string> &names,
4646
SetDim(names[i], dims[i]);
4747
}
4848
}
49+
std::vector<VarDesc::VarType> InferShapeContext::GetInputsVarType(
50+
const std::string &name) const {
51+
return GetVarTypes(Inputs(name));
52+
}
53+
std::vector<VarDesc::VarType> InferShapeContext::GetOutputsVarType(
54+
const std::string &name) const {
55+
return GetVarTypes(Outputs(name));
56+
}
57+
std::vector<VarDesc::VarType> InferShapeContext::GetVarTypes(
58+
const std::vector<std::string> &names) const {
59+
std::vector<VarDesc::VarType> retv;
60+
retv.resize(names.size());
61+
std::transform(names.begin(), names.end(), retv.begin(),
62+
std::bind(std::mem_fn(&InferShapeContext::GetVarType), this,
63+
std::placeholders::_1));
64+
return retv;
65+
}
4966

5067
} // namespace framework
5168
} // namespace paddle

paddle/framework/shape_inference.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ limitations under the License. */
1616

1717
#include "paddle/framework/attribute.h"
1818
#include "paddle/framework/ddim.h"
19+
#include "paddle/framework/framework.pb.h"
1920

2021
namespace paddle {
2122
namespace framework {
@@ -26,6 +27,10 @@ class InferShapeContext {
2627
virtual bool HasInput(const std::string &name) const = 0;
2728
virtual bool HasOutput(const std::string &name) const = 0;
2829

30+
std::vector<VarDesc::VarType> GetInputsVarType(const std::string &name) const;
31+
std::vector<VarDesc::VarType> GetOutputsVarType(
32+
const std::string &name) const;
33+
2934
virtual bool HasInputs(const std::string &name) const = 0;
3035
virtual bool HasOutputs(const std::string &name) const = 0;
3136

@@ -46,6 +51,8 @@ class InferShapeContext {
4651
virtual void ShareLoD(const std::string &in, const std::string &out,
4752
size_t i = 0, size_t j = 0) const = 0;
4853

54+
virtual bool IsRuntime() const = 0;
55+
4956
protected:
5057
virtual framework::DDim GetDim(const std::string &name) const = 0;
5158
virtual void SetDim(const std::string &name, const framework::DDim &dim) = 0;
@@ -55,6 +62,11 @@ class InferShapeContext {
5562

5663
void SetDims(const std::vector<std::string> &names,
5764
const std::vector<framework::DDim> &dims);
65+
66+
std::vector<VarDesc::VarType> GetVarTypes(
67+
const std::vector<std::string> &names) const;
68+
69+
virtual VarDesc::VarType GetVarType(const std::string &name) const = 0;
5870
};
5971

6072
} // namespace framework

paddle/framework/var_type.h

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
#include "paddle/framework/framework.pb.h"
17+
#include "paddle/framework/lod_rank_table.h"
18+
#include "paddle/framework/lod_tensor.h"
19+
#include "paddle/framework/lod_tensor_array.h"
20+
21+
namespace paddle {
22+
namespace framework {
23+
inline VarDesc::VarType ToVarType(std::type_index type) {
24+
if (type.hash_code() == typeid(LoDTensor).hash_code()) {
25+
return VarDesc_VarType_LOD_TENSOR;
26+
} else if (type.hash_code() == typeid(LoDRankTable).hash_code()) {
27+
return VarDesc_VarType_LOD_RANK_TABLE;
28+
} else if (type.hash_code() == typeid(LoDTensorArray).hash_code()) {
29+
return VarDesc_VarType_LOD_TENSOR_ARRAY;
30+
} else {
31+
PADDLE_THROW("ToVarType:Unsupported type %s", type.name());
32+
}
33+
}
34+
35+
} // namespace framework
36+
} // namespace paddle

paddle/framework/variable.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,11 @@ class Variable {
4848

4949
void Clear() { holder_.reset(); }
5050

51+
std::type_index Type() const {
52+
PADDLE_ENFORCE(holder_ != nullptr, "Must hold memory");
53+
return holder_->Type();
54+
}
55+
5156
private:
5257
struct Placeholder {
5358
virtual ~Placeholder() {}

paddle/operators/sum_op.cc

Lines changed: 47 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,16 @@ class SumOp : public framework::OperatorWithKernel {
2424

2525
void InferShape(framework::InferShapeContext* ctx) const override {
2626
PADDLE_ENFORCE(ctx->HasInputs("X"), "Inputs(X) should not be null");
27-
auto x_dims = ctx->GetInputsDim("X");
27+
2828
PADDLE_ENFORCE(ctx->HasOutput("Out"),
2929
"Output(Out) of SumOp should not be null.");
30+
if (ctx->IsRuntime() &&
31+
ctx->GetOutputsVarType("Out")[0] ==
32+
framework::VarDesc::LOD_TENSOR_ARRAY) {
33+
return; // skip runtime infershape when is tensor array;
34+
}
3035

36+
auto x_dims = ctx->GetInputsDim("X");
3137
size_t N = x_dims.size();
3238
PADDLE_ENFORCE_GT(N, 1, "Input tensors count should > 1.");
3339

@@ -39,6 +45,28 @@ class SumOp : public framework::OperatorWithKernel {
3945
ctx->SetOutputDim("Out", in_dim);
4046
ctx->ShareLoD("X", /*->*/ "Out");
4147
}
48+
49+
protected:
50+
framework::DataType IndicateDataType(
51+
const framework::ExecutionContext& ctx) const override {
52+
auto x_vars = ctx.MultiInputVar("X");
53+
if (x_vars[0]->IsType<framework::LoDTensor>()) {
54+
return framework::ToDataType(
55+
x_vars[0]->Get<framework::LoDTensor>().type());
56+
} else if (x_vars[0]->IsType<framework::SelectedRows>()) {
57+
return framework::ToDataType(
58+
x_vars[0]->Get<framework::SelectedRows>().value().type());
59+
} else if (x_vars[0]->IsType<framework::LoDTensorArray>()) {
60+
auto& array = x_vars[0]->Get<framework::LoDTensorArray>();
61+
for (auto& each : array) {
62+
if (each.numel() != 0) {
63+
return framework::ToDataType(each.type());
64+
}
65+
}
66+
}
67+
PADDLE_THROW("Unexpected branch. Input type is %s",
68+
x_vars[0]->Type().name());
69+
}
4270
};
4371

4472
class SumOpMaker : public framework::OpProtoAndCheckerMaker {
@@ -63,18 +91,32 @@ class SumOpVarTypeInference : public framework::VarTypeInference {
6391
void operator()(const framework::OpDescBind& op_desc,
6492
framework::BlockDescBind* block) const override {
6593
auto& inputs = op_desc.Input("X");
66-
auto default_var_type = framework::VarDesc::SELECTED_ROWS;
94+
auto var_type = framework::VarDesc::SELECTED_ROWS;
6795

6896
bool any_input_is_lod_tensor = std::any_of(
6997
inputs.begin(), inputs.end(), [block](const std::string& name) {
7098
return block->Var(name)->GetType() == framework::VarDesc::LOD_TENSOR;
7199
});
72-
if (any_input_is_lod_tensor) {
73-
default_var_type = framework::VarDesc::LOD_TENSOR;
100+
101+
auto is_tensor_array = [block](const std::string& name) {
102+
return block->Var(name)->GetType() ==
103+
framework::VarDesc::LOD_TENSOR_ARRAY;
104+
};
105+
106+
bool any_input_is_tensor_array =
107+
std::any_of(inputs.begin(), inputs.end(), is_tensor_array);
108+
bool all_inputs_are_tensor_array =
109+
std::all_of(inputs.begin(), inputs.end(), is_tensor_array);
110+
111+
if (any_input_is_tensor_array) {
112+
PADDLE_ENFORCE(all_inputs_are_tensor_array);
113+
var_type = framework::VarDesc::LOD_TENSOR_ARRAY;
114+
} else if (any_input_is_lod_tensor) {
115+
var_type = framework::VarDesc::LOD_TENSOR;
74116
}
75117

76118
auto out_var_name = op_desc.Output("Out").front();
77-
block->Var(out_var_name)->SetType(default_var_type);
119+
block->Var(out_var_name)->SetType(var_type);
78120
}
79121
};
80122

paddle/operators/sum_op.h

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ limitations under the License. */
1111

1212
#pragma once
1313
#include "paddle/framework/eigen.h"
14+
#include "paddle/framework/lod_tensor_array.h"
1415
#include "paddle/framework/op_registry.h"
1516
#include "paddle/operators/math/math_function.h"
1617
#include "paddle/operators/math/selected_rows_functor.h"
@@ -28,15 +29,15 @@ using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
2829
template <typename Place, typename T>
2930
class SumKernel : public framework::OpKernel<T> {
3031
public:
31-
void Compute(const framework::ExecutionContext& context) const override {
32+
void Compute(const framework::ExecutionContext &context) const override {
3233
auto in_vars = context.MultiInputVar("X");
3334
int N = in_vars.size();
3435
auto out_var = context.OutputVar("Out");
3536

3637
bool in_place = out_var == in_vars[0];
3738

3839
if (out_var->IsType<framework::LoDTensor>()) {
39-
auto* out = context.Output<Tensor>("Out");
40+
auto *out = context.Output<Tensor>("Out");
4041
out->mutable_data<T>(context.GetPlace());
4142

4243
auto result = EigenVector<T>::Flatten(*out);
@@ -51,20 +52,20 @@ class SumKernel : public framework::OpKernel<T> {
5152
// If in_place, just skip the first tensor
5253
for (int i = in_place ? 1 : 0; i < N; i++) {
5354
if (in_vars[i]->IsType<framework::LoDTensor>()) {
54-
auto& in_t = in_vars[i]->Get<framework::LoDTensor>();
55+
auto &in_t = in_vars[i]->Get<framework::LoDTensor>();
5556
auto in = EigenVector<T>::Flatten(in_t);
5657
result.device(place) = result + in;
5758
} else if (in_vars[i]->IsType<framework::SelectedRows>()) {
58-
auto& in_t = in_vars[i]->Get<framework::SelectedRows>();
59+
auto &in_t = in_vars[i]->Get<framework::SelectedRows>();
5960
functor(context.device_context(), in_t, out);
6061
} else {
6162
PADDLE_THROW("Variable type must be LoDTensor/SelectedRows.");
6263
}
6364
}
6465
} else if (out_var->IsType<framework::SelectedRows>()) {
6566
PADDLE_ENFORCE(!in_place, "SelectedRows not support inplace sum now");
66-
auto* out = context.Output<SelectedRows>("Out");
67-
auto* out_value = out->mutable_value();
67+
auto *out = context.Output<SelectedRows>("Out");
68+
auto *out_value = out->mutable_value();
6869

6970
// Runtime InferShape
7071
size_t first_dim = 0;
@@ -88,9 +89,36 @@ class SumKernel : public framework::OpKernel<T> {
8889
offset, out);
8990
offset += in_vars[i]->Get<SelectedRows>().value().numel();
9091
}
92+
} else if (out_var->IsType<framework::LoDTensorArray>()) {
93+
auto &out_array = *out_var->GetMutable<framework::LoDTensorArray>();
94+
for (size_t i = in_place ? 1 : 0; i < in_vars.size(); ++i) {
95+
PADDLE_ENFORCE(in_vars[i]->IsType<framework::LoDTensorArray>(),
96+
"Only support all inputs are TensorArray");
97+
auto &in_array = in_vars[i]->Get<framework::LoDTensorArray>();
98+
99+
for (size_t i = 0; i < in_array.size(); ++i) {
100+
if (in_array[i].numel() != 0) {
101+
if (i >= out_array.size()) {
102+
out_array.resize(i + 1);
103+
}
104+
if (out_array[i].numel() == 0) {
105+
out_array[i].CopyFrom(in_array[i], in_array[i].place(),
106+
context.device_context());
107+
out_array[i].set_lod(in_array[i].lod());
108+
} else {
109+
PADDLE_ENFORCE(out_array[i].lod() == in_array[i].lod());
110+
auto in = EigenVector<T>::Flatten(in_array[i]);
111+
auto result = EigenVector<T>::Flatten(out_array[i]);
112+
result.device(context.GetEigenDevice<Place>()) = result + in;
113+
}
114+
}
115+
}
116+
}
117+
} else {
118+
PADDLE_THROW("Unexpected branch, output variable type is %s",
119+
out_var->Type().name());
91120
}
92121
}
93122
};
94-
95123
} // namespace operators
96124
} // namespace paddle

paddle/operators/tensor_array_read_write_op.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,6 @@ class WriteToArrayInferVarType : public framework::VarTypeInference {
115115
public:
116116
void operator()(const framework::OpDescBind &op_desc,
117117
framework::BlockDescBind *block) const override {
118-
VLOG(10) << "I am here?";
119118
for (auto &out_var : op_desc.OutputArgumentNames()) {
120119
VLOG(10) << "Set Variable " << out_var << " as LOD_TENSOR_ARRAY";
121120
block->Var(out_var)->SetType(framework::VarDesc::LOD_TENSOR_ARRAY);

python/paddle/v2/framework/layers.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -801,12 +801,13 @@ def zeros(shape, dtype, main_program=None):
801801

802802
def increment(x, value=1.0, main_program=None):
803803
helper = LayerHelper("increment", **locals())
804+
tmp = helper.create_tmp_variable(dtype=x.data_type)
804805
helper.append_op(
805806
type='increment',
806807
inputs={'X': [x]},
807-
outputs={'Out': [x]},
808+
outputs={'Out': [tmp]},
808809
attrs={'step': value})
809-
return x
810+
return tmp
810811

811812

812813
def array_write(x, i, array=None, main_program=None):

0 commit comments

Comments
 (0)