Skip to content

Commit e176170

Browse files
author
chengduo
authored
Set the right shape of selected_rows (#13723)
* set the right shape of selected_rows test=develop * enhance check * fix activation_op * remove cast * use ShareDimInfo replace SetDim and ShareLod * use ShareDimAndLod test=develop * follow comment test=develop * check whether the input has lod test=develop * Split ShareDimAndLod test=develop * checkout clip.py test=develop
1 parent 2a36f0a commit e176170

19 files changed

+158
-25
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,5 +25,6 @@ third_party/
2525
bazel-*
2626
third_party/
2727

28+
build_*
2829
# clion workspace.
2930
cmake-build-*

paddle/fluid/framework/op_desc.cc

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,27 @@ class CompileTimeInferShapeContext : public InferShapeContext {
5050
const std::vector<std::string> &Outputs(
5151
const std::string &name) const override;
5252

53+
void ShareDim(const std::string &in, const std::string &out, size_t i = 0,
54+
size_t j = 0) override {
55+
PADDLE_ENFORCE_LT(i, Inputs(in).size());
56+
PADDLE_ENFORCE_LT(j, Outputs(out).size());
57+
const std::string &input_n = Inputs(in)[i];
58+
const std::string &output_n = Outputs(out)[j];
59+
60+
PADDLE_ENFORCE(input_n != framework::kEmptyVarName, "The %s[%d] is @EMPTY@",
61+
in, i);
62+
PADDLE_ENFORCE(output_n != framework::kEmptyVarName,
63+
"The %s[%d] is @EMPTY@", out, j);
64+
65+
auto *in_var = block_.FindVarRecursive(input_n);
66+
auto *out_var = block_.FindVarRecursive(output_n);
67+
68+
PADDLE_ENFORCE(in_var->GetType() == out_var->GetType(),
69+
"The type of %s and %s is not the same.", input_n, output_n);
70+
71+
SetDim(output_n, GetDim(input_n));
72+
}
73+
5374
void ShareLoD(const std::string &in, const std::string &out, size_t i = 0,
5475
size_t j = 0) const override {
5576
PADDLE_ENFORCE_LT(i, Inputs(in).size());

paddle/fluid/framework/operator.cc

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -542,6 +542,36 @@ class RuntimeInferShapeContext : public InferShapeContext {
542542
return op_.Outputs(name);
543543
}
544544

545+
void ShareDim(const std::string& in, const std::string& out, size_t i = 0,
546+
size_t j = 0) override {
547+
PADDLE_ENFORCE_LT(i, Inputs(in).size());
548+
PADDLE_ENFORCE_LT(j, Outputs(out).size());
549+
const std::string& input_n = Inputs(in)[i];
550+
const std::string& output_n = Outputs(out)[j];
551+
552+
Variable* in_var = scope_.FindVar(input_n);
553+
Variable* out_var = scope_.FindVar(output_n);
554+
PADDLE_ENFORCE(in_var->Type() == out_var->Type(),
555+
"The type of %s and %s is not the same.", output_n,
556+
GetDim(input_n));
557+
558+
if (in_var->IsType<framework::SelectedRows>()) {
559+
auto& in_sele_rows = in_var->Get<framework::SelectedRows>();
560+
auto out_sele_rows = out_var->GetMutable<framework::SelectedRows>();
561+
out_sele_rows->mutable_value()->Resize(in_sele_rows.value().dims());
562+
out_sele_rows->set_rows(in_sele_rows.rows());
563+
out_sele_rows->set_height(in_sele_rows.height());
564+
} else if (in_var->IsType<framework::LoDTensor>()) {
565+
auto& in_lod_tensor = in_var->Get<framework::LoDTensor>();
566+
auto* out_lod_tensor = out_var->GetMutable<framework::LoDTensor>();
567+
out_lod_tensor->Resize(in_lod_tensor.dims());
568+
} else {
569+
PADDLE_THROW(
570+
"Currently, the input type of ShareDim only can be LoDTensor "
571+
"or SelectedRows.");
572+
}
573+
}
574+
545575
void ShareLoD(const std::string& in, const std::string& out, size_t i = 0,
546576
size_t j = 0) const override {
547577
const std::vector<std::string>& inputs = Inputs(in);

paddle/fluid/framework/shape_inference.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,9 @@ class InferShapeContext {
5656
virtual const std::vector<std::string> &Outputs(
5757
const std::string &name) const = 0;
5858

59+
virtual void ShareDim(const std::string &in, const std::string &out,
60+
size_t i = 0, size_t j = 0) = 0;
61+
5962
virtual void ShareLoD(const std::string &in, const std::string &out,
6063
size_t i = 0, size_t j = 0) const = 0;
6164

paddle/fluid/operators/activation_op.cc

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ class ActivationOp : public framework::OperatorWithKernel {
8080
using framework::OperatorWithKernel::OperatorWithKernel;
8181

8282
void InferShape(framework::InferShapeContext* ctx) const override {
83-
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
83+
ctx->ShareDim("X", /*->*/ "Out");
8484
ctx->ShareLoD("X", /*->*/ "Out");
8585
}
8686

@@ -91,12 +91,26 @@ class ActivationOp : public framework::OperatorWithKernel {
9191
}
9292
};
9393

94+
class ActivationOpInferVarType : public framework::VarTypeInference {
95+
public:
96+
void operator()(const framework::OpDesc& op_desc,
97+
framework::BlockDesc* block) const override {
98+
auto x_name = op_desc.Input("X")[0];
99+
auto out_name = op_desc.Output("Out")[0];
100+
auto& x = block->FindRecursiveOrCreateVar(x_name);
101+
auto& out = block->FindRecursiveOrCreateVar(out_name);
102+
out.SetType(x.GetType());
103+
out.SetDataType(x.GetDataType());
104+
}
105+
};
106+
94107
class ActivationOpGrad : public framework::OperatorWithKernel {
95108
public:
96109
using framework::OperatorWithKernel::OperatorWithKernel;
97110

98111
void InferShape(framework::InferShapeContext* ctx) const override {
99-
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("Out"));
112+
ctx->ShareDim("Out", framework::GradVarName("X"));
113+
ctx->ShareLoD("Out", framework::GradVarName("X"));
100114
}
101115

102116
protected:
@@ -525,12 +539,14 @@ namespace ops = paddle::operators;
525539
#define REGISTER_INPLACE_ACTIVATION_OP(OP_NAME, KERNEL_TYPE) \
526540
REGISTER_OPERATOR(KERNEL_TYPE, ::paddle::operators::ActivationOp, \
527541
::paddle::operators::OP_NAME##OpMaker, \
542+
::paddle::operators::ActivationOpInferVarType, \
528543
::paddle::operators::OP_NAME##GradMaker); \
529544
REGISTER_OPERATOR(KERNEL_TYPE##_grad, ::paddle::operators::ActivationOpGrad)
530545

531546
#define REGISTER_ACTIVATION_OP(OP_NAME, KERNEL_TYPE) \
532547
REGISTER_OPERATOR(KERNEL_TYPE, ::paddle::operators::ActivationOp, \
533548
::paddle::operators::OP_NAME##OpMaker, \
549+
::paddle::operators::ActivationOpInferVarType, \
534550
::paddle::framework::DefaultGradOpDescMaker<true>); \
535551
REGISTER_OPERATOR(KERNEL_TYPE##_grad, ::paddle::operators::ActivationOpGrad)
536552

paddle/fluid/operators/argsort_op.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,8 @@ class ArgsortOp : public framework::OperatorWithKernel {
4242
"-rank(Input(X)) (%d).",
4343
axis, num_dims);
4444

45-
ctx->SetOutputDim("Out", in_dims);
46-
ctx->SetOutputDim("Indices", in_dims);
45+
ctx->ShareDim("X", "Out");
46+
ctx->ShareDim("X", "Indices");
4747
ctx->ShareLoD("X", "Out");
4848
ctx->ShareLoD("X", "Indices");
4949
}

paddle/fluid/operators/conv_shift_op.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ class ConvShiftOp : public framework::OperatorWithKernel {
4444
PADDLE_ENFORCE_LE(y_dims[1], x_dims[1],
4545
"The 2nd dimension of Input(Y) should be less than or "
4646
"equal to the 2nd dimension of Input(X).");
47-
ctx->SetOutputDim("Out", x_dims);
47+
ctx->ShareDim("X", /*->*/ "Out");
4848
ctx->ShareLoD("X", /*->*/ "Out");
4949
}
5050
};

paddle/fluid/operators/elementwise_op.h

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@ class ElementwiseOp : public framework::OperatorWithKernel {
4141
auto y_dim = ctx->GetInputDim("Y");
4242
PADDLE_ENFORCE_GE(x_dim.size(), y_dim.size(),
4343
"Rank of first input must >= rank of second input.");
44-
ctx->SetOutputDim("Out", x_dim);
44+
45+
ctx->ShareDim("X", /*->*/ "Out");
4546
ctx->ShareLoD("X", /*->*/ "Out");
4647
}
4748

@@ -70,6 +71,7 @@ class ElementwiseOpInferVarType : public framework::VarTypeInference {
7071
auto& x = block->FindRecursiveOrCreateVar(x_name);
7172
auto& out = block->FindRecursiveOrCreateVar(out_name);
7273
out.SetType(x.GetType());
74+
out.SetDataType(x.GetDataType());
7375
}
7476
};
7577

@@ -157,10 +159,12 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel {
157159
auto x_grad_name = framework::GradVarName("X");
158160
auto y_grad_name = framework::GradVarName("Y");
159161
if (ctx->HasOutput(x_grad_name)) {
160-
ctx->SetOutputDim(x_grad_name, x_dims);
162+
ctx->ShareDim("X", /*->*/ x_grad_name);
163+
ctx->ShareLoD("X", /*->*/ x_grad_name);
161164
}
162165
if (ctx->HasOutput(y_grad_name)) {
163-
ctx->SetOutputDim(y_grad_name, y_dims);
166+
ctx->ShareDim("Y", /*->*/ y_grad_name);
167+
ctx->ShareLoD("Y", /*->*/ y_grad_name);
164168
}
165169
}
166170

@@ -193,14 +197,15 @@ class ElementwiseOpExplicitGrad : public ElementwiseOpGrad {
193197

194198
auto x_grad_name = framework::GradVarName("X");
195199
if (ctx->HasOutput(x_grad_name)) {
196-
auto out_dims = ctx->GetInputDim(framework::GradVarName("Out"));
197-
ctx->SetOutputDim(x_grad_name, out_dims);
200+
ctx->ShareDim(framework::GradVarName("Out"), /*->*/ x_grad_name);
201+
ctx->ShareLoD(framework::GradVarName("Out"), /*->*/ x_grad_name);
198202
}
199203
auto y_grad_name = framework::GradVarName("Y");
200204
if (ctx->HasOutput(y_grad_name)) {
201205
PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null");
202-
auto y_dims = ctx->GetInputDim("Y");
203-
ctx->SetOutputDim(y_grad_name, y_dims);
206+
207+
ctx->ShareDim("Y", /*->*/ y_grad_name);
208+
ctx->ShareLoD("Y", /*->*/ y_grad_name);
204209
}
205210
}
206211
};

paddle/fluid/operators/fake_dequantize_op.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ class FakeDequantizeMaxAbsOp : public framework::OperatorWithKernel {
4848
"Input(X) of FakeDequantizeMaxAbsOp should not be null.");
4949
PADDLE_ENFORCE(ctx->HasOutput("Out"),
5050
"Output(Out) of FakeDequantizeMaxAbsOp should not be null.");
51-
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
51+
52+
ctx->ShareDim("X", /*->*/ "Out");
5253
ctx->ShareLoD("X", /*->*/ "Out");
5354
}
5455
};

paddle/fluid/operators/lookup_table_op.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ class LookupTableOpGradVarTypeInference : public framework::VarTypeInference {
137137
<< " is set to LoDTensor";
138138
block->Var(out_var_name)->SetType(framework::proto::VarType::LOD_TENSOR);
139139
}
140+
block->Var(out_var_name)->SetDataType(block->Var("W")->GetDataType());
140141
}
141142
};
142143

0 commit comments

Comments
 (0)