Skip to content

Commit c9ba51e

Browse files
committed
Merge remote-tracking branch 'ups/develop' into feature/libxsmm
2 parents 64a8e6d + 16aca3c commit c9ba51e

File tree

10 files changed

+352
-167
lines changed

10 files changed

+352
-167
lines changed

paddle/fluid/operators/detection/rpn_target_assign_op.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,9 @@ class RpnTargetAssignKernel : public framework::OpKernel<T> {
8686
std::minstd_rand engine,
8787
std::vector<int>* inds) const {
8888
std::uniform_real_distribution<float> uniform(0, 1);
89-
if (inds->size() > num) {
90-
for (int i = num; i < inds->size(); ++i) {
89+
const int64_t size = static_cast<int64_t>(inds->size());
90+
if (size > num) {
91+
for (int64_t i = num; i < size; ++i) {
9192
int rng_ind = std::floor(uniform(engine) * i);
9293
if (rng_ind < num)
9394
std::iter_swap(inds->begin() + rng_ind + offset,

paddle/fluid/operators/im2sequence_op.cc

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

1515
#include "paddle/fluid/operators/im2sequence_op.h"
16+
#include <string>
1617
#include <vector>
1718

1819
namespace paddle {
@@ -28,20 +29,19 @@ class Im2SequenceOp : public framework::OperatorWithKernel {
2829
"Input(X) of Im2SequenceOp should not be null.");
2930
PADDLE_ENFORCE(ctx->HasOutput("Out"),
3031
"Output(Out) of Im2SequenceOp op should not be null.");
31-
3232
auto in_dim = ctx->GetInputDim("X");
33+
3334
PADDLE_ENFORCE_EQ(in_dim.size(), 4,
3435
"Input(X) format must be 4D tensor, eg., NCHW.");
35-
36-
auto kernels = ctx->Attrs().Get<std::vector<int>>("kernels");
37-
auto strides = ctx->Attrs().Get<std::vector<int>>("strides");
38-
auto paddings = ctx->Attrs().Get<std::vector<int>>("paddings");
39-
4036
int batch_size = in_dim[0];
4137
int img_channels = in_dim[1];
4238
int img_height = in_dim[2];
4339
int img_width = in_dim[3];
4440

41+
auto kernels = ctx->Attrs().Get<std::vector<int>>("kernels");
42+
auto strides = ctx->Attrs().Get<std::vector<int>>("strides");
43+
auto paddings = ctx->Attrs().Get<std::vector<int>>("paddings");
44+
4545
int output_height = Im2SeqOutputSize(img_height, kernels[0], paddings[0],
4646
paddings[2], strides[0]);
4747
int output_width = Im2SeqOutputSize(img_width, kernels[1], paddings[1],
@@ -61,6 +61,10 @@ class Im2SequenceOpMaker : public framework::OpProtoAndCheckerMaker {
6161
"C: channels"
6262
"H: height"
6363
"W: width");
64+
AddInput("Y",
65+
"(Tensor) The input tensor of image real size(H, W)."
66+
"2-D with shape [batchsize, 2]")
67+
.AsDispensable();
6468
AddOutput("Out", "(LodTensor) The output data of im2sequence op,");
6569
AddAttr<std::vector<int>>("kernels",
6670
"(vector<int>), the "
@@ -73,6 +77,13 @@ class Im2SequenceOpMaker : public framework::OpProtoAndCheckerMaker {
7377
"(vector<int> default:{0, 0, 0, 0}), the "
7478
"paddings(up_pad, left_pad, down_pad, right_pad)")
7579
.SetDefault({0, 0, 0, 0});
80+
AddAttr<std::vector<int>>("out_stride",
81+
"the attribute is valid only when input(Y)"
82+
"is not NULL.this attribute represents the"
83+
"scaling of the pic through the CNN"
84+
"(vector<int> dedault:{1,1}),the out_stride"
85+
" (out_stride_height, out_stride_width)")
86+
.SetDefault({1, 1});
7687
AddComment(R"DOC(
7788
This op uses kernels to scan images and converts these images to sequences.
7889
After expanding, The number of time steps are output_height * output_width
@@ -123,7 +134,7 @@ output.data = [[ 6. 2. 8. 3. 2. 4. 6. 3.]
123134
[ 7. 1. 7. 9. 2. 1. 3. 5.]
124135
[ 5. 7. 2. 4. 1. 3. 9. 0.]
125136
[ 7. 9. 4. 8. 3. 5. 0. 8.]]
126-
output.dims = {8, 9}
137+
output.dims = {8, 8}
127138
output.lod = [[0, 4, 8]]
128139
129140
)DOC");

paddle/fluid/operators/im2sequence_op.h

Lines changed: 91 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
limitations under the License. */
1414

1515
#pragma once
16+
#include <string>
1617
#include <vector>
1718
#include "paddle/fluid/framework/data_layout.h"
1819
#include "paddle/fluid/framework/eigen.h"
@@ -39,50 +40,106 @@ class Im2SequenceKernel : public framework::OpKernel<T> {
3940
void Compute(const framework::ExecutionContext& ctx) const override {
4041
const Tensor* in = ctx.Input<Tensor>("X");
4142
LoDTensor* out = ctx.Output<LoDTensor>("Out");
42-
out->mutable_data<T>(ctx.GetPlace());
43-
// TODO(wanghaoshuang): Add layout checker after 'set_layout'
44-
// being available for python API
45-
// PADDLE_ENFORCE_EQ(in->layout(), framework::DataLayout::kNCHW,
46-
// "Input(X) layout must be NCHW");
4743
auto in_dim = in->dims();
4844
int batch_size = in_dim[0];
4945
int img_channels = in_dim[1];
5046
int img_height = in_dim[2];
5147
int img_width = in_dim[3];
52-
5348
auto kernels = ctx.Attr<std::vector<int>>("kernels");
5449
auto strides = ctx.Attr<std::vector<int>>("strides");
5550
auto paddings = ctx.Attr<std::vector<int>>("paddings");
56-
int output_height = Im2SeqOutputSize(img_height, kernels[0], paddings[0],
57-
paddings[2], strides[0]);
58-
int output_width = Im2SeqOutputSize(img_width, kernels[1], paddings[1],
59-
paddings[3], strides[1]);
60-
61-
const std::vector<int> dilations({1, 1});
62-
63-
auto out_dims = out->dims();
64-
out->Resize({batch_size, out->numel() / batch_size});
65-
for (int i = 0; i < batch_size; i++) {
66-
const Tensor src =
67-
in->Slice(i, i + 1).Resize({img_channels, img_height, img_width});
68-
Tensor dst = out->Slice(i, i + 1).Resize(
69-
{output_height, output_width, img_channels, kernels[0], kernels[1]});
70-
71-
math::Im2ColFunctor<math::ColFormat::kOCF, DeviceContext, T> f;
72-
auto& dev_ctx = ctx.template device_context<DeviceContext>();
73-
f(dev_ctx, src, dilations, strides, paddings, &dst);
74-
}
75-
out->Resize(out_dims);
76-
77-
// set lod information
78-
// TODO(wanghaoshuang): Move this to InferShape
79-
framework::LoD lod(1);
80-
lod[0].reserve(batch_size + 1);
81-
for (int i = 0, offset = 0; i < batch_size + 1; ++i) {
51+
if (ctx.HasInput("Y") && batch_size > 1) {
52+
const Tensor* imgrealsize = ctx.Input<Tensor>("Y");
53+
auto out_stride = ctx.Attr<std::vector<int>>("out_stride");
54+
Tensor cpu_shape_tensor;
55+
TensorCopySync(*imgrealsize, platform::CPUPlace(), &cpu_shape_tensor);
56+
std::vector<int> imgreal_h;
57+
std::vector<int> imgreal_w;
58+
std::vector<int> output_height;
59+
std::vector<int> output_width;
60+
int result = 0;
61+
for (int i = 0; i < batch_size; i++) {
62+
int tmp_real_h = static_cast<int>((cpu_shape_tensor.data<T>())[2 * i]);
63+
int tmp_real_w =
64+
static_cast<int>((cpu_shape_tensor.data<T>())[2 * i + 1]);
65+
if (tmp_real_h % out_stride[0] == 0) {
66+
tmp_real_h = tmp_real_h / out_stride[0];
67+
} else {
68+
tmp_real_h = tmp_real_h / out_stride[0] + 1;
69+
}
70+
if (tmp_real_w % out_stride[1] == 0) {
71+
tmp_real_w = tmp_real_w / out_stride[1];
72+
} else {
73+
tmp_real_w = tmp_real_w / out_stride[1] + 1;
74+
}
75+
imgreal_h.push_back(tmp_real_h);
76+
imgreal_w.push_back(tmp_real_w);
77+
output_height.push_back(Im2SeqOutputSize(
78+
imgreal_h[i], kernels[0], paddings[0], paddings[2], strides[0]));
79+
output_width.push_back(Im2SeqOutputSize(
80+
imgreal_w[i], kernels[1], paddings[1], paddings[3], strides[1]));
81+
result += output_height[i] * output_width[i];
82+
}
83+
84+
out->mutable_data<T>({result, img_channels * kernels[0] * kernels[1]},
85+
ctx.GetPlace());
86+
87+
const std::vector<int> dilations({1, 1});
88+
int offset_out = 0;
89+
for (int i = 0; i < batch_size; i++) {
90+
const Tensor src =
91+
in->Slice(i, i + 1).Resize({img_channels, img_height, img_width});
92+
Tensor dst = out->Slice(offset_out,
93+
offset_out + output_height[i] * output_width[i])
94+
.Resize({output_height[i], output_width[i],
95+
img_channels, kernels[0], kernels[1]});
96+
offset_out += output_height[i] * output_width[i];
97+
98+
math::Im2ColFunctor<math::ColFormat::kOCF, DeviceContext, T> f;
99+
auto& dev_ctx = ctx.template device_context<DeviceContext>();
100+
f(dev_ctx, src, dilations, strides, paddings, &dst);
101+
}
102+
framework::LoD lod(1);
103+
lod[0].reserve(batch_size + 1);
104+
int offset = 0;
105+
lod[0].push_back(offset);
106+
for (int i = 0; i < batch_size; ++i) {
107+
offset += output_height[i] * output_width[i];
108+
lod[0].push_back(offset);
109+
}
110+
out->set_lod(lod);
111+
} else {
112+
out->mutable_data<T>(ctx.GetPlace());
113+
int output_height = Im2SeqOutputSize(img_height, kernels[0], paddings[0],
114+
paddings[2], strides[0]);
115+
int output_width = Im2SeqOutputSize(img_width, kernels[1], paddings[1],
116+
paddings[3], strides[1]);
117+
118+
const std::vector<int> dilations({1, 1});
119+
auto out_dims = out->dims();
120+
out->Resize({batch_size, out->numel() / batch_size});
121+
for (int i = 0; i < batch_size; i++) {
122+
const Tensor src =
123+
in->Slice(i, i + 1).Resize({img_channels, img_height, img_width});
124+
Tensor dst =
125+
out->Slice(i, i + 1).Resize({output_height, output_width,
126+
img_channels, kernels[0], kernels[1]});
127+
128+
math::Im2ColFunctor<math::ColFormat::kOCF, DeviceContext, T> f;
129+
auto& dev_ctx = ctx.template device_context<DeviceContext>();
130+
f(dev_ctx, src, dilations, strides, paddings, &dst);
131+
}
132+
out->Resize(out_dims);
133+
framework::LoD lod(1);
134+
lod[0].reserve(batch_size + 1);
135+
int offset = 0;
82136
lod[0].push_back(offset);
83-
offset += output_height * output_width;
137+
for (int i = 0; i < batch_size; ++i) {
138+
offset += output_height * output_width;
139+
lod[0].push_back(offset);
140+
}
141+
out->set_lod(lod);
84142
}
85-
out->set_lod(lod);
86143
}
87144
};
88145

paddle/fluid/operators/math/im2col.cc

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -43,21 +43,6 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
4343
int col_height = col->dims()[3];
4444
int col_width = col->dims()[4];
4545

46-
PADDLE_ENFORCE_EQ((im_height + padding[0] + padding[2] -
47-
((dilation[0] * (filter_height - 1) + 1))) /
48-
stride[0] +
49-
1,
50-
col_height,
51-
"Output_height and padding(padding_up, padding_down) are "
52-
"inconsistent.");
53-
PADDLE_ENFORCE_EQ((im_width + padding[1] + padding[3] -
54-
((dilation[1] * (filter_width - 1) + 1))) /
55-
stride[1] +
56-
1,
57-
col_width,
58-
"Output_height and padding(padding_up, padding_down) are "
59-
"inconsistent.");
60-
6146
int channels_col = im_channels * filter_height * filter_width;
6247

6348
const T* im_data = im.data<T>();
@@ -178,17 +163,6 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
178163
int col_height = col->dims()[0];
179164
int col_width = col->dims()[1];
180165

181-
PADDLE_ENFORCE_EQ(
182-
(im_height + padding[0] + padding[2] - filter_height) / stride[0] + 1,
183-
col_height,
184-
"Output_height and padding(padding_up, padding_down) are "
185-
"inconsistent.");
186-
PADDLE_ENFORCE_EQ(
187-
(im_width + padding[1] + padding[3] - filter_width) / stride[1] + 1,
188-
col_width,
189-
"col_width and padding(padding_left, padding_right) are "
190-
"inconsistent.");
191-
192166
const T* im_data = im.data<T>();
193167
T* col_data = col->data<T>();
194168

paddle/fluid/operators/math/im2col.cu

Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -77,21 +77,6 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
7777
int col_height = col->dims()[3];
7878
int col_width = col->dims()[4];
7979

80-
PADDLE_ENFORCE_EQ((im_height + padding[0] + padding[2] -
81-
(dilation[0] * (filter_height - 1) + 1)) /
82-
stride[0] +
83-
1,
84-
col_height,
85-
"Output_height and padding(padding_up, padding_down) are "
86-
"inconsistent.");
87-
PADDLE_ENFORCE_EQ((im_width + padding[1] + padding[3] -
88-
(dilation[1] * (filter_width - 1) + 1)) /
89-
stride[1] +
90-
1,
91-
col_width,
92-
"col_width and padding(padding_left, padding_right) are "
93-
"inconsistent.");
94-
9580
int num_outputs = im_channels * col_height * col_width;
9681
int blocks = (num_outputs + 1024 - 1) / 1024;
9782
int block_x = 512;
@@ -274,21 +259,6 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
274259
int col_height = col->dims()[0];
275260
int col_width = col->dims()[1];
276261

277-
PADDLE_ENFORCE_EQ((im_height + padding[0] + padding[2] -
278-
(dilation[0] * (filter_height - 1) + 1)) /
279-
stride[0] +
280-
1,
281-
col_height,
282-
"Output_height and padding(padding_up, padding_down) are "
283-
"inconsistent.");
284-
PADDLE_ENFORCE_EQ((im_width + padding[1] + padding[3] -
285-
(dilation[1] * (filter_width - 1) + 1)) /
286-
stride[1] +
287-
1,
288-
col_width,
289-
"col_width and padding(padding_left, padding_right) are "
290-
"inconsistent.");
291-
292262
int block_dim_x = 0;
293263
int block_dim_y = 0;
294264
if (filter_height <= 4 && filter_width <= 4) {

python/paddle/fluid/backward.py

Lines changed: 39 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,8 @@ def _append_grad_suffix_(name):
123123
def _addup_repetitive_outputs_(op_descs):
124124
"""
125125
In backward part, an variable may be the output of more than one ops.
126-
In this case, the variable should be the accumulation of all the outputs.
126+
And one op may yield its multiple outputs to the same variable.
127+
In these cases, the variable should be the accumulation of all the outputs.
127128
`sum_op`s are added to implement the accumulate.
128129
"""
129130
pending_sum_ops = []
@@ -136,29 +137,46 @@ def _addup_repetitive_outputs_(op_descs):
136137
"sum", {"X": renamed_vars[var_name]}, {"Out": [var_name]},
137138
{"use_mkldnn": False}), idx))
138139
renamed_vars[var_name] = [var_name]
139-
for var_name in op_desc.output_arg_names():
140-
if var_name == core.empty_var_name(
141-
) or var_name in op_desc.input_arg_names():
142-
# empty variable or inplace op
143-
continue
144-
if len(renamed_vars[var_name]) == 0:
145-
# it's the first time we get the variable
146-
renamed_vars[var_name] = [var_name]
147-
else:
148-
if len(renamed_vars[var_name]) == 1:
140+
for param_idx, param_name in enumerate(op_desc.output_names()):
141+
arg_names = op_desc.output(param_name)
142+
for arg_idx, var_name in enumerate(arg_names):
143+
if var_name == core.empty_var_name(
144+
) or var_name in op_desc.input_arg_names():
145+
# empty variable or inplace op
146+
continue
147+
if len(renamed_vars[var_name]) == 0:
148+
# it's the first time we get the variable
149+
renamed_vars[var_name] = [var_name]
150+
else:
151+
if len(renamed_vars[var_name]) == 1:
152+
new_name = var_name + "@RENAME@" + \
153+
str(var_rename_count[var_name])
154+
var_rename_count[var_name] += 1
155+
# rename original var_name
156+
renamed_vars[var_name][0] = new_name
157+
_rename_arg_(op_descs, var_name, new_name, 0, idx)
158+
_rename_arg_(pending_sum_ops, var_name, new_name)
159+
160+
for p in op_desc.output_names()[:param_idx]:
161+
p_arg_names = op_desc.output(p)
162+
if var_name in p_arg_names:
163+
op_desc.set_output(p, [
164+
new_name if x == var_name else x
165+
for x in p_arg_names
166+
])
167+
168+
arg_names = [
169+
new_name if x == var_name else x
170+
for x in arg_names[:arg_idx]
171+
] + arg_names[arg_idx:]
172+
149173
new_name = var_name + "@RENAME@" + \
150174
str(var_rename_count[var_name])
151175
var_rename_count[var_name] += 1
152-
# rename original var_name
153-
renamed_vars[var_name][0] = new_name
154-
_rename_arg_(op_descs, var_name, new_name, 0, idx)
155-
_rename_arg_(pending_sum_ops, var_name, new_name)
156-
157-
new_name = var_name + "@RENAME@" + \
158-
str(var_rename_count[var_name])
159-
var_rename_count[var_name] += 1
160-
op_desc.rename_output(var_name, new_name)
161-
renamed_vars[var_name].append(new_name)
176+
arg_names[arg_idx] = new_name
177+
op_desc.set_output(param_name, arg_names)
178+
renamed_vars[var_name].append(new_name)
179+
162180
for var_name, inputs in renamed_vars.iteritems():
163181
if len(inputs) > 1:
164182
pending_sum_ops.append(

0 commit comments

Comments
 (0)