Skip to content

Commit 6776e92

Browse files
author
chengduo
authored
refine tensor_array_write_read (#14643)
test=develop
1 parent dfd4a11 commit 6776e92

File tree

8 files changed

+57
-2
lines changed

8 files changed

+57
-2
lines changed

paddle/fluid/framework/op_desc.cc

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,13 +81,35 @@ class CompileTimeInferShapeContext : public InferShapeContext {
8181
"The %s[%d] is @EMPTY@", out, j);
8282
auto *in_var = block_.FindVarRecursive(Inputs(in)[i]);
8383
auto *out_var = block_.FindVarRecursive(Outputs(out)[j]);
84-
if (in_var->GetType() != proto::VarType::LOD_TENSOR) {
85-
VLOG(3) << "input " << in << " is not LodTensor";
84+
if (in_var->GetType() != proto::VarType::LOD_TENSOR &&
85+
in_var->GetType() != proto::VarType::LOD_TENSOR_ARRAY) {
86+
VLOG(3) << "input " << in << " is not LodTensor or LodTensorArray.";
8687
return;
8788
}
8889
out_var->SetLoDLevel(in_var->GetLoDLevel());
8990
}
9091

92+
void DecreaseLoDLevel(const std::string &in, const std::string &out,
93+
size_t i = 0, size_t j = 0) const override {
94+
PADDLE_ENFORCE_LT(i, Inputs(in).size());
95+
PADDLE_ENFORCE_LT(j, Outputs(out).size());
96+
PADDLE_ENFORCE(Inputs(in)[i] != framework::kEmptyVarName,
97+
"The %s[%d] is @EMPTY@", in, i);
98+
PADDLE_ENFORCE(Outputs(out)[j] != framework::kEmptyVarName,
99+
"The %s[%d] is @EMPTY@", out, j);
100+
auto *in_var = block_.FindVarRecursive(Inputs(in)[i]);
101+
auto *out_var = block_.FindVarRecursive(Outputs(out)[j]);
102+
PADDLE_ENFORCE(out_var->GetType() == proto::VarType::LOD_TENSOR_ARRAY ||
103+
out_var->GetType() == proto::VarType::LOD_TENSOR,
104+
"The input %s should be LodTensorArray or LodTensor.",
105+
out_var->Name());
106+
PADDLE_ENFORCE(in_var->GetType() == proto::VarType::LOD_TENSOR,
107+
"The input %s should be LodTensor.", in_var->Name());
108+
if (in_var->GetLoDLevel() > 0) {
109+
out_var->SetLoDLevel(in_var->GetLoDLevel() - 1);
110+
}
111+
}
112+
91113
bool IsRuntime() const override;
92114

93115
protected:

paddle/fluid/framework/operator.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -623,6 +623,11 @@ class RuntimeInferShapeContext : public InferShapeContext {
623623
out_tensor->set_layout(in_tensor.layout());
624624
}
625625

626+
void DecreaseLoDLevel(const std::string& in, const std::string& out,
627+
size_t i = 0, size_t j = 0) const override {
628+
PADDLE_THROW("DecreaseLoDLevel is only used in compile time.");
629+
}
630+
626631
bool IsRuntime() const override { return true; }
627632

628633
protected:

paddle/fluid/framework/shape_inference.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,9 @@ class InferShapeContext {
6262
virtual void ShareLoD(const std::string &in, const std::string &out,
6363
size_t i = 0, size_t j = 0) const = 0;
6464

65+
virtual void DecreaseLoDLevel(const std::string &in, const std::string &out,
66+
size_t i = 0, size_t j = 0) const = 0;
67+
6568
virtual bool IsRuntime() const = 0;
6669

6770
std::vector<InferShapeVarPtr> GetInputVarPtrs(const std::string &name);

paddle/fluid/operators/controlflow/tensor_array_read_write_op.cc

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,19 @@ equation is
167167
};
168168

169169
class ReadFromArrayInferShape : public WriteToArrayInferShape {
170+
public:
171+
void operator()(framework::InferShapeContext *context) const override {
172+
WriteToArrayInferShape::operator()(context);
173+
if (!context->HasInput("X")) {
174+
return;
175+
}
176+
177+
// FIXME: just for compile time.
178+
if (!context->IsRuntime()) {
179+
context->ShareLoD("X", /*->*/ "Out");
180+
}
181+
}
182+
170183
protected:
171184
const char *NotHasXError() const override {
172185
return "The input array X must be set";

paddle/fluid/operators/lod_tensor_to_array_op.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,10 @@ class LoDTensorToArrayInferShape : public framework::InferShapeBase {
192192
// The first dim of each LoDTensor in Output can only be set at run-time.;
193193
// We still have to Resize each LoDTensor in Output.
194194
context->SetOutputDim("Out", x_dim);
195+
// The lod level should be passed to out in compile time.
196+
if (!context->IsRuntime()) {
197+
context->DecreaseLoDLevel("X", /*->*/ "Out");
198+
}
195199
}
196200
};
197201

paddle/fluid/operators/reorder_lod_tensor_by_rank_op.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,9 @@ class IdentityInferShape : public framework::InferShapeBase {
201201
public:
202202
void operator()(framework::InferShapeContext *context) const override {
203203
context->SetOutputDim("Out", context->GetInputDim("X"));
204+
if (!context->IsRuntime()) {
205+
context->ShareLoD("X", /*->*/ "Out");
206+
}
204207
}
205208
};
206209

paddle/fluid/operators/shrink_rnn_memory_op.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,9 @@ class ShrinkRNNMemoryInferShape : public framework::InferShapeBase {
100100
PADDLE_ENFORCE(context->HasInput("I"));
101101
PADDLE_ENFORCE(context->HasInput("RankTable"));
102102
context->SetOutputDim("Out", context->GetInputDim("X"));
103+
if (!context->IsRuntime()) {
104+
context->DecreaseLoDLevel("X", /*->*/ "Out");
105+
}
103106
}
104107
};
105108

python/paddle/fluid/tests/unittests/test_dyn_rnn.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,13 +172,15 @@ def fake_reader():
172172
rnn = fluid.layers.DynamicRNN()
173173
with rnn.block():
174174
in_ = rnn.step_input(sentence)
175+
assert in_.lod_level == 1, "the lod level of in_ should be 1"
175176
sent_emb = fluid.layers.embedding(
176177
input=in_, size=[len(word_dict), 32], dtype='float32')
177178
out_ = fluid.layers.fc(input=sent_emb, size=100, act='tanh')
178179

179180
rnn1 = fluid.layers.DynamicRNN()
180181
with rnn1.block():
181182
in_1 = rnn1.step_input(out_)
183+
assert in_1.lod_level == 0, "the lod level of in_1 should be 0"
182184
out_1 = fluid.layers.fc(input=[in_1], size=100, act='tanh')
183185
rnn1.output(out_1)
184186

0 commit comments

Comments
 (0)