Skip to content

Commit 88bd7e1

Browse files
authored
Merge pull request #15027 from shippingwang/shufflechannel
Add Shuffle Channel Operator
2 parents e043ea9 + 14f2a10 commit 88bd7e1

File tree

7 files changed

+468
-0
lines changed

7 files changed

+468
-0
lines changed

paddle/fluid/API.spec

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,7 @@ paddle.fluid.layers.bilinear_tensor_product ArgSpec(args=['x', 'y', 'size', 'act
213213
paddle.fluid.layers.merge_selected_rows ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,))
214214
paddle.fluid.layers.get_tensor_from_selected_rows ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,))
215215
paddle.fluid.layers.lstm ArgSpec(args=['input', 'init_h', 'init_c', 'max_len', 'hidden_size', 'num_layers', 'dropout_prob', 'is_bidirec', 'is_test', 'name', 'default_initializer', 'seed'], varargs=None, keywords=None, defaults=(0.0, False, False, None, None, -1))
216+
paddle.fluid.layers.shuffle_channel ArgSpec(args=['x', 'group', 'name'], varargs=None, keywords=None, defaults=(None,))
216217
paddle.fluid.layers.py_func ArgSpec(args=['func', 'x', 'out', 'backward_func', 'skip_vars_in_backward_input'], varargs=None, keywords=None, defaults=(None, None))
217218
paddle.fluid.layers.psroi_pool ArgSpec(args=['input', 'rois', 'output_channels', 'spatial_scale', 'pooled_height', 'pooled_width', 'name'], varargs=None, keywords=None, defaults=(None,))
218219
paddle.fluid.layers.teacher_student_sigmoid_loss ArgSpec(args=['input', 'label', 'soft_max_up_bound', 'soft_max_lower_bound'], varargs=None, keywords=None, defaults=(15.0, -15.0))
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
/*Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
Licensed under the Apache License, Version 2.0 (the "License");
3+
you may not use this file except in compliance with the License.
4+
You may obtain a copy of the License at
5+
http://www.apache.org/licenses/LICENSE-2.0
6+
Unless required by applicable law or agreed to in writing, software
7+
distributed under the License is distributed on an "AS IS" BASIS,
8+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
See the License for the specific language governing permissions and
10+
limitations under the License. */
11+
12+
#include "paddle/fluid/operators/shuffle_channel_op.h"
13+
14+
namespace paddle {
15+
namespace operators {
16+
17+
class ShuffleChannelOp : public framework::OperatorWithKernel {
18+
public:
19+
using framework::OperatorWithKernel::OperatorWithKernel;
20+
21+
void InferShape(framework::InferShapeContext* ctx) const override {
22+
PADDLE_ENFORCE(ctx->HasInput("X"),
23+
"Input(X) of ShuffleChannelOp should not be null.");
24+
PADDLE_ENFORCE(ctx->HasOutput("Out"),
25+
"Output(Out) of ShuffleChannelOp should not be null.");
26+
27+
auto input_dims = ctx->GetInputDim("X");
28+
PADDLE_ENFORCE(input_dims.size() == 4, "The layout of input is NCHW.");
29+
30+
ctx->SetOutputDim("Out", input_dims);
31+
}
32+
33+
protected:
34+
framework::OpKernelType GetExpectedKernelType(
35+
const framework::ExecutionContext& ctx) const override {
36+
return framework::OpKernelType(ctx.Input<framework::Tensor>("X")->type(),
37+
ctx.device_context());
38+
}
39+
};
40+
41+
class ShuffleChannelOpMaker : public framework::OpProtoAndCheckerMaker {
42+
public:
43+
void Make() override {
44+
AddInput("X",
45+
"(Tensor, default Tensor<float>), "
46+
"the input feature data of ShuffleChannelOp, the layout is NCHW.");
47+
AddOutput("Out",
48+
"(Tensor, default Tensor<float>), the output of "
49+
"ShuffleChannelOp. The layout is NCHW.");
50+
AddAttr<int>("group", "the number of groups.")
51+
.SetDefault(1)
52+
.AddCustomChecker([](const int& group) {
53+
PADDLE_ENFORCE_GE(group, 1, "group should be larger than 0.");
54+
});
55+
56+
AddComment(R"DOC(
57+
Shuffle Channel operator
58+
This opearator shuffles the channels of input x.
59+
It divide the input channels in each group into several subgroups,
60+
and obtain a new order by selecting element from every subgroup one by one.
61+
62+
Shuffle channel operation makes it possible to build more powerful structures
63+
with multiple group convolutional layers.
64+
please get more information from the following paper:
65+
https://arxiv.org/pdf/1707.01083.pdf
66+
)DOC");
67+
}
68+
};
69+
70+
class ShuffleChannelGradOp : public framework::OperatorWithKernel {
71+
public:
72+
using framework::OperatorWithKernel::OperatorWithKernel;
73+
74+
void InferShape(framework::InferShapeContext* ctx) const override {
75+
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
76+
"Input(Out@Grad) should not be null");
77+
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")),
78+
"Output(X@Grad) should not be null");
79+
80+
auto input_dims = ctx->GetInputDim("X");
81+
PADDLE_ENFORCE(input_dims.size() == 4, "The layout of input is NCHW.");
82+
83+
ctx->SetOutputDim(framework::GradVarName("X"), input_dims);
84+
}
85+
86+
protected:
87+
framework::OpKernelType GetExpectedKernelType(
88+
const framework::ExecutionContext& ctx) const override {
89+
return framework::OpKernelType(ctx.Input<framework::Tensor>("X")->type(),
90+
ctx.device_context());
91+
}
92+
};
93+
94+
} // namespace operators
95+
} // namespace paddle
96+
97+
namespace ops = paddle::operators;
98+
REGISTER_OPERATOR(shuffle_channel, ops::ShuffleChannelOp,
99+
ops::ShuffleChannelOpMaker,
100+
paddle::framework::DefaultGradOpDescMaker<true>);
101+
102+
REGISTER_OPERATOR(shuffle_channel_grad, ops::ShuffleChannelGradOp);
103+
104+
REGISTER_OP_CPU_KERNEL(
105+
shuffle_channel,
106+
ops::ShuffleChannelOpKernel<paddle::platform::CPUDeviceContext, float>,
107+
ops::ShuffleChannelOpKernel<paddle::platform::CPUDeviceContext, double>);
108+
109+
REGISTER_OP_CPU_KERNEL(
110+
shuffle_channel_grad,
111+
ops::ShuffleChannelGradOpKernel<paddle::platform::CPUDeviceContext, float>,
112+
ops::ShuffleChannelGradOpKernel<paddle::platform::CPUDeviceContext,
113+
double>);
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
Licensed under the Apache License, Version 2.0 (the "License");
3+
you may not use this file except in compliance with the License.
4+
You may obtain a copy of the License at
5+
http://www.apache.org/licenses/LICENSE-2.0
6+
Unless required by applicable law or agreed to in writing, software
7+
distributed under the License is distributed on an "AS IS" BASIS,
8+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
See the License for the specific language governing permissions and
10+
limitations under the License. */
11+
12+
#include "paddle/fluid/operators/shuffle_channel_op.h"
13+
#include "paddle/fluid/platform/cuda_primitives.h"
14+
#include "paddle/fluid/platform/gpu_info.h"
15+
16+
namespace paddle {
17+
namespace operators {
18+
19+
using Tensor = framework::Tensor;
20+
static constexpr int kNumCUDAThreads = 512;
21+
static constexpr int kNumMaximumNumBlocks = 4096;
22+
23+
static inline int NumBlocks(const int N) {
24+
return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads,
25+
kNumMaximumNumBlocks);
26+
}
27+
28+
template <typename T>
29+
__global__ void ShuffleChannel(const int nthreads, const int feature_map_size,
30+
T* output, const T* input, int group_row,
31+
int group_column, int len) {
32+
int index = blockIdx.x * blockDim.x + threadIdx.x;
33+
int offset = blockDim.x * gridDim.x;
34+
for (size_t ii = index; ii < nthreads; ii += offset) {
35+
const int n = index / group_row / group_column / len;
36+
const int i = (index / group_column / len) % group_row;
37+
const int j = index / len % group_column;
38+
const int k = index - (n * feature_map_size + (i * group_column + j) * len);
39+
T* p_o = output + n * feature_map_size + (j * group_row + i) * len;
40+
p_o[k] = input[index];
41+
}
42+
}
43+
template <typename DeviceContext, typename T>
44+
class ShuffleChannelOpCUDAKernel : public framework::OpKernel<T> {
45+
public:
46+
void Compute(const framework::ExecutionContext& ctx) const override {
47+
auto* input = ctx.Input<framework::Tensor>("X");
48+
auto* output = ctx.Output<framework::Tensor>("Out");
49+
int group = ctx.Attr<int>("group");
50+
51+
auto input_dims = input->dims();
52+
auto num = input_dims[0];
53+
auto channel = input_dims[1];
54+
auto height = input_dims[2];
55+
auto weight = input_dims[3];
56+
57+
auto feature_map_size = channel * height * weight;
58+
auto sp_sz = height * weight;
59+
int group_row = group;
60+
int group_column = channel / group_row;
61+
// count is the product of NCHW same as numel()
62+
int count = num * group_column * group_row * sp_sz;
63+
64+
int blocks = NumBlocks(output->numel());
65+
int threads = kNumCUDAThreads;
66+
67+
const T* input_data = input->data<T>();
68+
T* output_data = output->mutable_data<T>(ctx.GetPlace());
69+
70+
ShuffleChannel<
71+
T><<<blocks, threads, 0, ctx.cuda_device_context().stream()>>>(
72+
count, feature_map_size, output_data, input_data, group_row,
73+
group_column, sp_sz);
74+
}
75+
};
76+
77+
template <typename DeviceContext, typename T>
78+
class ShuffleChannelGradOpCUDAKernel : public framework::OpKernel<T> {
79+
public:
80+
void Compute(const framework::ExecutionContext& ctx) const override {
81+
auto* input = ctx.Input<framework::Tensor>("X");
82+
int group = ctx.Attr<int>("group");
83+
84+
auto input_dims = input->dims();
85+
auto num = input_dims[0];
86+
auto channel = input_dims[1];
87+
auto height = input_dims[2];
88+
auto weight = input_dims[3];
89+
auto feature_map_size = channel * height * weight;
90+
auto sp_sz = height * weight;
91+
92+
int group_row = group;
93+
int group_column = channel / group_row;
94+
auto* output_grad =
95+
ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
96+
auto* input_grad =
97+
ctx.Output<framework::Tensor>(framework::GradVarName("X"));
98+
T* input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());
99+
const T* output_grad_data = output_grad->data<T>();
100+
101+
int blocks = NumBlocks(output_grad->numel());
102+
int threads = kNumCUDAThreads;
103+
int count = num * group_column * group_row * sp_sz;
104+
105+
ShuffleChannel<
106+
T><<<blocks, threads, 0, ctx.cuda_device_context().stream()>>>(
107+
count, feature_map_size, input_grad_data, output_grad_data, group_row,
108+
group_column, sp_sz);
109+
}
110+
};
111+
} // namespace operators
112+
} // namespace paddle
113+
114+
namespace ops = paddle::operators;
115+
REGISTER_OP_CUDA_KERNEL(
116+
shuffle_channel,
117+
ops::ShuffleChannelOpCUDAKernel<paddle::platform::CUDADeviceContext, float>,
118+
ops::ShuffleChannelOpCUDAKernel<paddle::platform::CUDADeviceContext,
119+
double>);
120+
REGISTER_OP_CUDA_KERNEL(
121+
shuffle_channel_grad,
122+
ops::ShuffleChannelGradOpCUDAKernel<paddle::platform::CUDADeviceContext,
123+
float>,
124+
ops::ShuffleChannelGradOpCUDAKernel<paddle::platform::CUDADeviceContext,
125+
double>);
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
Licensed under the Apache License, Version 2.0 (the "License");
3+
you may not use this file except in compliance with the License.
4+
You may obtain a copy of the License at
5+
http://www.apache.org/licenses/LICENSE-2.0
6+
Unless required by applicable law or agreed to in writing, software
7+
distributed under the License is distributed on an "AS IS" BASIS,
8+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
See the License for the specific language governing permissions and
10+
limitations under the License. */
11+
12+
#pragma once
13+
#include <algorithm>
14+
#include <vector>
15+
#include "paddle/fluid/framework/op_registry.h"
16+
#include "paddle/fluid/operators/math/math_function.h"
17+
18+
namespace paddle {
19+
namespace operators {
20+
21+
template <typename DeviceContext, typename T>
22+
class ShuffleChannelOpKernel : public framework::OpKernel<T> {
23+
public:
24+
void Compute(const framework::ExecutionContext& ctx) const override {
25+
auto* input = ctx.Input<framework::Tensor>("X");
26+
auto* output = ctx.Output<framework::Tensor>("Out");
27+
int group = ctx.Attr<int>("group");
28+
29+
auto input_dims = input->dims();
30+
auto num = input_dims[0];
31+
auto channel = input_dims[1];
32+
auto height = input_dims[2];
33+
auto weight = input_dims[3];
34+
35+
auto feature_map_size = channel * height * weight;
36+
auto sp_sz = height * weight;
37+
int group_row = group;
38+
int group_column = channel / group_row;
39+
40+
const T* input_data = input->data<T>();
41+
T* output_data = output->mutable_data<T>(ctx.GetPlace());
42+
for (int n = 0; n < num; ++n) {
43+
for (int i = 0; i < group_row; ++i) {
44+
for (int j = 0; j < group_column; ++j) {
45+
const T* p_i = input_data + n * feature_map_size +
46+
(i * group_column + j) * sp_sz;
47+
T* p_o =
48+
output_data + n * feature_map_size + (j * group_row + i) * sp_sz;
49+
memcpy(p_o, p_i, sizeof(int) * sp_sz);
50+
}
51+
}
52+
}
53+
}
54+
};
55+
56+
template <typename DeviceContext, typename T>
57+
class ShuffleChannelGradOpKernel : public framework::OpKernel<T> {
58+
public:
59+
void Compute(const framework::ExecutionContext& ctx) const override {
60+
auto* input = ctx.Input<framework::Tensor>("X");
61+
int group = ctx.Attr<int>("group");
62+
63+
auto input_dims = input->dims();
64+
auto num = input_dims[0];
65+
auto channel = input_dims[1];
66+
auto height = input_dims[2];
67+
auto weight = input_dims[3];
68+
auto feature_map_size = channel * height * weight;
69+
auto sp_sz = height * weight;
70+
71+
int group_row = group;
72+
int group_column = channel / group_row;
73+
74+
auto* output_grad =
75+
ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
76+
auto* input_grad =
77+
ctx.Output<framework::Tensor>(framework::GradVarName("X"));
78+
T* input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());
79+
const T* output_grad_data = output_grad->data<T>();
80+
for (int n = 0; n < num; ++n) {
81+
for (int i = 0; i < group_row; ++i) {
82+
for (int j = 0; j < group_column; ++j) {
83+
const T* p_i = output_grad_data + n * feature_map_size +
84+
(i * group_column + j) * sp_sz;
85+
T* p_o = input_grad_data + n * feature_map_size +
86+
(j * group_row + i) * sp_sz;
87+
memcpy(p_o, p_i, sizeof(int) * sp_sz);
88+
}
89+
}
90+
}
91+
}
92+
};
93+
94+
} // namespace operators
95+
} // namespace paddle

0 commit comments

Comments
 (0)