Skip to content

Commit 33d7aae

Browse files
liym27Aurelius84
authored andcommitted
Cherry pick bug fix for Ops: reshape,concat, split and squeeze (#20929)
* [cherry-pick]fix bug in reshape: (#20781) consider the situation that shape of input can contain more than one -1. * [cherry-pick]support Tensor for split and concat, support -1 in num_or_sections, add check num_or_sections (#20780) * improve split and concat op: 1. support Tensor for argument 'dim' in split op. 2. support Tensor for argument 'axis' in concat op. * redefine function GetDataFromTensor and set unknown output shape to - 1. * add check: Attr(sections) match Input(X). * support Tensor for attr(sections) and attr(sections) can contain -1. * modify error message and fix bug for concat and call Resize only when necessary. test=release/1.6 * [cherry-pick]improve unsqueeze op to support int, Tensor for argument axes (#20824) * improve unsqueeze op to support int, Tensor and Tensor list for argument axes. * call Resize only when necessary. test=release/1.6 * [cherry-pick]Compatible int32 and int64 for attr in concat/split/unsqueeze. test=release/1.6 (#20912)
1 parent de130e9 commit 33d7aae

File tree

14 files changed

+1033
-184
lines changed

14 files changed

+1033
-184
lines changed

paddle/fluid/operators/concat_op.cc

Lines changed: 48 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -32,58 +32,36 @@ class ConcatOp : public framework::OperatorWithKernel {
3232
void InferShape(framework::InferShapeContext *ctx) const override {
3333
PADDLE_ENFORCE_GE(ctx->Inputs("X").size(), 1UL,
3434
"Inputs(X) of ConcatOp should not be empty.");
35-
PADDLE_ENFORCE(ctx->HasOutput("Out"),
36-
"Output(Out) of ConcatOp should not be null.");
35+
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
36+
"Output(Out) of ConcatOp should not be null.");
3737

38-
auto ins = ctx->GetInputsDim("X");
39-
size_t axis =
40-
ComputeAxis(static_cast<int64_t>(ctx->Attrs().Get<int>("axis")),
41-
static_cast<int64_t>(ins[0].size()));
38+
auto inputs_dims = ctx->GetInputsDim("X");
4239

43-
const size_t n = ins.size();
44-
PADDLE_ENFORCE_GT(n, 0,
40+
const size_t inputs_num = inputs_dims.size();
41+
PADDLE_ENFORCE_GT(inputs_num, 0,
4542
"ShapeError: Input tensors count should > 0. But "
4643
"recevied inputs' length is 0.");
47-
if (n == 1) {
44+
if (inputs_num == 1) {
4845
VLOG(3) << "Warning: concat op have only one input, may waste memory";
4946
}
5047

51-
auto out_dims = ins[0];
52-
size_t in_zero_dims_size = out_dims.size();
53-
for (size_t i = 1; i < n; i++) {
54-
for (size_t j = 0; j < in_zero_dims_size; j++) {
55-
if (j == axis) {
56-
if (ctx->IsRuntime()) {
57-
out_dims[axis] += ins[i][j];
58-
} else {
59-
if (ins[i][j] == -1) {
60-
out_dims[axis] = -1;
61-
} else {
62-
out_dims[axis] += ins[i][j];
63-
}
64-
}
65-
} else {
66-
bool check_shape =
67-
ctx->IsRuntime() || (out_dims[j] > 0 && ins[i][j] > 0);
68-
if (check_shape) {
69-
// check all shape in run time
70-
PADDLE_ENFORCE_EQ(
71-
out_dims[j], ins[i][j],
72-
"ShapeError: Input tensors should have same "
73-
"dimensions(or specific dimension = -1) except the axis. "
74-
"But recevied axis = %s, input[0]'s shape = "
75-
"[%s], input[%s]'s shape = [%s], the \"%s\" "
76-
"dimension of input[%s] is unexpected",
77-
axis, ins[0], i, ins[j], j, i);
78-
}
79-
}
48+
if (ctx->HasInput("AxisTensor")) {
49+
auto out_dims =
50+
framework::make_ddim(std::vector<int>(inputs_dims[0].size(), -1));
51+
ctx->SetOutputDim("Out", out_dims);
52+
ctx->ShareLoD("X", /*->*/ "Out");
53+
} else {
54+
size_t axis =
55+
ComputeAxis(static_cast<int64_t>(ctx->Attrs().Get<int>("axis")),
56+
static_cast<int64_t>(inputs_dims[0].size()));
57+
framework::DDim out_dims =
58+
ComputeAndCheckShape(ctx->IsRuntime(), inputs_dims, axis);
59+
if (out_dims[axis] < 0) {
60+
out_dims[axis] = -1;
8061
}
62+
ctx->SetOutputDim("Out", out_dims);
63+
ctx->ShareLoD("X", /*->*/ "Out");
8164
}
82-
if (out_dims[axis] < 0) {
83-
out_dims[axis] = -1;
84-
}
85-
ctx->SetOutputDim("Out", out_dims);
86-
ctx->ShareLoD("X", /*->*/ "Out");
8765
}
8866

8967
protected:
@@ -111,6 +89,16 @@ class ConcatOp : public framework::OperatorWithKernel {
11189
#endif
11290
return framework::OpKernelType(input_data_type, ctx.GetPlace());
11391
}
92+
93+
framework::OpKernelType GetKernelTypeForVar(
94+
const std::string &var_name, const Tensor &tensor,
95+
const framework::OpKernelType &expected_kernel_type) const override {
96+
if (var_name == "AxisTensor") {
97+
return expected_kernel_type;
98+
}
99+
return framework::OpKernelType(expected_kernel_type.data_type_,
100+
tensor.place(), tensor.layout());
101+
}
114102
};
115103

116104
class ConcatOpMaker : public framework::OpProtoAndCheckerMaker {
@@ -128,6 +116,12 @@ class ConcatOpMaker : public framework::OpProtoAndCheckerMaker {
128116
"interpreted as counting from the end of the rank."
129117
"i.e., axis + rank(X) th dimension.")
130118
.SetDefault(0);
119+
AddInput("AxisTensor",
120+
"(Tensor) The axis along which the input tensors will be "
121+
"concatenated. "
122+
"It has higher priority than Attr(axis). "
123+
"The shape of AxisTensor must be [1].")
124+
.AsDispensable();
131125
AddAttr<bool>("use_quantizer",
132126
"(bool, default false) "
133127
"Set to true for operators that should be quantized and use "
@@ -178,6 +172,16 @@ class ConcatOpGrad : public framework::OperatorWithKernel {
178172
ctx, framework::GradVarName("Out")),
179173
ctx.GetPlace());
180174
}
175+
176+
framework::OpKernelType GetKernelTypeForVar(
177+
const std::string &var_name, const Tensor &tensor,
178+
const framework::OpKernelType &expected_kernel_type) const override {
179+
if (var_name == "AxisTensor") {
180+
return expected_kernel_type;
181+
}
182+
return framework::OpKernelType(expected_kernel_type.data_type_,
183+
tensor.place(), tensor.layout());
184+
}
181185
};
182186

183187
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(ConcatOpGradNoNeedBufferVarInference,
@@ -192,6 +196,7 @@ class ConcatGradOpDescMaker : public framework::SingleGradOpDescMaker {
192196
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
193197
op->SetType("concat_grad");
194198
op->SetInput("X", Input("X"));
199+
op->SetInput("AxisTensor", Input("AxisTensor"));
195200
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
196201
op->SetOutput(framework::GradVarName("X"), InputGrad("X", false));
197202
op->SetAttrMap(Attrs());

paddle/fluid/operators/concat_op.h

Lines changed: 66 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,51 @@ limitations under the License. */
1414

1515
#pragma once
1616

17+
#include <string>
1718
#include <utility>
1819
#include <vector>
1920
#include "paddle/fluid/framework/op_registry.h"
2021
#include "paddle/fluid/operators/math/concat_and_split.h"
2122
#include "paddle/fluid/operators/strided_memcpy.h"
23+
#include "paddle/fluid/operators/utils.h"
2224

2325
namespace paddle {
2426
namespace operators {
27+
static inline framework::DDim ComputeAndCheckShape(
28+
const bool is_runtime, const std::vector<framework::DDim>& inputs_dims,
29+
const int axis) {
30+
const size_t n = inputs_dims.size();
31+
auto out_dims = inputs_dims[0];
32+
size_t in_zero_dims_size = out_dims.size();
33+
for (size_t i = 1; i < n; i++) {
34+
for (size_t j = 0; j < in_zero_dims_size; j++) {
35+
if (j == axis) {
36+
if (is_runtime) {
37+
out_dims[axis] += inputs_dims[i][j];
38+
} else {
39+
if (inputs_dims[i][j] == -1) {
40+
out_dims[axis] = -1;
41+
} else {
42+
out_dims[axis] += inputs_dims[i][j];
43+
}
44+
}
45+
} else {
46+
bool check_shape =
47+
is_runtime || (out_dims[j] > 0 && inputs_dims[i][j] > 0);
48+
if (check_shape) {
49+
// check all shape in run time
50+
PADDLE_ENFORCE_EQ(
51+
inputs_dims[0][j], inputs_dims[i][j],
52+
"ShapeError: Dimension %d in inputs' shapes must be equal. "
53+
"But recevied input[0]'s shape = "
54+
"[%s], input[%d]'s shape = [%s].",
55+
j, inputs_dims[0], i, inputs_dims[i]);
56+
}
57+
}
58+
}
59+
}
60+
return out_dims;
61+
}
2562

2663
static inline int64_t ComputeAxis(int64_t axis, int64_t rank) {
2764
if (axis < 0) {
@@ -36,9 +73,27 @@ class ConcatKernel : public framework::OpKernel<T> {
3673
void Compute(const framework::ExecutionContext& ctx) const override {
3774
auto ins = ctx.MultiInput<framework::Tensor>("X");
3875
framework::Tensor* out = ctx.Output<framework::Tensor>("Out");
39-
PADDLE_ENFORCE(ins[0], "The input should not be null.");
40-
auto axis = ComputeAxis(static_cast<int64_t>(ctx.Attr<int>("axis")),
41-
static_cast<int64_t>(ins[0]->dims().size()));
76+
PADDLE_ENFORCE_EQ(ins[0] != nullptr, true, "The input should not be null.");
77+
auto axis = ctx.Attr<int>("axis");
78+
bool need_resize_out_dims = false;
79+
if (ctx.HasInput("AxisTensor")) {
80+
auto* axis_tensor = ctx.Input<framework::Tensor>("AxisTensor");
81+
axis = GetDataFromTensor<int>(axis_tensor)[0];
82+
need_resize_out_dims = true;
83+
}
84+
axis = ComputeAxis(static_cast<int64_t>(axis),
85+
static_cast<int64_t>(ins[0]->dims().size()));
86+
87+
if (need_resize_out_dims) {
88+
const size_t n = ins.size();
89+
std::vector<framework::DDim> ins_dims(n);
90+
for (size_t i = 0; i < n; i++) {
91+
ins_dims[i] = ins[i]->dims();
92+
}
93+
94+
framework::DDim out_dims = ComputeAndCheckShape(true, ins_dims, axis);
95+
out->Resize(out_dims);
96+
}
4297
auto place = ctx.GetPlace();
4398
out->mutable_data<T>(place);
4499

@@ -92,10 +147,15 @@ class ConcatGradKernel : public framework::OpKernel<T> {
92147
}
93148
}
94149
}
95-
PADDLE_ENFORCE(ins[0], "The input should not be null.");
96-
auto axis = ComputeAxis(static_cast<int64_t>(ctx.Attr<int>("axis")),
97-
static_cast<int64_t>(ins[0]->dims().size()));
150+
PADDLE_ENFORCE_EQ(ins[0] != nullptr, true, "The input should not be null.");
98151

152+
auto axis = ctx.Attr<int>("axis");
153+
if (ctx.HasInput("AxisTensor")) {
154+
auto* axis_tensor = ctx.Input<framework::Tensor>("AxisTensor");
155+
axis = GetDataFromTensor<int>(axis_tensor)[0];
156+
}
157+
axis = ComputeAxis(static_cast<int64_t>(axis),
158+
static_cast<int64_t>(ins[0]->dims().size()));
99159
// get output tensor that the name is not kEmptyVarName
100160
std::vector<framework::Tensor*> outputs;
101161
for (size_t j = 0; j < outs.size(); ++j) {

paddle/fluid/operators/reshape_op.cc

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -186,13 +186,16 @@ class ReshapeOp : public framework::OperatorWithKernel {
186186
output_shape[unk_dim_idx] = -1;
187187
}
188188
} else {
189-
PADDLE_ENFORCE_EQ(
190-
capacity, in_size,
191-
"ShapeError: The 'shape' in ReshapeOp is invalid. "
192-
"The input tensor X'size must be equal to the capacity of 'shape'. "
193-
"But received X's shape = [%s], X's size = %d, 'shape' is [%s], the "
194-
"capacity of 'shape' is %d.",
195-
in_dims, in_size, framework::make_ddim(shape), capacity);
189+
if (all_positive) {
190+
PADDLE_ENFORCE_EQ(
191+
capacity, in_size,
192+
"ShapeError: The 'shape' in ReshapeOp is invalid. "
193+
"The input tensor X'size must be equal to the capacity of 'shape'. "
194+
"But received X's shape = [%s], X's size = %d, 'shape' is [%s], "
195+
"the "
196+
"capacity of 'shape' is %d.",
197+
in_dims, in_size, framework::make_ddim(shape), capacity);
198+
}
196199
}
197200
return framework::make_ddim(output_shape);
198201
}

paddle/fluid/operators/split_op.cc

Lines changed: 50 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

1515
#include "paddle/fluid/operators/split_op.h"
16+
#include <string>
1617

1718
namespace paddle {
1819
namespace operators {
@@ -23,8 +24,8 @@ class SplitOp : public framework::OperatorWithKernel {
2324
using framework::OperatorWithKernel::OperatorWithKernel;
2425

2526
void InferShape(framework::InferShapeContext *ctx) const override {
26-
PADDLE_ENFORCE(ctx->HasInput("X"),
27-
"Input(X) of SplitOp should not be null.");
27+
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
28+
"Input(X) of SplitOp should not be null.");
2829
PADDLE_ENFORCE_GE(ctx->Outputs("Out").size(), 1UL,
2930
"Outputs(Out) of SplitOp should not be empty.");
3031
auto in_dims = ctx->GetInputDim("X");
@@ -34,38 +35,29 @@ class SplitOp : public framework::OperatorWithKernel {
3435
std::vector<int> sections = static_cast<std::vector<int>>(
3536
ctx->Attrs().Get<std::vector<int>>("sections"));
3637
const size_t outs_number = outs_names.size();
37-
std::vector<framework::DDim> outs_dims;
38-
outs_dims.reserve(outs_number);
39-
40-
if (num > 0) {
41-
int64_t in_axis_dim = in_dims[axis];
42-
if (ctx->IsRuntime() || in_axis_dim > 0) {
43-
PADDLE_ENFORCE_EQ(in_axis_dim % num, 0,
44-
"tensor split does not result"
45-
" in an equal division");
46-
size_t out_axis_dim = in_axis_dim / num;
47-
for (size_t i = 0; i < outs_number; ++i) {
48-
auto dim = in_dims;
49-
dim[axis] = out_axis_dim;
50-
outs_dims.push_back(dim);
51-
}
52-
} else {
53-
for (size_t i = 0; i < outs_number; ++i) {
54-
auto dim = in_dims;
55-
dim[axis] = -1;
56-
outs_dims.push_back(dim);
57-
}
58-
}
59-
} else if (sections.size() > 0) {
38+
39+
if (sections.size() > 0) {
6040
PADDLE_ENFORCE_EQ(sections.size(), outs_number,
61-
"tensor split sections size"
41+
"tensor split sections size "
6242
"should be equal to output size.");
43+
}
44+
45+
if (ctx->HasInput("AxisTensor")) {
46+
auto out_dims =
47+
framework::make_ddim(std::vector<int>(in_dims.size(), -1));
48+
std::vector<framework::DDim> outs_dims(outs_number, out_dims);
49+
ctx->SetOutputsDim("Out", outs_dims);
6350
for (size_t i = 0; i < outs_number; ++i) {
64-
auto dim = in_dims;
65-
dim[axis] = sections[i];
66-
outs_dims.push_back(dim);
51+
ctx->ShareLoD("X", "Out", 0, i);
6752
}
53+
return;
6854
}
55+
56+
bool each_section_is_known =
57+
(sections.size() > 0 && !ctx->HasInputs("SectionsTensorList"));
58+
59+
auto outs_dims = UpdateOutsDims(ctx->IsRuntime(), each_section_is_known,
60+
in_dims, num, sections, axis, outs_number);
6961
ctx->SetOutputsDim("Out", outs_dims);
7062
if (axis != 0) {
7163
// Only pass LoD when not spliting along the first dim.
@@ -74,12 +66,41 @@ class SplitOp : public framework::OperatorWithKernel {
7466
}
7567
}
7668
}
69+
70+
protected:
71+
framework::OpKernelType GetExpectedKernelType(
72+
const framework::ExecutionContext &ctx) const override {
73+
return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
74+
ctx.device_context());
75+
}
76+
77+
framework::OpKernelType GetKernelTypeForVar(
78+
const std::string &var_name, const Tensor &tensor,
79+
const framework::OpKernelType &expected_kernel_type) const override {
80+
if (var_name == "AxisTensor" || var_name == "SectionsTensorList") {
81+
return expected_kernel_type;
82+
}
83+
return framework::OpKernelType(expected_kernel_type.data_type_,
84+
tensor.place(), tensor.layout());
85+
}
7786
};
7887

7988
class SplitOpMaker : public framework::OpProtoAndCheckerMaker {
8089
public:
8190
void Make() override {
8291
AddInput("X", "(Tensor) Input tensor of the split operator.");
92+
AddInput("AxisTensor",
93+
"(Tensor) The axis which the input will be splited on. "
94+
"It has higher priority than Attr(axis). "
95+
"The shape of AxisTensor must be [1]")
96+
.AsDispensable();
97+
AddInput("SectionsTensorList",
98+
"(vector<Tensor<int>>, optional). "
99+
"The length of each output along the specified axis. "
100+
"It has a higher priority than Attr(sections)."
101+
"The shape of the element in vector must be [1].")
102+
.AsDuplicable()
103+
.AsDispensable();
83104
AddOutput("Out", "(Tensor) Output tensors of the split operator.")
84105
.AsDuplicable();
85106
AddComment(R"DOC(

0 commit comments

Comments
 (0)