Skip to content

Commit 0c319e0

Browse files
Add affine grid generator op (#12238)
* Add affine grid generator. * fix ffine grid. * Add unitest. * Add CPU kernel and fix unitest. * Fix CPU kernel. * Refine code. test=develop * Fix python api. test=develop * Update python api. test=develop * Fix comment. test=develop * Rename affine_grid_generator to affine_grid and enhence unitest. test=develop * Fix unitest. test=develop
1 parent d325e66 commit 0c319e0

File tree

9 files changed

+817
-38
lines changed

9 files changed

+817
-38
lines changed

paddle/fluid/API.spec

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@ paddle.fluid.layers.mean ArgSpec(args=['x', 'name'], varargs=None, keywords=None
174174
paddle.fluid.layers.mul ArgSpec(args=['x', 'y', 'x_num_col_dims', 'y_num_col_dims', 'name'], varargs=None, keywords=None, defaults=(1, 1, None))
175175
paddle.fluid.layers.sigmoid_cross_entropy_with_logits ArgSpec(args=['x', 'label', 'name'], varargs=None, keywords=None, defaults=(None,))
176176
paddle.fluid.layers.maxout ArgSpec(args=['x', 'groups', 'name'], varargs=None, keywords=None, defaults=(None,))
177+
paddle.fluid.layers.affine_grid ArgSpec(args=['theta', 'out_shape', 'name'], varargs=None, keywords=None, defaults=(None,))
177178
paddle.fluid.layers.sequence_reverse ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,))
178179
paddle.fluid.layers.affine_channel ArgSpec(args=['x', 'scale', 'bias', 'data_layout', 'name'], varargs=None, keywords=None, defaults=(None, None, 'NCHW', None))
179180
paddle.fluid.layers.hash ArgSpec(args=['input', 'hash_size', 'num_hash', 'name'], varargs=None, keywords=None, defaults=(1, None))
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/framework/op_registry.h"
16+
#include "paddle/fluid/platform/cudnn_helper.h"
17+
18+
namespace paddle {
19+
namespace operators {
20+
21+
using Tensor = framework::Tensor;
22+
using ScopedSpatialTransformerDescriptor =
23+
platform::ScopedSpatialTransformerDescriptor;
24+
25+
template <typename T>
26+
class CUDNNAffineGridOpKernel : public framework::OpKernel<T> {
27+
public:
28+
void Compute(const framework::ExecutionContext& ctx) const override {
29+
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
30+
"It must use CUDAPlace.");
31+
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
32+
auto handle = dev_ctx.cudnn_handle();
33+
auto* theta = ctx.Input<Tensor>("Theta");
34+
auto* output = ctx.Output<Tensor>("Output");
35+
const T* theta_data = theta->data<T>();
36+
37+
int n = theta->dims()[0];
38+
auto size_attr = ctx.Attr<std::vector<int>>("output_shape");
39+
Tensor h_sizes;
40+
int* h_size_data;
41+
if (size_attr.size() == 0) {
42+
auto* output_shape = ctx.Input<Tensor>("OutputShape");
43+
framework::TensorCopy(*output_shape, platform::CPUPlace(), &h_sizes);
44+
h_size_data = h_sizes.data<int>();
45+
} else {
46+
h_size_data = h_sizes.mutable_data<int>({4}, platform::CPUPlace());
47+
h_size_data[0] = n;
48+
h_size_data[1] = size_attr[1];
49+
h_size_data[2] = size_attr[2];
50+
h_size_data[3] = size_attr[3];
51+
}
52+
53+
T* output_data = output->mutable_data<T>(
54+
{n, h_size_data[2], h_size_data[3], 2}, ctx.GetPlace());
55+
ScopedSpatialTransformerDescriptor st_desc;
56+
cudnnSpatialTransformerDescriptor_t cudnn_st_desc =
57+
st_desc.descriptor<T>(4, h_size_data);
58+
59+
PADDLE_ENFORCE(platform::dynload::cudnnSpatialTfGridGeneratorForward(
60+
handle, cudnn_st_desc, theta_data, output_data));
61+
}
62+
};
63+
64+
template <typename T>
65+
class CUDNNAffineGridGradOpKernel : public framework::OpKernel<T> {
66+
public:
67+
void Compute(const framework::ExecutionContext& ctx) const override {
68+
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
69+
"It must use CUDAPlace.");
70+
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
71+
auto handle = dev_ctx.cudnn_handle();
72+
auto output_grad = ctx.Input<Tensor>(framework::GradVarName("Output"));
73+
auto theta_grad = ctx.Output<Tensor>(framework::GradVarName("Theta"));
74+
75+
int n = output_grad->dims()[0];
76+
auto size_attr = ctx.Attr<std::vector<int>>("output_shape");
77+
Tensor h_sizes;
78+
int* h_size_data;
79+
if (size_attr.size() == 0) {
80+
auto* output_shape = ctx.Input<Tensor>("OutputShape");
81+
framework::TensorCopy(*output_shape, platform::CPUPlace(), &h_sizes);
82+
h_size_data = h_sizes.data<int>();
83+
} else {
84+
h_size_data = h_sizes.mutable_data<int>({4}, platform::CPUPlace());
85+
h_size_data[0] = n;
86+
h_size_data[1] = size_attr[1];
87+
h_size_data[2] = size_attr[2];
88+
h_size_data[3] = size_attr[3];
89+
}
90+
91+
ScopedSpatialTransformerDescriptor st_desc;
92+
cudnnSpatialTransformerDescriptor_t cudnn_st_desc =
93+
st_desc.descriptor<T>(4, h_size_data);
94+
95+
const T* output_grad_data = output_grad->data<T>();
96+
T* theta_grad_data = theta_grad->mutable_data<T>(ctx.GetPlace());
97+
98+
PADDLE_ENFORCE(platform::dynload::cudnnSpatialTfGridGeneratorBackward(
99+
handle, cudnn_st_desc, output_grad_data, theta_grad_data));
100+
}
101+
};
102+
103+
} // namespace operators
104+
} // namespace paddle
105+
106+
namespace plat = paddle::platform;
107+
REGISTER_OP_KERNEL(affine_grid, CUDNN, plat::CUDAPlace,
108+
paddle::operators::CUDNNAffineGridOpKernel<float>,
109+
paddle::operators::CUDNNAffineGridOpKernel<double>);
110+
REGISTER_OP_KERNEL(affine_grid_grad, CUDNN, plat::CUDAPlace,
111+
paddle::operators::CUDNNAffineGridGradOpKernel<float>,
112+
paddle::operators::CUDNNAffineGridGradOpKernel<double>);
Lines changed: 233 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,233 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/operators/affine_grid_op.h"
16+
#include <string>
17+
#include "paddle/fluid/framework/op_registry.h"
18+
#ifdef PADDLE_WITH_CUDA
19+
#include "paddle/fluid/platform/cudnn_helper.h"
20+
#endif
21+
22+
namespace paddle {
23+
namespace operators {
24+
25+
using Tensor = framework::Tensor;
26+
27+
template <typename T>
28+
struct Linspace<paddle::platform::CPUDeviceContext, T> {
29+
framework::Tensor operator()(T start, T end, int count,
30+
const framework::ExecutionContext& ctx) {
31+
Tensor numbers;
32+
T* number_data = numbers.mutable_data<T>({count}, platform::CPUPlace());
33+
T slice = (end - start) / (T)(count - 1);
34+
for (int i = 0; i < count; ++i) {
35+
number_data[i] = start + (T)i * slice;
36+
}
37+
return numbers;
38+
}
39+
};
40+
41+
class AffineGridOp : public framework::OperatorWithKernel {
42+
public:
43+
using framework::OperatorWithKernel::OperatorWithKernel;
44+
void InferShape(framework::InferShapeContext* ctx) const override {
45+
PADDLE_ENFORCE(ctx->HasInput("Theta"),
46+
"Input(Theta) of AffineGridOp should not be null.");
47+
PADDLE_ENFORCE(ctx->HasOutput("Output"),
48+
"Output(Output) of AffineGridOp should not be null.");
49+
auto theta_dims = ctx->GetInputDim("Theta");
50+
PADDLE_ENFORCE(theta_dims.size() == 3,
51+
"AffineGrid's Input(Theta) should be 3-D tensor.");
52+
53+
auto output_shape = ctx->Attrs().Get<std::vector<int>>("output_shape");
54+
if (output_shape.size() == 0) {
55+
PADDLE_ENFORCE(ctx->HasInput("OutputShape"),
56+
"Input(OutputShape) of AffineGridOp should not be null if "
57+
"attr(output_shape) is not configured.");
58+
auto output_shape_dims = ctx->GetInputDim("OutputShape");
59+
PADDLE_ENFORCE(output_shape_dims.size() == 1,
60+
"AffineGrid's Input(OutputShape) should be 1-D tensor.");
61+
} else {
62+
PADDLE_ENFORCE(output_shape.size() == 4,
63+
"The size of attr(output_shape) should be 4.");
64+
}
65+
66+
PADDLE_ENFORCE(theta_dims[1] == 2, "Input(theta) dims[1] should be 2.");
67+
PADDLE_ENFORCE(theta_dims[2] == 3, "Input(theta) dims[2] should be 3.");
68+
// N * H * W * 2
69+
ctx->SetOutputDim("Output",
70+
framework::make_ddim({theta_dims[0], -1, -1, 2}));
71+
ctx->ShareLoD("Theta", "Output");
72+
}
73+
74+
protected:
75+
framework::OpKernelType GetExpectedKernelType(
76+
const framework::ExecutionContext& ctx) const override {
77+
framework::LibraryType library{framework::LibraryType::kPlain};
78+
#ifdef PADDLE_WITH_CUDA
79+
if (platform::CanCUDNNBeUsed(ctx)) {
80+
library = framework::LibraryType::kCUDNN;
81+
}
82+
#endif
83+
auto data_type = framework::ToDataType(ctx.Input<Tensor>("Theta")->type());
84+
return framework::OpKernelType(data_type, ctx.GetPlace(),
85+
framework::DataLayout::kAnyLayout, library);
86+
}
87+
};
88+
89+
class AffineGridOpMaker : public framework::OpProtoAndCheckerMaker {
90+
public:
91+
void Make() override {
92+
AddInput(
93+
"Theta",
94+
"(Tensor) A batch of affine transform parameters with shape [N, 2, 3]. "
95+
"It is used to transform coordinate (x_0, y_0) to coordinate (x_1, "
96+
"y_1).");
97+
AddInput("OutputShape",
98+
"(Tensor) The shape of target image with format [N, C, H, W].")
99+
.AsDispensable();
100+
AddOutput("Output", "(Tensor) Output Tensor with shape [N, H, W, 2].");
101+
AddAttr<bool>(
102+
"use_cudnn",
103+
"(bool, default false) Only used in cudnn kernel, need install cudnn")
104+
.SetDefault(true);
105+
AddAttr<std::vector<int>>(
106+
"output_shape",
107+
"The target output image shape with format [N, C, H, W].")
108+
.SetDefault(std::vector<int>());
109+
110+
AddComment(R"DOC(
111+
It generates a grid of (x,y) coordinates using the parameters of the
112+
affine transformation that correspond to a set of points where the input
113+
feature map should be sampled to produce the transformed output feature map.
114+
115+
Given:
116+
Theta = [[[x_11, x_12, x_13]
117+
[x_14, x_15, x_16]]
118+
[[x_21, x_22, x_23]
119+
[x_24, x_25, x_26]]]
120+
121+
OutputShape = [2, 3, 5, 5]
122+
123+
Step 1:
124+
125+
Generate relative coordinates according to OutputShape.
126+
The values of relative coordinates are in the interval between -1 and 1.
127+
The shape of the relative coordinates is [2, H, W] as below:
128+
129+
C = [[[-1. -1. -1. -1. -1. ]
130+
[-0.5 -0.5 -0.5 -0.5 -0.5]
131+
[ 0. 0. 0. 0. 0. ]
132+
[ 0.5 0.5 0.5 0.5 0.5]
133+
[ 1. 1. 1. 1. 1. ]]
134+
[[-1. -0.5 0. 0.5 1. ]
135+
[-1. -0.5 0. 0.5 1. ]
136+
[-1. -0.5 0. 0.5 1. ]
137+
[-1. -0.5 0. 0.5 1. ]
138+
[-1. -0.5 0. 0.5 1. ]]]
139+
C[0] is the coordinates in height axis and C[1] is the coordinates in width axis.
140+
141+
Step2:
142+
Tanspose and reshape C to shape [H * W, 2] and append ones to last dimension. The we get:
143+
C_ = [[-1. -1. 1. ]
144+
[-0.5 -1. 1. ]
145+
[ 0. -1. 1. ]
146+
[ 0.5 -1. 1. ]
147+
[ 1. -1. 1. ]
148+
[-1. -0.5 1. ]
149+
[-0.5 -0.5 1. ]
150+
[ 0. -0.5 1. ]
151+
[ 0.5 -0.5 1. ]
152+
[ 1. -0.5 1. ]
153+
[-1. 0. 1. ]
154+
[-0.5 0. 1. ]
155+
[ 0. 0. 1. ]
156+
[ 0.5 0. 1. ]
157+
[ 1. 0. 1. ]
158+
[-1. 0.5 1. ]
159+
[-0.5 0.5 1. ]
160+
[ 0. 0.5 1. ]
161+
[ 0.5 0.5 1. ]
162+
[ 1. 0.5 1. ]
163+
[-1. 1. 1. ]
164+
[-0.5 1. 1. ]
165+
[ 0. 1. 1. ]
166+
[ 0.5 1. 1. ]
167+
[ 1. 1. 1. ]]
168+
Step3:
169+
Compute output by equation $$Output[i] = C_ * Theta[i]^T$$
170+
)DOC");
171+
}
172+
};
173+
174+
class AffineGridOpGrad : public framework::OperatorWithKernel {
175+
public:
176+
using framework::OperatorWithKernel::OperatorWithKernel;
177+
void InferShape(framework::InferShapeContext* ctx) const override {
178+
auto theta_dims = ctx->GetInputDim("Theta");
179+
if (ctx->HasOutput(framework::GradVarName("Theta"))) {
180+
ctx->SetOutputDim(framework::GradVarName("Theta"), theta_dims);
181+
}
182+
}
183+
184+
protected:
185+
framework::OpKernelType GetExpectedKernelType(
186+
const framework::ExecutionContext& ctx) const override {
187+
framework::LibraryType library_{framework::LibraryType::kPlain};
188+
#ifdef PADDLE_WITH_CUDA
189+
if (platform::CanCUDNNBeUsed(ctx)) {
190+
library_ = framework::LibraryType::kCUDNN;
191+
}
192+
#endif
193+
return framework::OpKernelType(
194+
framework::ToDataType(ctx.Input<Tensor>("Theta")->type()),
195+
ctx.GetPlace(), framework::DataLayout::kAnyLayout, library_);
196+
}
197+
};
198+
199+
class AffineGridGradMaker : public framework::SingleGradOpDescMaker {
200+
public:
201+
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
202+
203+
protected:
204+
std::unique_ptr<framework::OpDesc> Apply() const override {
205+
auto* op = new framework::OpDesc();
206+
op->SetType("affine_grid_grad");
207+
op->SetInput("Theta", Input("Theta"));
208+
op->SetInput("OutputShape", Input("OutputShape"));
209+
op->SetInput(framework::GradVarName("Output"), OutputGrad("Output"));
210+
211+
op->SetAttrMap(Attrs());
212+
213+
op->SetOutput(framework::GradVarName("Theta"), InputGrad("Theta"));
214+
return std::unique_ptr<framework::OpDesc>(op);
215+
}
216+
};
217+
218+
} // namespace operators
219+
} // namespace paddle
220+
221+
namespace ops = paddle::operators;
222+
REGISTER_OPERATOR(affine_grid, ops::AffineGridOp, ops::AffineGridOpMaker,
223+
ops::AffineGridGradMaker);
224+
REGISTER_OPERATOR(affine_grid_grad, ops::AffineGridOpGrad);
225+
226+
REGISTER_OP_CPU_KERNEL(
227+
affine_grid,
228+
ops::AffineGridOpKernel<paddle::platform::CPUDeviceContext, float>,
229+
ops::AffineGridOpKernel<paddle::platform::CPUDeviceContext, double>);
230+
REGISTER_OP_CPU_KERNEL(
231+
affine_grid_grad,
232+
ops::AffineGridGradOpKernel<paddle::platform::CPUDeviceContext, float>,
233+
ops::AffineGridGradOpKernel<paddle::platform::CPUDeviceContext, double>);

0 commit comments

Comments
 (0)