Skip to content

Commit d908c3b

Browse files
authored
Merge pull request #9008 from lcy-seso/enhance_reshape
Enhance reshape
2 parents 6cfc0c1 + 5b8bb34 commit d908c3b

File tree

9 files changed

+375
-111
lines changed

9 files changed

+375
-111
lines changed

paddle/fluid/operators/reshape_op.cc

Lines changed: 57 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -17,90 +17,66 @@ limitations under the License. */
1717
namespace paddle {
1818
namespace operators {
1919

20-
class ReshapeOp : public framework::OperatorWithKernel {
21-
public:
22-
ReshapeOp(const std::string &type, const framework::VariableNameMap &inputs,
23-
const framework::VariableNameMap &outputs,
24-
const framework::AttributeMap &attrs)
25-
: OperatorWithKernel(type, inputs, outputs, attrs) {}
26-
27-
void InferShape(framework::InferShapeContext *ctx) const override {
28-
// input check
29-
PADDLE_ENFORCE(ctx->HasInput("X"),
30-
"Input(X) of ReshapeOp should not be null.");
31-
PADDLE_ENFORCE(ctx->HasOutput("Out"),
32-
"Output(Out) of ReshapeOp should not be null.");
33-
34-
auto shape = ctx->Attrs().Get<std::vector<int>>("shape");
35-
PADDLE_ENFORCE(shape.size() > 0, "Attr(shape) shouldn't be empty.");
36-
auto x_dims = ctx->GetInputDim("X");
37-
38-
std::vector<size_t> neg_dims_idx;
39-
// set some dimension to -1 if it is unknown
40-
const int unknown_size = -1;
41-
for (size_t i = 0; i < shape.size(); ++i) {
42-
PADDLE_ENFORCE(shape[i] > 0 || shape[i] == unknown_size,
43-
"Each dimension of Attr(shape) must be positive or %d.",
44-
unknown_size);
45-
if (shape[i] == unknown_size) {
46-
neg_dims_idx.push_back(i);
47-
PADDLE_ENFORCE(neg_dims_idx.size() <= 1,
48-
"Only one dimension of Attr(shape) can be unknown.");
49-
}
50-
}
51-
52-
int64_t capacity =
53-
std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>());
54-
int64_t in_size = framework::product(x_dims);
55-
if (neg_dims_idx.size() == 1) {
56-
// dim infer
57-
shape[neg_dims_idx[0]] = in_size / (-capacity);
58-
// recalculate capacity
59-
capacity = shape[neg_dims_idx[0]] * (-capacity);
60-
}
61-
// capacity check
62-
PADDLE_ENFORCE(capacity == in_size,
63-
"The size of Input(X) mismatches with Attr(shape).");
64-
// resize output
65-
std::vector<int64_t> shape_int64(shape.size(), 0);
66-
std::transform(shape.begin(), shape.end(), shape_int64.begin(),
67-
[](int a) { return static_cast<int64_t>(a); });
68-
auto out_dims = framework::make_ddim(shape_int64);
69-
ctx->SetOutputDim("Out", out_dims);
70-
if (shape[0] == x_dims[0]) {
71-
// Only pass LoD when the first dimension is equal between
72-
// output and input.
73-
ctx->ShareLoD("X", /*->*/ "Out");
74-
}
75-
}
76-
};
77-
7820
class ReshapeOpMaker : public framework::OpProtoAndCheckerMaker {
7921
public:
8022
ReshapeOpMaker(OpProto *proto, OpAttrChecker *op_checker)
8123
: OpProtoAndCheckerMaker(proto, op_checker) {
82-
AddInput("X", "The input tensor of reshape operator.");
83-
AddOutput("Out", "The output tensor of reshape operator.");
84-
AddAttr<std::vector<int>>("shape",
85-
"(vector<int>) "
86-
"Target shape of reshape operator.");
24+
AddInput("X", "(Tensor). The input tensor of reshape operator.");
25+
AddInput("Shape",
26+
"(Tensor<int32>, optional). If provided, reshape according to "
27+
"this given shape. That is to say it has a higher priority than "
28+
"the shape attribute, while the shape attribute still should be "
29+
"set correctly to gurantee shape inference in compile time.")
30+
.AsDispensable();
31+
AddOutput("Out", "(Tensor). The output tensor of reshape operator.");
32+
AddAttr<std::vector<int>>(
33+
"shape", "(std::vector<int>) Target shape of reshape operator.");
8734
AddAttr<bool>("inplace",
88-
"Change the source tensor's shape without copy memory.")
89-
.SetDefault(true);
35+
"(default: false) Change the source tensor's shape without "
36+
"memory copy. When Attr(inplace) is set true, the output "
37+
"tensor shares memory with Input(X), otherwise, a new output "
38+
"tensor is created, and its data are copied from Input(x).")
39+
.SetDefault(false);
9040
AddComment(R"DOC(
9141
Reshape Operator.
9242
93-
Reshape Input(X) into the shape specified by Attr(shape).
43+
Reshape Input(X) into the shape specified by Attr(shape) or Input(Shape). The
44+
data in Input(X) are unchanged.
45+
46+
Examples:
9447
95-
An example:
96-
Given a 2-D tensor X with 2 rows and 2 columns : [[1, 2], [3, 4]]
48+
1. Given a 3-D tensor Input(X) with a shape [2, 4, 6], and the target shape
49+
specified by Attr(shape) is [6, 8], the reshape operator will transform Input(X)
50+
into a 2-D tensor with shape [6, 8] and leaving Input(X)'s data unchanged.
9751
98-
and target shape = [1, 4], the reshape operator will transform
99-
the tensor X into a 2-D tensor: [[1, 2, 3, 4]]
52+
2. Given a 3-D tensor Input(X) with a shape [2, 4, 6], and the target shape
53+
specified by Attr(shape) is [2, 3, -1, 2], the reshape operator will transform
54+
Input(X) into a 4-D tensor with shape [2, 3, 4, 2] and leaving Input(X)'s data
55+
unchanged. In this case, one and only dimension of Attr(shape) can be set to -1,
56+
the value of this dimension is inferred from the total element number of
57+
Input(X) and remaining dimensions.
58+
59+
3. Given a 3-D tensor Input(X) with a shape [2, 4, 6], and the target shape
60+
specified by Attr(shape) is [-1, 0, 3, 2], the reshape operator will transform
61+
Input(X) into a 4-D tensor with shape [2, 4, 3, 2] and leaving Input(X)'s data
62+
unchanged. In this case, besides -1, 0 means the actual dimension value is going
63+
to be copied from the corresponding dimension of Input(X).
64+
65+
Note:
66+
67+
1. One and only one dimension in Attr(shape) can be set -1. In this case,
68+
the actual dimension value will be infered from the total element number of
69+
Input(X) and remaining dimensions.
70+
71+
2. More than one dimensions in Attr(shape) can be set to 0, which means the real
72+
dimension value will be copied from Input(X) at runtime. Note that the index of
73+
0 can not exceed Rank(X). For example, Input(X) is a 3-D tensor with shape
74+
[2, 3, 4], Attr(shape) = [2, 3, 2, 0] is an invalid input.
75+
76+
3. Input(Shape) has a higher priority than Attr(shape) if it is provided, while
77+
Attr(shape) still should be set correctly to gurantee shape inference in
78+
compile-time.
10079
101-
One dimension in the target shape can be set -1, representing that its
102-
size is unknown. In this case, the real dimension will be infered from
103-
the original shape of Input(X) and other dimensions in the target shape.
10480
)DOC");
10581
}
10682
};
@@ -119,6 +95,14 @@ class ReshapeGradOp : public framework::OperatorWithKernel {
11995
"Input(Out@GRAD) shouldn't be null.");
12096
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
12197
}
98+
99+
protected:
100+
framework::OpKernelType GetExpectedKernelType(
101+
const framework::ExecutionContext &ctx) const override {
102+
return framework::OpKernelType(
103+
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
104+
ctx.device_context());
105+
}
122106
};
123107

124108
} // namespace operators

paddle/fluid/operators/reshape_op.h

Lines changed: 120 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,17 +20,129 @@ limitations under the License. */
2020
namespace paddle {
2121
namespace operators {
2222

23+
class ReshapeOp : public framework::OperatorWithKernel {
24+
public:
25+
ReshapeOp(const std::string &type, const framework::VariableNameMap &inputs,
26+
const framework::VariableNameMap &outputs,
27+
const framework::AttributeMap &attrs)
28+
: OperatorWithKernel(type, inputs, outputs, attrs) {}
29+
30+
void InferShape(framework::InferShapeContext *ctx) const override {
31+
PADDLE_ENFORCE(ctx->HasInput("X"),
32+
"Input(X) of ReshapeOp should not be null.");
33+
PADDLE_ENFORCE(ctx->HasOutput("Out"),
34+
"Output(Out) of ReshapeOp should not be null.");
35+
36+
const std::vector<int> &shape = ctx->Attrs().Get<std::vector<int>>("shape");
37+
PADDLE_ENFORCE(!shape.empty(),
38+
"The shape information must be set by Attr(shape).");
39+
40+
if (ctx->HasInput("Shape") && ctx->IsRuntime()) {
41+
// If true, set the shape of Output(Out) according to Input(Shape) in
42+
// ReshapeKernel with ExecutionContext. Also check LoD in ReshapeKernel.
43+
ctx->ShareLoD("X", /*->*/ "Out");
44+
return;
45+
}
46+
47+
auto x_dims = ctx->GetInputDim("X");
48+
auto out_dims = ValidateShape(shape, x_dims);
49+
ctx->SetOutputDim("Out", out_dims);
50+
if (x_dims[0] == out_dims[0]) {
51+
// Only pass LoD when the first dimension of output and Input(X)
52+
// are the same.
53+
ctx->ShareLoD("X", /*->*/ "Out");
54+
}
55+
}
56+
57+
static framework::DDim ValidateShape(const std::vector<int> shape,
58+
const framework::DDim &in_dims) {
59+
const int64_t in_size = framework::product(in_dims);
60+
// only one dimension canbe set to -1, whose size will be automatically
61+
// infered.
62+
const int64_t unk_dim_val = -1;
63+
const int64_t copy_dim_val = 0;
64+
65+
std::vector<int64_t> output_shape(shape.size(), 0);
66+
int64_t capacity = 1;
67+
int unk_dim_idx = -1;
68+
for (size_t i = 0; i < shape.size(); ++i) {
69+
if (shape[i] == unk_dim_val) {
70+
PADDLE_ENFORCE(
71+
unk_dim_idx == -1,
72+
"Only one input dimension of Attr(shape) can be unknown.");
73+
unk_dim_idx = i;
74+
} else if (shape[i] == copy_dim_val) {
75+
PADDLE_ENFORCE(
76+
static_cast<int>(i) < in_dims.size(),
77+
"The index of dimension to copy from input shape must be less "
78+
"than the size of input shape.");
79+
} else {
80+
PADDLE_ENFORCE(
81+
shape[i] > 0,
82+
"Each input dimension of Attr(shape) must not be negtive except "
83+
"one unknown dimension.");
84+
}
85+
86+
capacity *= (shape[i] ? shape[i] : in_dims[i]);
87+
output_shape[i] =
88+
(shape[i] ? static_cast<int64_t>(shape[i]) : in_dims[i]);
89+
}
90+
91+
if (unk_dim_idx != -1) {
92+
output_shape[unk_dim_idx] = -in_size / capacity;
93+
PADDLE_ENFORCE_EQ(output_shape[unk_dim_idx] * capacity, -in_size,
94+
"Invalid shape is given.");
95+
} else {
96+
PADDLE_ENFORCE_EQ(capacity, in_size, "Invalid shape is given.");
97+
}
98+
return framework::make_ddim(output_shape);
99+
}
100+
101+
protected:
102+
framework::OpKernelType GetExpectedKernelType(
103+
const framework::ExecutionContext &ctx) const override {
104+
return framework::OpKernelType(
105+
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
106+
ctx.device_context());
107+
}
108+
};
109+
23110
template <typename DeviceContext, typename T>
24111
class ReshapeKernel : public framework::OpKernel<T> {
25112
public:
26-
void Compute(const framework::ExecutionContext& ctx) const {
27-
auto* out = ctx.Output<framework::Tensor>("Out");
28-
auto* in = ctx.Input<framework::Tensor>("X");
113+
void Compute(const framework::ExecutionContext &ctx) const {
114+
auto *out = ctx.Output<framework::LoDTensor>("Out");
115+
auto *in = ctx.Input<framework::LoDTensor>("X");
116+
auto *shape_tensor = ctx.Input<framework::LoDTensor>("Shape");
117+
118+
framework::DDim out_dims = out->dims();
119+
if (shape_tensor) {
120+
auto *shape_data = shape_tensor->data<int>();
121+
if (platform::is_gpu_place(ctx.GetPlace())) {
122+
framework::Tensor cpu_shape_tensor;
123+
TensorCopy(*shape_tensor, platform::CPUPlace(), ctx.device_context(),
124+
&cpu_shape_tensor);
125+
shape_data = cpu_shape_tensor.data<int>();
126+
}
127+
auto shape =
128+
std::vector<int>(shape_data, shape_data + shape_tensor->numel());
129+
out_dims = ReshapeOp::ValidateShape(shape, in->dims());
130+
}
131+
if (!in->lod().empty()) {
132+
PADDLE_ENFORCE_EQ(
133+
out_dims[0], in->dims()[0],
134+
"Reshape operator cannot reshape an input sequence batch "
135+
"into an output sequence batch that has a different "
136+
"number of time steps. Please consider using "
137+
"sequence_reshape op.");
138+
}
139+
29140
bool inplace = ctx.Attr<bool>("inplace");
30-
auto out_dims = out->dims();
141+
out->Resize(out_dims);
31142
if (!inplace) {
32143
out->mutable_data<T>(ctx.GetPlace());
33144
framework::TensorCopy(*in, ctx.GetPlace(), ctx.device_context(), out);
145+
// TensorCopy will resize to in_dims.
34146
out->Resize(out_dims);
35147
} else {
36148
out->ShareDataWith(*in);
@@ -42,9 +154,10 @@ class ReshapeKernel : public framework::OpKernel<T> {
42154
template <typename DeviceContext, typename T>
43155
class ReshapeGradKernel : public framework::OpKernel<T> {
44156
public:
45-
void Compute(const framework::ExecutionContext& ctx) const {
46-
auto* d_out = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
47-
auto* d_x = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
157+
void Compute(const framework::ExecutionContext &ctx) const {
158+
auto *d_out = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
159+
auto *d_x = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
160+
48161
d_x->mutable_data<T>(ctx.GetPlace());
49162
bool inplace = ctx.Attr<bool>("inplace");
50163

python/paddle/fluid/layers/detection.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from layer_function_generator import autodoc
2020
from ..layer_helper import LayerHelper
2121
import tensor
22-
import ops
2322
import nn
2423
import math
2524

@@ -58,7 +57,7 @@ def detection_output(loc,
5857
5958
This operation is to get the detection results by performing following
6059
two steps:
61-
60+
6261
1. Decode input bounding box predictions according to the prior boxes.
6362
2. Get the final detection results by applying multi-class non maximum
6463
suppression (NMS).
@@ -130,9 +129,9 @@ class number, M is number of bounding boxes. For each category
130129
target_box=loc,
131130
code_type='decode_center_size')
132131
old_shape = scores.shape
133-
scores = ops.reshape(x=scores, shape=(-1, old_shape[-1]))
132+
scores = nn.reshape(x=scores, shape=(-1, old_shape[-1]))
134133
scores = nn.softmax(input=scores)
135-
scores = ops.reshape(x=scores, shape=old_shape)
134+
scores = nn.reshape(x=scores, shape=old_shape)
136135
scores = nn.transpose(scores, perm=[0, 2, 1])
137136
scores.stop_gradient = True
138137
nmsed_outs = helper.create_tmp_variable(dtype=decoded_box.dtype)
@@ -463,7 +462,7 @@ def ssd_loss(location,
463462
num, num_prior, num_class = confidence.shape
464463

465464
def __reshape_to_2d(var):
466-
return ops.reshape(x=var, shape=[-1, var.shape[-1]])
465+
return nn.reshape(x=var, shape=[-1, var.shape[-1]])
467466

468467
# 1. Find matched boundding box by prior box.
469468
# 1.1 Compute IOU similarity between ground-truth boxes and prior boxes.
@@ -474,7 +473,7 @@ def __reshape_to_2d(var):
474473

475474
# 2. Compute confidence for mining hard examples
476475
# 2.1. Get the target label based on matched indices
477-
gt_label = ops.reshape(x=gt_label, shape=gt_label.shape + (1, ))
476+
gt_label = nn.reshape(x=gt_label, shape=gt_label.shape + (1, ))
478477
gt_label.stop_gradient = True
479478
target_label, _ = target_assign(
480479
gt_label, matched_indices, mismatch_value=background_label)
@@ -487,7 +486,7 @@ def __reshape_to_2d(var):
487486
conf_loss = nn.softmax_with_cross_entropy(confidence, target_label)
488487

489488
# 3. Mining hard examples
490-
conf_loss = ops.reshape(x=conf_loss, shape=(num, num_prior))
489+
conf_loss = nn.reshape(x=conf_loss, shape=(num, num_prior))
491490
conf_loss.stop_gradient = True
492491
neg_indices = helper.create_tmp_variable(dtype='int32')
493492
dtype = matched_indices.dtype
@@ -556,7 +555,7 @@ def __reshape_to_2d(var):
556555
# 5.3 Compute overall weighted loss.
557556
loss = conf_loss_weight * conf_loss + loc_loss_weight * loc_loss
558557
# reshape to [N, Np], N is the batch size and Np is the prior box number.
559-
loss = ops.reshape(x=loss, shape=[-1, num_prior])
558+
loss = nn.reshape(x=loss, shape=[-1, num_prior])
560559
loss = nn.reduce_sum(loss, dim=1, keep_dim=True)
561560
if normalize:
562561
normalizer = nn.reduce_sum(target_loc_weight)
@@ -709,7 +708,7 @@ def _reshape_with_axis_(input, axis=1):
709708
new_shape = [
710709
-1, reduce(lambda x, y: x * y, input.shape[axis:len(input.shape)])
711710
]
712-
out = ops.reshape(x=input, shape=new_shape)
711+
out = nn.reshape(x=input, shape=new_shape)
713712
return out
714713

715714
def _is_list_or_tuple_(data):
@@ -803,7 +802,7 @@ def _is_list_or_tuple_and_equal(data, length, err_info):
803802
mbox_loc.shape[0],
804803
mbox_loc.shape[1] * mbox_loc.shape[2] * mbox_loc.shape[3] / 4, 4
805804
]
806-
mbox_loc_flatten = ops.reshape(mbox_loc, shape=new_shape)
805+
mbox_loc_flatten = nn.reshape(mbox_loc, shape=new_shape)
807806
mbox_locs.append(mbox_loc_flatten)
808807

809808
# get conf
@@ -819,7 +818,7 @@ def _is_list_or_tuple_and_equal(data, length, err_info):
819818
conf_loc.shape[0], conf_loc.shape[1] * conf_loc.shape[2] *
820819
conf_loc.shape[3] / num_classes, num_classes
821820
]
822-
conf_loc_flatten = ops.reshape(conf_loc, shape=new_shape)
821+
conf_loc_flatten = nn.reshape(conf_loc, shape=new_shape)
823822
mbox_confs.append(conf_loc_flatten)
824823

825824
if len(box_results) == 1:

0 commit comments

Comments
 (0)