Skip to content

Commit 39e129a

Browse files
authored
Merge pull request #13794 from chengduoZH/set_right_shape_of_seleceted_rows_release
Set the output shape of seleceted rows in right way.
2 parents b2e6e5f + e7e69e2 commit 39e129a

17 files changed

+162
-29
lines changed

paddle/fluid/framework/op_desc.cc

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,27 @@ class CompileTimeInferShapeContext : public InferShapeContext {
5050
const std::vector<std::string> &Outputs(
5151
const std::string &name) const override;
5252

53+
void ShareDim(const std::string &in, const std::string &out, size_t i = 0,
54+
size_t j = 0) override {
55+
PADDLE_ENFORCE_LT(i, Inputs(in).size());
56+
PADDLE_ENFORCE_LT(j, Outputs(out).size());
57+
const std::string &input_n = Inputs(in)[i];
58+
const std::string &output_n = Outputs(out)[j];
59+
60+
PADDLE_ENFORCE(input_n != framework::kEmptyVarName, "The %s[%d] is @EMPTY@",
61+
in, i);
62+
PADDLE_ENFORCE(output_n != framework::kEmptyVarName,
63+
"The %s[%d] is @EMPTY@", out, j);
64+
65+
auto *in_var = block_.FindVarRecursive(input_n);
66+
auto *out_var = block_.FindVarRecursive(output_n);
67+
68+
PADDLE_ENFORCE(in_var->GetType() == out_var->GetType(),
69+
"The type of %s and %s is not the same.", input_n, output_n);
70+
71+
SetDim(output_n, GetDim(input_n));
72+
}
73+
5374
void ShareLoD(const std::string &in, const std::string &out, size_t i = 0,
5475
size_t j = 0) const override {
5576
PADDLE_ENFORCE_LT(i, Inputs(in).size());

paddle/fluid/framework/operator.cc

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -542,13 +542,45 @@ class RuntimeInferShapeContext : public InferShapeContext {
542542
return op_.Outputs(name);
543543
}
544544

545-
void ShareLoD(const std::string& in, const std::string& out, size_t i = 0,
546-
size_t j = 0) const override {
545+
void ShareDim(const std::string& in, const std::string& out, size_t i = 0,
546+
size_t j = 0) override {
547547
PADDLE_ENFORCE_LT(i, Inputs(in).size());
548548
PADDLE_ENFORCE_LT(j, Outputs(out).size());
549-
Variable* in_var = scope_.FindVar(Inputs(in)[i]);
550-
Variable* out_var = scope_.FindVar(Outputs(out)[j]);
549+
const std::string& input_n = Inputs(in)[i];
550+
const std::string& output_n = Outputs(out)[j];
551+
552+
Variable* in_var = scope_.FindVar(input_n);
553+
Variable* out_var = scope_.FindVar(output_n);
554+
PADDLE_ENFORCE(in_var->Type() == out_var->Type(),
555+
"The type of %s and %s is not the same.", output_n,
556+
GetDim(input_n));
557+
558+
if (in_var->IsType<framework::SelectedRows>()) {
559+
auto& in_sele_rows = in_var->Get<framework::SelectedRows>();
560+
auto out_sele_rows = out_var->GetMutable<framework::SelectedRows>();
561+
out_sele_rows->mutable_value()->Resize(in_sele_rows.value().dims());
562+
out_sele_rows->set_rows(in_sele_rows.rows());
563+
out_sele_rows->set_height(in_sele_rows.height());
564+
} else if (in_var->IsType<framework::LoDTensor>()) {
565+
auto& in_lod_tensor = in_var->Get<framework::LoDTensor>();
566+
auto* out_lod_tensor = out_var->GetMutable<framework::LoDTensor>();
567+
out_lod_tensor->Resize(in_lod_tensor.dims());
568+
} else {
569+
PADDLE_THROW(
570+
"Currently, the input type of ShareDim only can be LoDTensor "
571+
"or SelectedRows.");
572+
}
573+
}
574+
575+
void ShareLoD(const std::string& in, const std::string& out, size_t i = 0,
576+
size_t j = 0) const override {
577+
const std::vector<std::string>& inputs = Inputs(in);
578+
const std::vector<std::string>& outputs = Outputs(out);
579+
PADDLE_ENFORCE_LT(i, inputs.size());
580+
PADDLE_ENFORCE_LT(j, outputs.size());
581+
Variable* in_var = scope_.FindVar(inputs.at(i));
551582
if (!in_var->IsType<LoDTensor>()) return;
583+
Variable* out_var = scope_.FindVar(outputs.at(j));
552584
PADDLE_ENFORCE(out_var->IsType<LoDTensor>(),
553585
"The %d-th output of Output(%s) must be LoDTensor.", j, out);
554586
auto in_tensor = in_var->Get<LoDTensor>();

paddle/fluid/framework/shape_inference.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,9 @@ class InferShapeContext {
5858

5959
void ShareLoDs(const std::string &in, const std::string &out) const;
6060

61+
virtual void ShareDim(const std::string &in, const std::string &out,
62+
size_t i = 0, size_t j = 0) = 0;
63+
6164
virtual void ShareLoD(const std::string &in, const std::string &out,
6265
size_t i = 0, size_t j = 0) const = 0;
6366

paddle/fluid/operators/activation_op.cc

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ class ActivationOp : public framework::OperatorWithKernel {
7979
using framework::OperatorWithKernel::OperatorWithKernel;
8080

8181
void InferShape(framework::InferShapeContext* ctx) const override {
82-
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
82+
ctx->ShareDim("X", /*->*/ "Out");
8383
ctx->ShareLoD("X", /*->*/ "Out");
8484
}
8585

@@ -90,12 +90,26 @@ class ActivationOp : public framework::OperatorWithKernel {
9090
}
9191
};
9292

93+
class ActivationOpInferVarType : public framework::VarTypeInference {
94+
public:
95+
void operator()(const framework::OpDesc& op_desc,
96+
framework::BlockDesc* block) const override {
97+
auto x_name = op_desc.Input("X")[0];
98+
auto out_name = op_desc.Output("Out")[0];
99+
auto& x = block->FindRecursiveOrCreateVar(x_name);
100+
auto& out = block->FindRecursiveOrCreateVar(out_name);
101+
out.SetType(x.GetType());
102+
out.SetDataType(x.GetDataType());
103+
}
104+
};
105+
93106
class ActivationOpGrad : public framework::OperatorWithKernel {
94107
public:
95108
using framework::OperatorWithKernel::OperatorWithKernel;
96109

97110
void InferShape(framework::InferShapeContext* ctx) const override {
98-
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("Out"));
111+
ctx->ShareDim("Out", framework::GradVarName("X"));
112+
ctx->ShareLoD("Out", framework::GradVarName("X"));
99113
}
100114

101115
protected:
@@ -524,12 +538,14 @@ namespace ops = paddle::operators;
524538
#define REGISTER_INPLACE_ACTIVATION_OP(OP_NAME, KERNEL_TYPE) \
525539
REGISTER_OPERATOR(KERNEL_TYPE, ::paddle::operators::ActivationOp, \
526540
::paddle::operators::OP_NAME##OpMaker, \
541+
::paddle::operators::ActivationOpInferVarType, \
527542
::paddle::operators::OP_NAME##GradMaker); \
528543
REGISTER_OPERATOR(KERNEL_TYPE##_grad, ::paddle::operators::ActivationOpGrad)
529544

530545
#define REGISTER_ACTIVATION_OP(OP_NAME, KERNEL_TYPE) \
531546
REGISTER_OPERATOR(KERNEL_TYPE, ::paddle::operators::ActivationOp, \
532547
::paddle::operators::OP_NAME##OpMaker, \
548+
::paddle::operators::ActivationOpInferVarType, \
533549
::paddle::framework::DefaultGradOpDescMaker<true>); \
534550
REGISTER_OPERATOR(KERNEL_TYPE##_grad, ::paddle::operators::ActivationOpGrad)
535551

paddle/fluid/operators/argsort_op.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,8 @@ class ArgsortOp : public framework::OperatorWithKernel {
4242
"-rank(Input(X)) (%d).",
4343
axis, num_dims);
4444

45-
ctx->SetOutputDim("Out", in_dims);
46-
ctx->SetOutputDim("Indices", in_dims);
45+
ctx->ShareDim("X", "Out");
46+
ctx->ShareDim("X", "Indices");
4747
ctx->ShareLoD("X", "Out");
4848
ctx->ShareLoD("X", "Indices");
4949
}

paddle/fluid/operators/conv_shift_op.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ class ConvShiftOp : public framework::OperatorWithKernel {
4444
PADDLE_ENFORCE_LE(y_dims[1], x_dims[1],
4545
"The 2nd dimension of Input(Y) should be less than or "
4646
"equal to the 2nd dimension of Input(X).");
47-
ctx->SetOutputDim("Out", x_dims);
47+
ctx->ShareDim("X", /*->*/ "Out");
4848
ctx->ShareLoD("X", /*->*/ "Out");
4949
}
5050
};

paddle/fluid/operators/elementwise_op.h

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@ class ElementwiseOp : public framework::OperatorWithKernel {
4141
auto y_dim = ctx->GetInputDim("Y");
4242
PADDLE_ENFORCE_GE(x_dim.size(), y_dim.size(),
4343
"Rank of first input must >= rank of second input.");
44-
ctx->SetOutputDim("Out", x_dim);
44+
45+
ctx->ShareDim("X", /*->*/ "Out");
4546
ctx->ShareLoD("X", /*->*/ "Out");
4647
}
4748

@@ -70,6 +71,7 @@ class ElementwiseOpInferVarType : public framework::VarTypeInference {
7071
auto& x = block->FindRecursiveOrCreateVar(x_name);
7172
auto& out = block->FindRecursiveOrCreateVar(out_name);
7273
out.SetType(x.GetType());
74+
out.SetDataType(x.GetDataType());
7375
}
7476
};
7577

@@ -157,10 +159,12 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel {
157159
auto x_grad_name = framework::GradVarName("X");
158160
auto y_grad_name = framework::GradVarName("Y");
159161
if (ctx->HasOutput(x_grad_name)) {
160-
ctx->SetOutputDim(x_grad_name, x_dims);
162+
ctx->ShareDim("X", /*->*/ x_grad_name);
163+
ctx->ShareLoD("X", /*->*/ x_grad_name);
161164
}
162165
if (ctx->HasOutput(y_grad_name)) {
163-
ctx->SetOutputDim(y_grad_name, y_dims);
166+
ctx->ShareDim("Y", /*->*/ y_grad_name);
167+
ctx->ShareLoD("Y", /*->*/ y_grad_name);
164168
}
165169
}
166170

@@ -193,14 +197,15 @@ class ElementwiseOpExplicitGrad : public ElementwiseOpGrad {
193197

194198
auto x_grad_name = framework::GradVarName("X");
195199
if (ctx->HasOutput(x_grad_name)) {
196-
auto out_dims = ctx->GetInputDim(framework::GradVarName("Out"));
197-
ctx->SetOutputDim(x_grad_name, out_dims);
200+
ctx->ShareDim(framework::GradVarName("Out"), /*->*/ x_grad_name);
201+
ctx->ShareLoD(framework::GradVarName("Out"), /*->*/ x_grad_name);
198202
}
199203
auto y_grad_name = framework::GradVarName("Y");
200204
if (ctx->HasOutput(y_grad_name)) {
201205
PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null");
202-
auto y_dims = ctx->GetInputDim("Y");
203-
ctx->SetOutputDim(y_grad_name, y_dims);
206+
207+
ctx->ShareDim("Y", /*->*/ y_grad_name);
208+
ctx->ShareLoD("Y", /*->*/ y_grad_name);
204209
}
205210
}
206211
};

paddle/fluid/operators/fake_dequantize_op.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ class FakeDequantizeMaxAbsOp : public framework::OperatorWithKernel {
4848
"Input(X) of FakeDequantizeMaxAbsOp should not be null.");
4949
PADDLE_ENFORCE(ctx->HasOutput("Out"),
5050
"Output(Out) of FakeDequantizeMaxAbsOp should not be null.");
51-
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
51+
52+
ctx->ShareDim("X", /*->*/ "Out");
5253
ctx->ShareLoD("X", /*->*/ "Out");
5354
}
5455
};

paddle/fluid/operators/prelu_op.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ class PReluOp : public framework::OperatorWithKernel {
4949
} else {
5050
PADDLE_THROW("Unkown mode %s", mode);
5151
}
52-
ctx->SetOutputDim("Out", x_dim);
52+
ctx->ShareDim("X", /*->*/ "Out");
5353
ctx->ShareLoD("X", /*->*/ "Out");
5454
}
5555

paddle/fluid/operators/rnn_memory_helper_op.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ class RNNMemoryHelperOpShapeInference : public framework::InferShapeBase {
5454
"Input(X) of rnn_memory_helper op should not be null.");
5555
PADDLE_ENFORCE(ctx->HasOutput("Out"),
5656
"Output of rnn_memory_helper op should not be null.");
57-
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
57+
ctx->ShareDim("X", /*->*/ "Out");
5858
ctx->ShareLoD("X", /*->*/ "Out");
5959
}
6060
};

0 commit comments

Comments
 (0)