Skip to content

Commit 8015fbd

Browse files
authored
[Cherry-pick]Move sum op to PHI && Fix MetaTensor's bug when run infermeta (#49342)
* cherry-pick 45860 * [BUG FIX]Fix MetaTensor's bug when run infermeta (#46265) * fix sum bug * fix ci bugs * fix ci bugs * update code according comment
1 parent 8aa5be9 commit 8015fbd

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+810
-686
lines changed

paddle/fluid/framework/infershape_utils.cc

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,15 @@ class InferShapeArgumentMappingContext : public phi::ArgumentMappingContext {
8787
});
8888
}
8989

90+
bool IsSelectedRowsInputs(const std::string& name) const override {
91+
auto var_types = ctx_.GetInputsVarType(name);
92+
return std::all_of(var_types.begin(),
93+
var_types.end(),
94+
[](const proto::VarType::Type& type) {
95+
return type == proto::VarType::SELECTED_ROWS;
96+
});
97+
}
98+
9099
bool IsSelectedRowsInput(const std::string& name) const override {
91100
auto var_type = ctx_.GetInputVarType(name);
92101
return var_type == proto::VarType::SELECTED_ROWS;
@@ -155,6 +164,16 @@ int64_t CompatMetaTensor::numel() const {
155164
}
156165
}
157166

167+
bool CompatMetaTensor::is_selected_rows() const {
168+
if (is_runtime_) {
169+
auto* var = PADDLE_GET_CONST(Variable*, var_);
170+
return var->IsType<phi::SelectedRows>();
171+
} else {
172+
auto* var = PADDLE_GET_CONST(VarDesc*, var_);
173+
return var->GetType() == proto::VarType::SELECTED_ROWS;
174+
}
175+
}
176+
158177
bool CompatMetaTensor::is_dense() const {
159178
if (is_runtime_) {
160179
auto* var = PADDLE_GET_CONST(Variable*, var_);
@@ -182,7 +201,7 @@ DDim CompatMetaTensor::dims() const {
182201
if (var->IsType<phi::DenseTensor>()) {
183202
return var->Get<phi::DenseTensor>().dims();
184203
} else if (var->IsType<phi::SelectedRows>()) {
185-
return var->Get<phi::SelectedRows>().dims();
204+
return var->Get<phi::SelectedRows>().GetCompleteDims();
186205
} else if (var->IsType<phi::SparseCooTensor>()) {
187206
return var->Get<phi::SparseCooTensor>().dims();
188207
} else if (var->IsType<framework::LoDTensorArray>()) {
@@ -260,8 +279,7 @@ void CompatMetaTensor::set_dims(const DDim& dims) {
260279
auto* tensor = var->GetMutable<phi::DenseTensor>();
261280
phi::DenseTensorUtils::GetMutableMeta(tensor)->dims = dims;
262281
} else if (var->IsType<phi::SelectedRows>()) {
263-
auto* tensor = var->GetMutable<phi::SelectedRows>()->mutable_value();
264-
phi::DenseTensorUtils::GetMutableMeta(tensor)->dims = dims;
282+
var->GetMutable<phi::SelectedRows>()->set_height(dims[0]);
265283
} else if (var->IsType<phi::SparseCooTensor>()) {
266284
auto* tensor = var->GetMutable<phi::SparseCooTensor>();
267285
phi::DenseTensorUtils::GetMutableMeta(tensor)->dims = dims;

paddle/fluid/framework/infershape_utils.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ class CompatMetaTensor : public phi::MetaTensor {
5959

6060
bool initialized() const override { return initialized_; };
6161

62+
bool is_selected_rows() const;
63+
6264
bool is_tensor_array() const;
6365
bool is_dense() const;
6466

paddle/fluid/framework/new_executor/standalone_executor_test.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ USE_OP_ITSELF(concat_grad);
5050
USE_OP_ITSELF(elementwise_mul_grad);
5151
USE_OP_ITSELF(sigmoid_grad);
5252
USE_OP_ITSELF(tanh_grad);
53-
USE_OP(sum);
53+
USE_OP_ITSELF(sum);
5454
USE_OP_ITSELF(slice_grad);
5555
USE_OP_ITSELF(lookup_table_grad);
5656
USE_OP_ITSELF(sqrt);
@@ -101,6 +101,7 @@ PD_DECLARE_KERNEL(slice_grad, GPU, ALL_LAYOUT);
101101
PD_DECLARE_KERNEL(cross_entropy_with_softmax, GPU, ALL_LAYOUT);
102102
PD_DECLARE_KERNEL(cross_entropy_with_softmax_grad, GPU, ALL_LAYOUT);
103103
PD_DECLARE_KERNEL(sqrt, GPU, ALL_LAYOUT);
104+
PD_DECLARE_KERNEL(add_n, GPU, ALL_LAYOUT);
104105

105106
namespace paddle {
106107
namespace framework {

paddle/fluid/framework/operator.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -512,6 +512,13 @@ class ExecutionArgumentMappingContext : public phi::ArgumentMappingContext {
512512
});
513513
}
514514

515+
bool IsSelectedRowsInputs(const std::string& name) const override {
516+
auto vars = ctx_.MultiInputVar(name);
517+
return std::all_of(vars.begin(), vars.end(), [](const Variable* var) {
518+
return var->IsType<phi::SelectedRows>();
519+
});
520+
}
521+
515522
bool IsSelectedRowsInput(const std::string& name) const override {
516523
const auto* var = ctx_.InputVar(name);
517524
return var->IsType<phi::SelectedRows>();

paddle/fluid/inference/tensorrt/plugin_arg_mapping_context.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ bool PluginArgumentMappingContext::IsSelectedRowsInput(
104104
const std::string& name) const {
105105
return false;
106106
}
107+
107108
bool PluginArgumentMappingContext::IsSparseCooTensorInput(
108109
const std::string& name) const {
109110
return false;
@@ -112,6 +113,11 @@ bool PluginArgumentMappingContext::IsSparseCsrTensorInput(
112113
const std::string& name) const {
113114
return false;
114115
}
116+
117+
bool PluginArgumentMappingContext::IsSelectedRowsInputs(
118+
const std::string& name) const {
119+
return false;
120+
}
115121
bool PluginArgumentMappingContext::IsDenseTensorVectorInput(
116122
const std::string& name) const {
117123
return false;

paddle/fluid/inference/tensorrt/plugin_arg_mapping_context.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ class PluginArgumentMappingContext : public ::phi::ArgumentMappingContext {
5050

5151
bool IsSparseCsrTensorInput(const std::string& name) const override;
5252

53+
bool IsSelectedRowsInputs(const std::string& name) const override;
54+
5355
bool IsDenseTensorVectorInput(const std::string& name) const override;
5456

5557
bool IsDenseTensorOutput(const std::string& name) const override;

paddle/fluid/operators/mkldnn/sum_mkldnn_op.cc

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@
2424
See the License for the specific language governing permissions and
2525
limitations under the License. */
2626

27-
#include "paddle/fluid/operators/sum_op.h"
27+
#include "paddle/fluid/framework/lod_tensor_array.h"
28+
#include "paddle/fluid/framework/op_registry.h"
2829
#include "paddle/fluid/platform/mkldnn_reuse.h"
2930

3031
namespace phi {
@@ -37,6 +38,9 @@ namespace operators {
3738
using paddle::platform::MKLDNNDeviceContext;
3839
using phi::CPUContext;
3940
using platform::to_void_cast;
41+
using Tensor = framework::Tensor;
42+
using SelectedRows = phi::SelectedRows;
43+
using LoDTensor = framework::LoDTensor;
4044

4145
template <typename T>
4246
class SumMKLDNNHandler

paddle/fluid/operators/sum_op.cc

Lines changed: 11 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,17 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
99
See the License for the specific language governing permissions and
1010
limitations under the License. */
1111

12-
#include "paddle/fluid/operators/sum_op.h"
13-
1412
#include <algorithm>
1513
#include <memory>
1614
#include <string>
1715
#include <unordered_map>
1816
#include <vector>
1917

18+
#include "paddle/fluid/framework/infershape_utils.h"
19+
#include "paddle/fluid/framework/op_registry.h"
2020
#include "paddle/fluid/framework/var_type_inference.h"
21+
#include "paddle/phi/core/infermeta_utils.h"
22+
#include "paddle/phi/infermeta/multiary.h"
2123

2224
#ifdef PADDLE_WITH_MKLDNN
2325
#include "paddle/fluid/platform/mkldnn_helper.h"
@@ -32,94 +34,6 @@ class SumOp : public framework::OperatorWithKernel {
3234
public:
3335
using framework::OperatorWithKernel::OperatorWithKernel;
3436

35-
void InferShape(framework::InferShapeContext* ctx) const override {
36-
OP_INOUT_CHECK(ctx->HasInputs("X"), "Input", "X", "sum");
37-
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "sum");
38-
39-
if (ctx->IsRuntime() && ctx->GetOutputsVarType("Out")[0] ==
40-
framework::proto::VarType::LOD_TENSOR_ARRAY) {
41-
return; // skip runtime infershape when is tensor array;
42-
}
43-
44-
auto x_var_types = ctx->GetInputsVarType("X");
45-
auto x_dims = ctx->GetInputsDim("X");
46-
47-
auto N = x_dims.size();
48-
PADDLE_ENFORCE_GT(
49-
N,
50-
0,
51-
platform::errors::InvalidArgument(
52-
"The input tensor X's dimensions of SumOp "
53-
"should be larger than 0. But received X's dimensions %d, "
54-
"X's shape = [%s].",
55-
N,
56-
&x_dims));
57-
if (N == 1) {
58-
VLOG(3) << "Warning: SumOp have only one input, may waste memory";
59-
}
60-
61-
framework::DDim in_dim({0});
62-
for (size_t i = 0; i < x_dims.size(); ++i) {
63-
auto& x_dim = x_dims[i];
64-
// x_dim.size() == 1 means the real dim of selected rows is [0]
65-
if (x_var_types[i] == framework::proto::VarType::SELECTED_ROWS &&
66-
x_dim.size() == 1) {
67-
continue;
68-
}
69-
if (phi::product(x_dim) == 0) {
70-
continue;
71-
}
72-
if (phi::product(in_dim) == 0) {
73-
in_dim = x_dim;
74-
} else {
75-
if (ctx->IsRuntime()) {
76-
PADDLE_ENFORCE_EQ(in_dim,
77-
x_dim,
78-
platform::errors::InvalidArgument(
79-
"The input tensor X of SumOp must"
80-
" have same shape. But received X[0]'s shape = "
81-
"[%s], X[%d]'s shape = [%s].",
82-
in_dim,
83-
i,
84-
x_dim));
85-
} else {
86-
PADDLE_ENFORCE_EQ(
87-
in_dim.size(),
88-
x_dim.size(),
89-
platform::errors::InvalidArgument(
90-
"The input tensor X of SumOp must have same "
91-
"dimensions. But received X[0]'s dimensions = %d, X[0]'s "
92-
"shape = "
93-
"[%s], X[%d]'s dimensions = %d, X[%d]'s shape = [%s].",
94-
in_dim.size(),
95-
in_dim,
96-
i,
97-
x_dim.size(),
98-
i,
99-
x_dim));
100-
// if in_dim or x_dim has -1, not check equal
101-
for (int j = 0; j < x_dim.size(); ++j) {
102-
if (x_dim[j] == -1 || in_dim[j] == -1) {
103-
continue;
104-
}
105-
PADDLE_ENFORCE_EQ(
106-
in_dim[j],
107-
x_dim[j],
108-
platform::errors::InvalidArgument(
109-
"The input tensor X of SumOp must have same shape "
110-
"if not -1."
111-
"But received X[0]'s shape = [%s], X[%d]'s shape = [%s].",
112-
in_dim,
113-
i,
114-
x_dim));
115-
}
116-
}
117-
}
118-
}
119-
ctx->SetOutputDim("Out", in_dim);
120-
ctx->ShareLoD("X", /*->*/ "Out");
121-
}
122-
12337
protected:
12438
framework::OpKernelType GetExpectedKernelType(
12539
const framework::ExecutionContext& ctx) const override {
@@ -350,18 +264,16 @@ DECLARE_INPLACE_OP_INFERER(SumInplaceInferer, {"X", "Out"});
350264

351265
namespace ops = paddle::operators;
352266

267+
namespace ops = paddle::operators;
268+
DECLARE_INFER_SHAPE_FUNCTOR(sum,
269+
AddNInferShapeFunctor,
270+
PD_INFER_META(phi::AddNTensorArrayInferMeta));
271+
353272
REGISTER_OPERATOR(sum,
354273
ops::SumOp,
355274
ops::SumOpMaker,
356275
ops::SumGradDescMaker,
357276
ops::SumGradOpBaseMaker,
358277
ops::SumOpVarTypeInference,
359-
ops::SumInplaceInferer);
360-
361-
REGISTER_OP_CPU_KERNEL(
362-
sum,
363-
ops::SumKernel<phi::CPUContext, float>,
364-
ops::SumKernel<phi::CPUContext, double>,
365-
ops::SumKernel<phi::CPUContext, int>,
366-
ops::SumKernel<phi::CPUContext, paddle::platform::bfloat16>,
367-
ops::SumKernel<phi::CPUContext, int64_t>);
278+
ops::SumInplaceInferer,
279+
AddNInferShapeFunctor);

0 commit comments

Comments
 (0)