Skip to content

Commit 966a6ce

Browse files
authored
Merge pull request #5826 from sweetsky0901/my_unpool_max_2d
My unpool max 2d
2 parents 5a3d136 + 4ffb73f commit 966a6ce

File tree

9 files changed

+591
-2
lines changed

9 files changed

+591
-2
lines changed

paddle/operators/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,7 @@ set(DEPS_OPS
191191
sum_op
192192
pool_op
193193
maxout_op
194+
unpool_op
194195
pool_with_index_op
195196
conv_op
196197
conv_transpose_op
@@ -235,6 +236,7 @@ op_library(adagrad_op DEPS selected_rows_functor)
235236
op_library(conv_op DEPS vol2col)
236237
op_library(pool_op DEPS pooling)
237238
op_library(maxout_op DEPS maxouting)
239+
op_library(unpool_op DEPS unpooling)
238240
op_library(pool_with_index_op DEPS pooling)
239241
op_library(lod_rank_table_op SRCS lod_rank_table_op.cc DEPS lod_rank_table)
240242
op_library(lod_tensor_to_array_op SRCS lod_tensor_to_array_op.cc DEPS lod_rank_table_op)

paddle/operators/math/CMakeLists.txt

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,9 @@ if(WITH_GPU)
1313
nv_library(context_project SRCS context_project.cc context_project.cu DEPS device_context math_function)
1414
nv_library(sequence2batch SRCS sequence2batch.cc sequence2batch.cu DEPS device_context)
1515
nv_library(lstm_compute SRCS lstm_compute.cc lstm_compute.cu DEPS device_context activation_functions)
16-
nv_library(gru_compute SRCS gru_compute.cc gru_compute.cu DEPS device_context activation_functions math_function)
1716
nv_library(maxouting SRCS maxouting.cc maxouting.cu DEPS device_context)
17+
nv_library(unpooling SRCS unpooling.cc unpooling.cu DEPS device_context)
18+
nv_library(gru_compute SRCS gru_compute.cc gru_compute.cu DEPS device_context activation_functions math_function)
1819
else()
1920
cc_library(math_function SRCS math_function.cc im2col.cc DEPS cblas device_context framework_proto)
2021
cc_library(selected_rows_functor SRCS selected_rows_functor.cc DEPS selected_rows math_function)
@@ -26,8 +27,9 @@ else()
2627
cc_library(context_project SRCS context_project.cc DEPS device_context math_function)
2728
cc_library(sequence2batch SRCS sequence2batch.cc DEPS device_context)
2829
cc_library(lstm_compute SRCS lstm_compute.cc DEPS device_context activation_functions)
29-
cc_library(gru_compute SRCS gru_compute.cc DEPS device_context activation_functions math_function)
3030
cc_library(maxouting SRCS maxouting.cc DEPS device_context)
31+
cc_library(unpooling SRCS unpooling.cc DEPS device_context)
32+
cc_library(gru_compute SRCS gru_compute.cc DEPS device_context activation_functions math_function)
3133
endif()
3234

3335
cc_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor)

paddle/operators/math/unpooling.cc

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/operators/math/unpooling.h"
16+
namespace paddle {
17+
namespace operators {
18+
namespace math {
19+
template <typename T>
20+
class Unpool2dMaxFunctor<platform::CPUPlace, T> {
21+
public:
22+
void operator()(const platform::DeviceContext& context,
23+
const framework::Tensor& input,
24+
const framework::Tensor& indices, framework::Tensor* output) {
25+
const int batch_size = input.dims()[0];
26+
const int input_height = input.dims()[2];
27+
const int input_width = input.dims()[3];
28+
const int output_channels = output->dims()[1];
29+
const int output_height = output->dims()[2];
30+
const int output_width = output->dims()[3];
31+
int input_feasize = input_height * input_width;
32+
int output_feasize = output_height * output_width;
33+
const T* input_data = input.data<T>();
34+
const int* indices_data = indices.data<int>();
35+
T* output_data = output->mutable_data<T>(context.GetPlace());
36+
for (int b = 0; b < batch_size; ++b) {
37+
for (int c = 0; c < output_channels; ++c) {
38+
for (int i = 0; i < input_feasize; ++i) {
39+
int index = indices_data[i];
40+
PADDLE_ENFORCE(index < output_feasize, "err index in unpooling!");
41+
output_data[index] = input_data[i];
42+
}
43+
input_data += input_feasize;
44+
indices_data += input_feasize;
45+
output_data += output_feasize;
46+
}
47+
}
48+
}
49+
};
50+
template <class T>
51+
class Unpool2dMaxGradFunctor<platform::CPUPlace, T> {
52+
public:
53+
void operator()(const platform::DeviceContext& context,
54+
const framework::Tensor& input,
55+
const framework::Tensor& indices,
56+
const framework::Tensor& output,
57+
const framework::Tensor& output_grad,
58+
framework::Tensor* input_grad) {
59+
const int batch_size = input.dims()[0];
60+
const int input_height = input.dims()[2];
61+
const int input_width = input.dims()[3];
62+
const int output_channels = output.dims()[1];
63+
const int output_height = output.dims()[2];
64+
const int output_width = output.dims()[3];
65+
int input_feasize = input_height * input_width;
66+
int output_feasize = output_height * output_width;
67+
const int* indices_data = indices.data<int>();
68+
const T* output_grad_data = output_grad.data<T>();
69+
T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
70+
71+
for (int b = 0; b < batch_size; ++b) {
72+
for (int c = 0; c < output_channels; ++c) {
73+
for (int i = 0; i < input_feasize; ++i) {
74+
int index = indices_data[i];
75+
PADDLE_ENFORCE(index < output_feasize, "err index in unpooling!");
76+
input_grad_data[i] = output_grad_data[index];
77+
}
78+
input_grad_data += input_feasize;
79+
indices_data += input_feasize;
80+
output_grad_data += output_feasize;
81+
}
82+
}
83+
}
84+
};
85+
template class Unpool2dMaxGradFunctor<platform::CPUPlace, float>;
86+
template class Unpool2dMaxGradFunctor<platform::CPUPlace, double>;
87+
template class Unpool2dMaxFunctor<platform::CPUPlace, float>;
88+
template class Unpool2dMaxFunctor<platform::CPUPlace, double>;
89+
} // namespace math
90+
} // namespace operators
91+
} // namespace paddle

paddle/operators/math/unpooling.cu

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
/* Copyright (c) 2016 paddlepaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/operators/math/unpooling.h"
16+
#include "paddle/platform/cuda_helper.h"
17+
18+
namespace paddle {
19+
namespace operators {
20+
namespace math {
21+
template <typename T>
22+
__global__ void KernelUnpool2dMax(const int nthreads, const T* input_data,
23+
const int* indices_data,
24+
const int input_height, const int input_width,
25+
const int channels, T* output_data,
26+
const int output_height,
27+
const int output_width) {
28+
int in_n_stride = input_height * input_width * channels;
29+
int in_c_stride = input_height * input_width;
30+
int out_n_stride = output_height * output_width * channels;
31+
int out_c_stride = output_height * output_width;
32+
int index = blockIdx.x * blockDim.x + threadIdx.x;
33+
int offset = blockDim.x * gridDim.x;
34+
for (int i = index; i < nthreads; i += offset) {
35+
int bidx = i / in_n_stride;
36+
int boffset = i % in_n_stride;
37+
int cidx = boffset / in_c_stride;
38+
int out_offset = bidx * out_n_stride + cidx * out_c_stride;
39+
int out_index = indices_data[i];
40+
PADDLE_ASSERT(out_index < out_c_stride);
41+
output_data[out_offset + out_index] = input_data[i];
42+
}
43+
}
44+
template <typename T>
45+
__global__ void KernelUnpool2dMaxGrad(
46+
const int nthreads, const T* input_data, const int* indices_data,
47+
const int input_height, const int input_width, const int channels,
48+
const T* output_data, const T* output_grad, const int output_height,
49+
const int output_width, T* input_grad) {
50+
int in_n_stride = input_height * input_width * channels;
51+
int in_c_stride = input_height * input_width;
52+
int out_n_stride = output_height * output_width * channels;
53+
int out_c_stride = output_height * output_width;
54+
int index = blockIdx.x * blockDim.x + threadIdx.x;
55+
int offset = blockDim.x * gridDim.x;
56+
for (int i = index; i < nthreads; i += offset) {
57+
int bidx = i / in_n_stride;
58+
int boffset = i % in_n_stride;
59+
int cidx = boffset / in_c_stride;
60+
int out_offset = bidx * out_n_stride + cidx * out_c_stride;
61+
int out_index = indices_data[i];
62+
PADDLE_ASSERT(out_index < out_c_stride);
63+
input_grad[i] = output_grad[out_offset + out_index];
64+
}
65+
}
66+
/*
67+
* All tensors are in NCHW format.
68+
*/
69+
template <typename T>
70+
class Unpool2dMaxFunctor<platform::GPUPlace, T> {
71+
public:
72+
void operator()(const platform::DeviceContext& context,
73+
const framework::Tensor& input,
74+
const framework::Tensor& indices, framework::Tensor* output) {
75+
const int batch_size = input.dims()[0];
76+
const int input_height = input.dims()[2];
77+
const int input_width = input.dims()[3];
78+
const int output_channels = output->dims()[1];
79+
const int output_height = output->dims()[2];
80+
const int output_width = output->dims()[3];
81+
const T* input_data = input.data<T>();
82+
const int* indices_data = indices.data<int>();
83+
T* output_data = output->mutable_data<T>(context.GetPlace());
84+
int threads = 1024;
85+
int grid = (input.numel() + threads - 1) / threads;
86+
KernelUnpool2dMax<
87+
T><<<grid, threads, 0,
88+
reinterpret_cast<const platform::CUDADeviceContext&>(context)
89+
.stream()>>>(input.numel(), input_data, indices_data,
90+
input_height, input_width, output_channels,
91+
output_data, output_height, output_width);
92+
}
93+
};
94+
/*
95+
* All tensors are in NCHW format.
96+
*/
97+
template <typename T>
98+
class Unpool2dMaxGradFunctor<platform::GPUPlace, T> {
99+
public:
100+
void operator()(const platform::DeviceContext& context,
101+
const framework::Tensor& input,
102+
const framework::Tensor& indices,
103+
const framework::Tensor& output,
104+
const framework::Tensor& output_grad,
105+
framework::Tensor* input_grad) {
106+
const int batch_size = input.dims()[0];
107+
const int input_height = input.dims()[2];
108+
const int input_width = input.dims()[3];
109+
const int output_channels = output.dims()[1];
110+
const int output_height = output.dims()[2];
111+
const int output_width = output.dims()[3];
112+
const T* input_data = input.data<T>();
113+
const int* indices_data = indices.data<int>();
114+
const T* output_data = output.data<T>();
115+
const T* output_grad_data = output_grad.data<T>();
116+
T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
117+
int threads = 1024;
118+
int grid = (input.numel() + threads - 1) / threads;
119+
KernelUnpool2dMaxGrad<
120+
T><<<grid, threads, 0,
121+
reinterpret_cast<const platform::CUDADeviceContext&>(context)
122+
.stream()>>>(input.numel(), input_data, indices_data,
123+
input_height, input_width, output_channels,
124+
output_data, output_grad_data, output_height,
125+
output_width, input_grad_data);
126+
}
127+
};
128+
template class Unpool2dMaxGradFunctor<platform::GPUPlace, float>;
129+
template class Unpool2dMaxGradFunctor<platform::GPUPlace, double>;
130+
template class Unpool2dMaxFunctor<platform::GPUPlace, float>;
131+
template class Unpool2dMaxFunctor<platform::GPUPlace, double>;
132+
} // namespace math
133+
} // namespace operators
134+
} // namespace paddle

paddle/operators/math/unpooling.h

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
#include "paddle/framework/tensor.h"
17+
18+
namespace paddle {
19+
namespace operators {
20+
namespace math {
21+
template <typename Place, typename T>
22+
class Unpool2dMaxFunctor {
23+
public:
24+
void operator()(const platform::DeviceContext& context,
25+
const framework::Tensor& input,
26+
const framework::Tensor& indices, framework::Tensor* output);
27+
};
28+
template <typename Place, class T>
29+
class Unpool2dMaxGradFunctor {
30+
public:
31+
void operator()(const platform::DeviceContext& context,
32+
const framework::Tensor& input,
33+
const framework::Tensor& indices,
34+
const framework::Tensor& output,
35+
const framework::Tensor& output_grad,
36+
framework::Tensor* input_grad);
37+
};
38+
} // namespace math
39+
} // namespace operators
40+
} // namespace paddle

0 commit comments

Comments
 (0)