Skip to content

Commit cdf3a4c

Browse files
author
chengduo
authored
Fix concat_op InferShape (#13513)
* add ShareLoDs * refine * add Is EmptyVarName * refine Sharedlod
1 parent 6537b17 commit cdf3a4c

File tree

4 files changed

+31
-2
lines changed

4 files changed

+31
-2
lines changed

paddle/fluid/framework/op_desc.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,10 @@ class CompileTimeInferShapeContext : public InferShapeContext {
5454
size_t j = 0) const override {
5555
PADDLE_ENFORCE_LT(i, Inputs(in).size());
5656
PADDLE_ENFORCE_LT(j, Outputs(out).size());
57+
PADDLE_ENFORCE(Inputs(in)[i] != framework::kEmptyVarName,
58+
"The %s[%d] is @EMPTY@", in, i);
59+
PADDLE_ENFORCE(Outputs(out)[j] != framework::kEmptyVarName,
60+
"The %s[%d] is @EMPTY@", out, j);
5761
auto *in_var = block_.FindVarRecursive(Inputs(in)[i]);
5862
auto *out_var = block_.FindVarRecursive(Outputs(out)[j]);
5963
if (in_var->GetType() != proto::VarType::LOD_TENSOR) {
@@ -63,6 +67,7 @@ class CompileTimeInferShapeContext : public InferShapeContext {
6367
PADDLE_ENFORCE_EQ(in_var->GetType(), proto::VarType::LOD_TENSOR,
6468
"The %d-th output of Output(%s) must be LoDTensor.", j,
6569
out);
70+
6671
out_var->SetLoDLevel(in_var->GetLoDLevel());
6772
}
6873

paddle/fluid/framework/shape_inference.cc

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,16 @@ std::vector<DDim> InferShapeContext::GetReaderDims(
4646
return this->GetRepeatedDims(arg_names[0]);
4747
}
4848

49+
void InferShapeContext::ShareLoDs(const std::string &in,
50+
const std::string &out) const {
51+
PADDLE_ENFORCE_EQ(Inputs(in).size(), Outputs(out).size(),
52+
"The number of arguments in %s and %s is not equal.", in,
53+
out);
54+
for (size_t i = 0; i < in.size(); ++i) {
55+
ShareLoD(in, out, i, i);
56+
}
57+
}
58+
4959
DDim InferShapeContext::GetInputsElementDim(const std::string &name,
5060
int idx) const {
5161
const std::vector<std::string> &names = Inputs(name);

paddle/fluid/framework/shape_inference.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@ class InferShapeContext {
5656
virtual const std::vector<std::string> &Outputs(
5757
const std::string &name) const = 0;
5858

59+
void ShareLoDs(const std::string &in, const std::string &out) const;
60+
5961
virtual void ShareLoD(const std::string &in, const std::string &out,
6062
size_t i = 0, size_t j = 0) const = 0;
6163

paddle/fluid/operators/concat_op.cc

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,20 @@ class ConcatOpGrad : public framework::OperatorWithKernel {
9494
: OperatorWithKernel(type, inputs, outputs, attrs) {}
9595

9696
void InferShape(framework::InferShapeContext *ctx) const override {
97-
ctx->SetOutputsDim(framework::GradVarName("X"), ctx->GetInputsDim("X"));
98-
ctx->ShareLoD("X", framework::GradVarName("X"));
97+
auto in_x = "X";
98+
auto out_x_g_n = framework::GradVarName(in_x);
99+
ctx->SetOutputsDim(out_x_g_n, ctx->GetInputsDim(in_x));
100+
auto &in_names = ctx->Inputs(in_x);
101+
auto &out_names = ctx->Outputs(out_x_g_n);
102+
PADDLE_ENFORCE_EQ(
103+
in_names.size(), out_names.size(),
104+
"The number of arguments in %s[%d] and %s[%d] is not equal.", in_x,
105+
in_names.size(), out_x_g_n, out_names.size());
106+
for (size_t i = 0; i < in_names.size(); ++i) {
107+
if (out_names[i] != framework::kEmptyVarName) {
108+
ctx->ShareLoD(in_x, out_x_g_n, i, i);
109+
}
110+
}
99111
}
100112
};
101113

0 commit comments

Comments
 (0)