Skip to content

Commit 9d6243b

Browse files
wanghaoshuangqingqing01
authored andcommitted
Fix crop op. (#12603)
* Fix infer shape of crop op. * Speed crop op.
1 parent 49ad570 commit 9d6243b

File tree

3 files changed

+63
-19
lines changed

3 files changed

+63
-19
lines changed

paddle/fluid/operators/crop_op.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
22
33
Licensed under the Apache License, Version 2.0 (the "License");
44
you may not use this file except in compliance with the License.
@@ -188,6 +188,7 @@ namespace ops = paddle::operators;
188188
REGISTER_OPERATOR(crop, ops::CropOp, ops::CropOpMaker,
189189
paddle::framework::DefaultGradOpDescMaker<true>);
190190
REGISTER_OPERATOR(crop_grad, ops::CropOpGrad);
191-
REGISTER_OP_CPU_KERNEL(crop, ops::CropKernel<float>);
191+
REGISTER_OP_CPU_KERNEL(
192+
crop, ops::CropKernel<paddle::platform::CPUDeviceContext, float>);
192193
REGISTER_OP_CPU_KERNEL(
193194
crop_grad, ops::CropGradKernel<paddle::platform::CPUDeviceContext, float>);

paddle/fluid/operators/crop_op.cu

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
22
33
Licensed under the Apache License, Version 2.0 (the "License");
44
you may not use this file except in compliance with the License.
@@ -16,6 +16,7 @@ limitations under the License. */
1616
#include "paddle/fluid/operators/crop_op.h"
1717

1818
namespace ops = paddle::operators;
19-
REGISTER_OP_CUDA_KERNEL(crop, ops::CropKernel<float>);
19+
REGISTER_OP_CUDA_KERNEL(
20+
crop, ops::CropKernel<paddle::platform::CUDADeviceContext, float>);
2021
REGISTER_OP_CUDA_KERNEL(
2122
crop_grad, ops::CropGradKernel<paddle::platform::CUDADeviceContext, float>);

paddle/fluid/operators/crop_op.h

Lines changed: 57 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
22
33
Licensed under the Apache License, Version 2.0 (the "License");
44
you may not use this file except in compliance with the License.
@@ -58,32 +58,74 @@ static std::vector<int> GetOffsets(const framework::ExecutionContext& ctx) {
5858
return res;
5959
}
6060

61-
template <typename T>
61+
template <typename DeviceContext, typename T, size_t D>
62+
void CropFunction(const framework::ExecutionContext& context) {
63+
auto* x = context.Input<Tensor>("X");
64+
auto* out = context.Output<Tensor>("Out");
65+
auto out_dims = out->dims();
66+
if (out_dims[0] == -1) {
67+
out_dims[0] = x->dims()[0];
68+
}
69+
out->mutable_data<T>(out_dims, context.GetPlace());
70+
auto x_stride = framework::stride(x->dims());
71+
auto out_stride = framework::stride(out->dims());
72+
auto offsets = GetOffsets(context);
73+
int64_t offset = 0;
74+
for (size_t i = 0; i < offsets.size(); ++i) {
75+
offset += (x_stride[i] * offsets[i]);
76+
}
77+
78+
auto x_tensor = EigenTensor<T, D>::From(*x);
79+
auto out_tensor = EigenTensor<T, D>::From(*out);
80+
Eigen::array<int, D> e_offsets;
81+
Eigen::array<int, D> e_shape;
82+
for (size_t i = 0; i < D; ++i) {
83+
e_offsets[i] = offsets[i];
84+
e_shape[i] = out->dims()[i];
85+
}
86+
auto& place =
87+
*context.template device_context<DeviceContext>().eigen_device();
88+
out_tensor.device(place) = x_tensor.slice(e_offsets, e_shape);
89+
}
90+
91+
template <typename DeviceContext, typename T>
6292
class CropKernel : public framework::OpKernel<T> {
6393
public:
6494
void Compute(const framework::ExecutionContext& context) const override {
65-
auto* x = context.Input<Tensor>("X");
66-
auto* out = context.Output<Tensor>("Out");
67-
const T* x_data = x->data<T>();
68-
T* out_data = out->mutable_data<T>(context.GetPlace());
69-
auto x_stride = framework::stride(x->dims());
70-
auto out_stride = framework::stride(out->dims());
71-
auto offsets = GetOffsets(context);
72-
int64_t offset = 0;
73-
for (size_t i = 0; i < offsets.size(); ++i) {
74-
offset += (x_stride[i] * offsets[i]);
95+
int rank = context.Input<Tensor>("X")->dims().size();
96+
switch (rank) {
97+
case 1:
98+
CropFunction<DeviceContext, T, 1>(context);
99+
break;
100+
case 2:
101+
CropFunction<DeviceContext, T, 2>(context);
102+
break;
103+
case 3:
104+
CropFunction<DeviceContext, T, 3>(context);
105+
break;
106+
case 4:
107+
CropFunction<DeviceContext, T, 4>(context);
108+
break;
109+
case 5:
110+
CropFunction<DeviceContext, T, 5>(context);
111+
break;
112+
case 6:
113+
CropFunction<DeviceContext, T, 6>(context);
114+
break;
115+
default:
116+
PADDLE_THROW(
117+
"CropOp only support tensors with no more than 6 dimensions.");
75118
}
76-
StridedMemcpy<T>(context.device_context(), x_data + offset, x_stride,
77-
out->dims(), out_stride, out_data);
78119
}
79120
};
80121

81122
template <typename DeviceContext, typename T, size_t D>
82123
void CropGradFunction(const framework::ExecutionContext& context) {
83124
auto* d_x = context.Output<Tensor>(framework::GradVarName("X"));
125+
auto* x = context.Input<Tensor>("X");
84126
if (d_x != nullptr) {
85127
auto* d_out = context.Input<Tensor>(framework::GradVarName("Out"));
86-
d_x->mutable_data<T>(context.GetPlace());
128+
d_x->mutable_data<T>(x->dims(), context.GetPlace());
87129
auto offsets = GetOffsets(context);
88130
Eigen::array<std::pair<int, int>, D> paddings;
89131
for (size_t i = 0; i < D; ++i) {

0 commit comments

Comments
 (0)