Skip to content

Commit 64f3e3e

Browse files
authored
Merge pull request #14069 from heavengate/grid_sampler
Grid sampler operator for spatial transformer network.
2 parents 8690deb + decaeb1 commit 64f3e3e

File tree

7 files changed

+872
-0
lines changed

7 files changed

+872
-0
lines changed

paddle/fluid/API.spec

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,7 @@ paddle.fluid.layers.affine_grid ArgSpec(args=['theta', 'out_shape', 'name'], var
178178
paddle.fluid.layers.sequence_reverse ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,))
179179
paddle.fluid.layers.affine_channel ArgSpec(args=['x', 'scale', 'bias', 'data_layout', 'name'], varargs=None, keywords=None, defaults=(None, None, 'NCHW', None))
180180
paddle.fluid.layers.hash ArgSpec(args=['input', 'hash_size', 'num_hash', 'name'], varargs=None, keywords=None, defaults=(1, None))
181+
paddle.fluid.layers.grid_sampler ArgSpec(args=['x', 'grid', 'name'], varargs=None, keywords=None, defaults=(None,))
181182
paddle.fluid.layers.log_loss ArgSpec(args=['input', 'label', 'epsilon', 'name'], varargs=None, keywords=None, defaults=(0.0001, None))
182183
paddle.fluid.layers.add_position_encoding ArgSpec(args=['input', 'alpha', 'beta', 'name'], varargs=None, keywords=None, defaults=(None,))
183184
paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True))
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/framework/op_registry.h"
16+
#include "paddle/fluid/platform/cudnn_helper.h"
17+
18+
namespace paddle {
19+
namespace operators {
20+
21+
using framework::Tensor;
22+
using ScopedTensorDescriptor = platform::ScopedTensorDescriptor;
23+
using DataLayout = platform::DataLayout;
24+
using ScopedSpatialTransformerDescriptor =
25+
platform::ScopedSpatialTransformerDescriptor;
26+
template <typename T>
27+
using CudnnDataType = platform::CudnnDataType<T>;
28+
29+
template <typename T>
30+
class CUDNNGridSampleOpKernel : public framework::OpKernel<T> {
31+
public:
32+
void Compute(const framework::ExecutionContext& ctx) const override {
33+
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
34+
"It must use CUDAPlace");
35+
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
36+
auto handle = dev_ctx.cudnn_handle();
37+
auto* input = ctx.Input<Tensor>("X");
38+
auto* grid = ctx.Input<Tensor>("Grid");
39+
auto* output = ctx.Output<Tensor>("Output");
40+
41+
int n = input->dims()[0];
42+
int c = input->dims()[1];
43+
int h = input->dims()[2];
44+
int w = input->dims()[3];
45+
const int size[4] = {n, c, h, w};
46+
47+
const T* input_data = input->data<T>();
48+
const T* grid_data = grid->data<T>();
49+
T* output_data = output->mutable_data<T>({n, c, h, w}, ctx.GetPlace());
50+
51+
ScopedSpatialTransformerDescriptor st_desc;
52+
cudnnSpatialTransformerDescriptor_t cudnn_st_desc =
53+
st_desc.descriptor<T>(4, size);
54+
55+
ScopedTensorDescriptor input_desc;
56+
ScopedTensorDescriptor output_desc;
57+
cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor<T>(
58+
DataLayout::kNCHW, framework::vectorize2int(input->dims()));
59+
cudnnTensorDescriptor_t cudnn_output_desc = output_desc.descriptor<T>(
60+
DataLayout::kNCHW, framework::vectorize2int(output->dims()));
61+
62+
CUDNN_ENFORCE(platform::dynload::cudnnSpatialTfSamplerForward(
63+
handle, cudnn_st_desc, CudnnDataType<T>::kOne(), cudnn_input_desc,
64+
input_data, grid_data, CudnnDataType<T>::kZero(), cudnn_output_desc,
65+
output_data));
66+
}
67+
};
68+
69+
template <typename T>
70+
class CUDNNGridSampleGradOpKernel : public framework::OpKernel<T> {
71+
public:
72+
void Compute(const framework::ExecutionContext& ctx) const override {
73+
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
74+
"It must use CUDAPlace");
75+
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
76+
auto handle = dev_ctx.cudnn_handle();
77+
auto* input = ctx.Input<Tensor>("X");
78+
auto* grid = ctx.Input<Tensor>("Grid");
79+
auto* output_grad = ctx.Input<Tensor>(framework::GradVarName("Output"));
80+
auto* input_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
81+
auto* grid_grad = ctx.Output<Tensor>(framework::GradVarName("Grid"));
82+
83+
auto output_grad_dims = output_grad->dims();
84+
const int n = output_grad_dims[0];
85+
const int c = output_grad_dims[1];
86+
const int h = output_grad_dims[2];
87+
const int w = output_grad_dims[3];
88+
const int size[4] = {n, c, h, w};
89+
90+
ScopedSpatialTransformerDescriptor st_dest;
91+
cudnnSpatialTransformerDescriptor_t cudnn_st_dest =
92+
st_dest.descriptor<T>(4, size);
93+
94+
const T* input_data = input->data<T>();
95+
const T* grid_data = grid->data<T>();
96+
const T* output_grad_data = output_grad->data<T>();
97+
T* input_grad_data =
98+
input_grad->mutable_data<T>(output_grad_dims, ctx.GetPlace());
99+
T* grid_grad_data =
100+
grid_grad->mutable_data<T>({n, h, w, 2}, ctx.GetPlace());
101+
102+
ScopedTensorDescriptor input_desc;
103+
ScopedTensorDescriptor input_grad_desc;
104+
ScopedTensorDescriptor output_grad_desc;
105+
cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor<T>(
106+
DataLayout::kNCHW, framework::vectorize2int(input->dims()));
107+
cudnnTensorDescriptor_t cudnn_input_grad_desc =
108+
input_grad_desc.descriptor<T>(
109+
DataLayout::kNCHW, framework::vectorize2int(input_grad->dims()));
110+
cudnnTensorDescriptor_t cudnn_output_grad_desc =
111+
output_grad_desc.descriptor<T>(
112+
DataLayout::kNCHW, framework::vectorize2int(output_grad->dims()));
113+
114+
CUDNN_ENFORCE(platform::dynload::cudnnSpatialTfSamplerBackward(
115+
handle, cudnn_st_dest, CudnnDataType<T>::kOne(), cudnn_input_desc,
116+
input_data, CudnnDataType<T>::kZero(), cudnn_input_grad_desc,
117+
input_grad_data, CudnnDataType<T>::kOne(), cudnn_output_grad_desc,
118+
output_grad_data, grid_data, CudnnDataType<T>::kZero(),
119+
grid_grad_data));
120+
}
121+
};
122+
123+
} // namespace operators
124+
} // namespace paddle
125+
126+
namespace plat = paddle::platform;
127+
REGISTER_OP_KERNEL(grid_sampler, CUDNN, plat::CUDAPlace,
128+
paddle::operators::CUDNNGridSampleOpKernel<float>,
129+
paddle::operators::CUDNNGridSampleOpKernel<double>);
130+
REGISTER_OP_KERNEL(grid_sampler_grad, CUDNN, plat::CUDAPlace,
131+
paddle::operators::CUDNNGridSampleGradOpKernel<float>,
132+
paddle::operators::CUDNNGridSampleGradOpKernel<double>);
Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/operators/grid_sampler_op.h"
16+
#include "paddle/fluid/framework/op_registry.h"
17+
#ifdef PADDLE_WITH_CUDA
18+
#include "paddle/fluid/platform/cudnn_helper.h"
19+
#endif
20+
21+
namespace paddle {
22+
namespace operators {
23+
24+
using Tensor = framework::Tensor;
25+
26+
class GridSampleOp : public framework::OperatorWithKernel {
27+
public:
28+
using framework::OperatorWithKernel::OperatorWithKernel;
29+
void InferShape(framework::InferShapeContext* ctx) const override {
30+
PADDLE_ENFORCE(ctx->HasInput("X"),
31+
"Input(X) of GridSampleOp should not be null.");
32+
PADDLE_ENFORCE(ctx->HasInput("Grid"),
33+
"Input(Grid) of GridSampleOp should not be null.");
34+
PADDLE_ENFORCE(ctx->HasOutput("Output"),
35+
"Output(Output) of GridSampleOp should not be null.");
36+
37+
auto x_dims = ctx->GetInputDim("X");
38+
auto grid_dims = ctx->GetInputDim("Grid");
39+
PADDLE_ENFORCE(x_dims.size() == 4,
40+
"Input(X) of GridSampleOp should be 4-D Tensor.");
41+
PADDLE_ENFORCE(grid_dims.size() == 4,
42+
"Input(Grid) of GridSampleOp should be 4-D Tensor.");
43+
PADDLE_ENFORCE(grid_dims[3] == 2, "Input(Grid) dims[3] should be 2.");
44+
PADDLE_ENFORCE_EQ(grid_dims[0], x_dims[0],
45+
"Input(X) and Input(Grid) dims[0] should be equal.");
46+
PADDLE_ENFORCE_EQ(
47+
grid_dims[1], x_dims[2],
48+
"Input(X) dims[2] and Input(Grid) dims[1] should be equal.");
49+
PADDLE_ENFORCE_EQ(
50+
grid_dims[2], x_dims[3],
51+
"Input(X) dims[3] and Input(Grid) dims[2] should be equal.");
52+
53+
ctx->SetOutputDim("Output", x_dims);
54+
ctx->ShareLoD("X", "Output");
55+
}
56+
57+
protected:
58+
framework::OpKernelType GetExpectedKernelType(
59+
const framework::ExecutionContext& ctx) const override {
60+
framework::LibraryType library_{framework::LibraryType::kPlain};
61+
#ifdef PADDLE_WITH_CUDA
62+
if (platform::CanCUDNNBeUsed(ctx)) {
63+
library_ = framework::LibraryType::kCUDNN;
64+
}
65+
#endif
66+
return framework::OpKernelType(
67+
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace(),
68+
framework::DataLayout::kAnyLayout, library_);
69+
}
70+
};
71+
72+
class GridSampleOpMaker : public framework::OpProtoAndCheckerMaker {
73+
public:
74+
void Make() override {
75+
AddInput("X",
76+
"(Tensor) The input data of GridSampleOp, "
77+
"This is a 4-D tensor with shape of [N, C, H, W]");
78+
AddInput(
79+
"Grid",
80+
"(Tensor) The input grid of GridSampleOp generated by AffineGridOp, "
81+
"This is a 4-D tensor with shape of [N, H, W, 2] is the concatenation "
82+
"of x and y coordinates with shape [N, H, W] in last dimention");
83+
AddOutput("Output", "(Tensor) Output tensor with shape [N, C, H, W]");
84+
AddAttr<bool>(
85+
"use_cudnn",
86+
"(bool, default true) Only used in cudnn kernel, need install cudnn")
87+
.SetDefault(true);
88+
89+
AddComment(R"DOC(
90+
This operation samples input X by using bilinear interpolation based on
91+
flow field grid, which is usually gennerated by affine_grid. The grid of
92+
shape [N, H, W, 2] is the concatenation of (grid_x, grid_y) coordinates
93+
with shape [N, H, W] each, where grid_x is indexing the 4th dimension
94+
(in width dimension) of input data x and grid_y is indexng the 3rd
95+
dimention (in height dimension), finally results is the bilinear
96+
interpolation value of 4 nearest corner points.
97+
98+
Step 1:
99+
Get (x, y) grid coordinates and scale to [0, H-1/W-1].
100+
101+
grid_x = 0.5 * (grid[:, :, :, 0] + 1) * (W - 1)
102+
grid_y = 0.5 * (grid[:, :, :, 1] + 1) * (H - 1)
103+
104+
Step 2:
105+
Indices input data X with grid (x, y) in each [H, W] area, and bilinear
106+
interpolate point value by 4 nearest points.
107+
108+
wn ------- y_n ------- en
109+
| | |
110+
| d_n |
111+
| | |
112+
x_w --d_w-- grid--d_e-- x_e
113+
| | |
114+
| d_s |
115+
| | |
116+
ws ------- y_s ------- wn
117+
118+
x_w = floor(x) // west side x coord
119+
x_e = x_w + 1 // east side x coord
120+
y_n = floor(y) // north side y coord
121+
y_s = y_s + 1 // south side y coord
122+
123+
d_w = grid_x - x_w // distance to west side
124+
d_e = x_e - grid_x // distance to east side
125+
d_n = grid_y - y_n // distance to north side
126+
d_s = y_s - grid_y // distance to south side
127+
128+
wn = X[:, :, y_n, x_w] // north-west point value
129+
en = X[:, :, y_n, x_e] // north-east point value
130+
ws = X[:, :, y_s, x_w] // south-east point value
131+
es = X[:, :, y_s, x_w] // north-east point value
132+
133+
output = wn * d_e * d_s + en * d_w * d_s
134+
+ ws * d_e * d_n + es * d_w * d_n
135+
)DOC");
136+
}
137+
};
138+
139+
class GridSampleOpGrad : public framework::OperatorWithKernel {
140+
public:
141+
using framework::OperatorWithKernel::OperatorWithKernel;
142+
void InferShape(framework::InferShapeContext* ctx) const override {
143+
auto input_dims = ctx->GetInputDim("X");
144+
auto grid_dims = ctx->GetInputDim("Grid");
145+
if (ctx->HasOutput(framework::GradVarName("X"))) {
146+
ctx->SetOutputDim(framework::GradVarName("X"), input_dims);
147+
}
148+
if (ctx->HasOutput(framework::GradVarName("Grid"))) {
149+
ctx->SetOutputDim(framework::GradVarName("Grid"), grid_dims);
150+
}
151+
}
152+
153+
protected:
154+
framework::OpKernelType GetExpectedKernelType(
155+
const framework::ExecutionContext& ctx) const override {
156+
framework::LibraryType library_{framework::LibraryType::kPlain};
157+
#ifdef PADDLE_WITH_CUDA
158+
if (platform::CanCUDNNBeUsed(ctx)) {
159+
library_ = framework::LibraryType::kCUDNN;
160+
}
161+
#endif
162+
return framework::OpKernelType(
163+
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace(),
164+
framework::DataLayout::kAnyLayout, library_);
165+
}
166+
};
167+
168+
class GridSampleGradMaker : public framework::SingleGradOpDescMaker {
169+
public:
170+
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
171+
172+
protected:
173+
std::unique_ptr<framework::OpDesc> Apply() const override {
174+
auto* op = new framework::OpDesc();
175+
op->SetType("grid_sampler_grad");
176+
op->SetInput("X", Input("X"));
177+
op->SetInput("Grid", Input("Grid"));
178+
op->SetInput(framework::GradVarName("Output"), OutputGrad("Output"));
179+
180+
op->SetAttrMap(Attrs());
181+
182+
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
183+
op->SetOutput(framework::GradVarName("Grid"), InputGrad("Grid"));
184+
return std::unique_ptr<framework::OpDesc>(op);
185+
}
186+
};
187+
188+
} // namespace operators
189+
} // namespace paddle
190+
191+
namespace ops = paddle::operators;
192+
REGISTER_OPERATOR(grid_sampler, ops::GridSampleOp, ops::GridSampleOpMaker,
193+
ops::GridSampleGradMaker);
194+
REGISTER_OPERATOR(grid_sampler_grad, ops::GridSampleOpGrad);
195+
196+
REGISTER_OP_CPU_KERNEL(
197+
grid_sampler,
198+
ops::GridSampleOpKernel<paddle::platform::CPUDeviceContext, float>,
199+
ops::GridSampleOpKernel<paddle::platform::CPUDeviceContext, double>);
200+
REGISTER_OP_CPU_KERNEL(
201+
grid_sampler_grad,
202+
ops::GridSampleGradOpKernel<paddle::platform::CPUDeviceContext, float>,
203+
ops::GridSampleGradOpKernel<paddle::platform::CPUDeviceContext, double>);

0 commit comments

Comments
 (0)