Skip to content

Commit a3ccbdb

Browse files
authored
Cudnn conv op (#4195)
* add cudnn_conv_op * WIP * update * update * fix grad check * use platform::memory * add support group for cudnn * update * follow comments * fix onlycpu build * update cuda define * follow comments * follow comments * merge with updates * fix compile error * follow comments * follow comments
1 parent b504a23 commit a3ccbdb

File tree

9 files changed

+489
-108
lines changed

9 files changed

+489
-108
lines changed

paddle/framework/operator.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,15 @@ class ExecutionContext {
289289
return device_context_;
290290
}
291291

292+
#ifdef PADDLE_WITH_CUDA
293+
const platform::CUDADeviceContext& cuda_device_context() const {
294+
PADDLE_ENFORCE(platform::is_gpu_place(device_context_.GetPlace()));
295+
auto cuda_ctx =
296+
reinterpret_cast<const platform::CUDADeviceContext*>(&device_context_);
297+
return *cuda_ctx;
298+
}
299+
#endif
300+
292301
private:
293302
const OperatorBase& op_;
294303
const Scope& scope_;

paddle/operators/conv2d_op.cc

Lines changed: 73 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -12,111 +12,91 @@
1212
See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

15-
#include "paddle/operators/gemm_conv2d_op.h"
15+
#include "paddle/operators/conv2d_op.h"
1616

1717
namespace paddle {
1818
namespace operators {
1919

20-
int outputSize(int input_size, int filter_size, int padding, int stride) {
21-
int output_size = (input_size - filter_size + 2 * padding) / stride + 1;
22-
return output_size;
20+
void Conv2DOp::InferShape(framework::InferShapeContext* ctx) const {
21+
PADDLE_ENFORCE(ctx->HasInput("Input"),
22+
"Input(Input) of Conv2DOp should not be null.");
23+
PADDLE_ENFORCE(ctx->HasInput("Filter"),
24+
"Input(Filter) of Conv2DOp should not be null.");
25+
PADDLE_ENFORCE(ctx->HasOutput("Output"),
26+
"Output(Output) of Conv2DOp should not be null.");
27+
28+
auto in_dims = ctx->GetInputDim("Input");
29+
auto filter_dims = ctx->GetInputDim("Filter");
30+
std::vector<int> strides = ctx->Attrs().Get<std::vector<int>>("strides");
31+
std::vector<int> paddings = ctx->Attrs().Get<std::vector<int>>("paddings");
32+
int groups = ctx->Attrs().Get<int>("groups");
33+
int input_channels = in_dims[1];
34+
int output_channels = filter_dims[0];
35+
36+
PADDLE_ENFORCE_EQ(in_dims.size(), 4, "Conv2DOp input should be 4-D.");
37+
PADDLE_ENFORCE_EQ(filter_dims.size(), 4, "Conv2DOp filter should be 4-D.");
38+
PADDLE_ENFORCE_EQ(input_channels, filter_dims[1] * groups,
39+
"The number of input channels should be equal to filter "
40+
"channels * groups.");
41+
PADDLE_ENFORCE_EQ(
42+
output_channels % groups, 0,
43+
"The number of output channels should be divided by groups.");
44+
45+
auto output_height =
46+
OutputSize(in_dims[2], filter_dims[2], paddings[0], strides[0]);
47+
auto output_width =
48+
OutputSize(in_dims[3], filter_dims[3], paddings[1], strides[1]);
49+
ctx->SetOutputDim("Output",
50+
{in_dims[0], filter_dims[0], output_height, output_width});
2351
}
2452

25-
class Conv2DOp : public framework::OperatorWithKernel {
26-
public:
27-
using framework::OperatorWithKernel::OperatorWithKernel;
28-
29-
protected:
30-
void InferShape(framework::InferShapeContext* ctx) const override {
31-
PADDLE_ENFORCE(ctx->HasInput("Input"),
32-
"Input(Input) of Conv2DOp should not be null.");
33-
PADDLE_ENFORCE(ctx->HasInput("Filter"),
34-
"Input(Filter) of Conv2DOp should not be null.");
35-
PADDLE_ENFORCE(ctx->HasOutput("Output"),
36-
"Output(Output) of Conv2DOp should not be null.");
37-
38-
auto in_dims = ctx->GetInputDim("Input");
39-
auto filter_dims = ctx->GetInputDim("Filter");
40-
std::vector<int> strides = ctx->Attrs().Get<std::vector<int>>("strides");
41-
std::vector<int> paddings = ctx->Attrs().Get<std::vector<int>>("paddings");
42-
int groups = ctx->Attrs().Get<int>("groups");
43-
int input_channels = in_dims[1];
44-
int output_channels = filter_dims[0];
45-
46-
PADDLE_ENFORCE_EQ(in_dims.size(), 4, "Conv2DOp input should be 4-D.");
47-
PADDLE_ENFORCE_EQ(filter_dims.size(), 4, "Conv2DOp filter should be 4-D.");
48-
PADDLE_ENFORCE_EQ(input_channels, filter_dims[1] * groups,
49-
"The number of input channels should be equal to filter "
50-
"channels * groups.");
51-
PADDLE_ENFORCE_EQ(
52-
output_channels % groups, 0,
53-
"The number of output channels should be divided by groups.");
54-
55-
auto output_height =
56-
outputSize(in_dims[2], filter_dims[2], paddings[0], strides[0]);
57-
auto output_width =
58-
outputSize(in_dims[3], filter_dims[3], paddings[1], strides[1]);
59-
ctx->SetOutputDim(
60-
"Output", {in_dims[0], filter_dims[0], output_height, output_width});
61-
}
62-
};
63-
64-
class Conv2DOpMaker : public framework::OpProtoAndCheckerMaker {
65-
public:
66-
Conv2DOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
67-
: OpProtoAndCheckerMaker(proto, op_checker) {
68-
AddInput(
69-
"Input",
70-
"The input tensor of convolution operator. "
71-
"The format of input tensor is NCHW. Where N is batch size, C is the "
72-
"number of channels, H and W is the height and width of image.");
73-
AddInput(
74-
"Filter",
75-
"The filter tensor of convolution operator."
76-
"The format of the filter tensor is MCHW, where M is the number of "
77-
"output image channels, C is the number of input image channels, "
78-
"H and W is height and width of filter. "
79-
"If the groups attribute is greater than 1, C equal the number of "
80-
"input image channels divided by the groups.");
81-
AddOutput("Output",
82-
"The output tensor of convolution operator."
83-
"The format of output tensor is also NCHW.");
84-
AddAttr<std::vector<int>>("strides", "strides of convolution operator.")
85-
.SetDefault({1, 1});
86-
AddAttr<std::vector<int>>("paddings", "paddings of convolution operator.")
87-
.SetDefault({0, 0});
88-
AddAttr<int>(
89-
"groups",
90-
"group size of convolution operator. "
91-
"Refer to grouped convolution in Alex Krizhevsky's paper: "
92-
"when group=2, the first half of the filters are only connected to the "
93-
"first half of the input channels, and the second half only connected "
94-
"to the second half.")
95-
.SetDefault(1);
96-
AddComment(R"DOC(
53+
Conv2DOpMaker::Conv2DOpMaker(framework::OpProto* proto,
54+
framework::OpAttrChecker* op_checker)
55+
: OpProtoAndCheckerMaker(proto, op_checker) {
56+
AddInput(
57+
"Input",
58+
"The input tensor of convolution operator. "
59+
"The format of input tensor is NCHW. Where N is batch size, C is the "
60+
"number of channels, H and W is the height and width of image.");
61+
AddInput("Filter",
62+
"The filter tensor of convolution operator."
63+
"The format of the filter tensor is MCHW, where M is the number of "
64+
"output image channels, C is the number of input image channels, "
65+
"H and W is height and width of filter. "
66+
"If the groups attribute is greater than 1, C equal the number of "
67+
"input image channels divided by the groups.");
68+
AddOutput("Output",
69+
"The output tensor of convolution operator."
70+
"The format of output tensor is also NCHW.");
71+
AddAttr<std::vector<int>>("strides", "strides of convolution operator.")
72+
.SetDefault({1, 1});
73+
AddAttr<std::vector<int>>("paddings", "paddings of convolution operator.")
74+
.SetDefault({0, 0});
75+
AddAttr<int>(
76+
"groups",
77+
"group size of convolution operator. "
78+
"Refer to grouped convolution in Alex Krizhevsky's paper: "
79+
"when group=2, the first half of the filters are only connected to the "
80+
"first half of the input channels, and the second half only connected "
81+
"to the second half.")
82+
.SetDefault(1);
83+
AddComment(R"DOC(
9784
The convolution operation calculates the output based on the input, filter
9885
and strides, paddings, groups parameters. The size of each dimension of the
9986
parameters is checked in the infer-shape.
10087
)DOC");
101-
}
102-
};
103-
104-
class Conv2DOpGrad : public framework::OperatorWithKernel {
105-
public:
106-
using framework::OperatorWithKernel::OperatorWithKernel;
88+
}
10789

108-
protected:
109-
void InferShape(framework::InferShapeContext* ctx) const override {
110-
auto in_dims = ctx->GetInputDim("Input");
111-
auto filter_dims = ctx->GetInputDim("Filter");
112-
if (ctx->HasOutput(framework::GradVarName("Input"))) {
113-
ctx->SetOutputDim(framework::GradVarName("Input"), in_dims);
114-
}
115-
if (ctx->HasOutput(framework::GradVarName("Filter"))) {
116-
ctx->SetOutputDim(framework::GradVarName("Filter"), filter_dims);
117-
}
90+
void Conv2DOpGrad::InferShape(framework::InferShapeContext* ctx) const {
91+
auto in_dims = ctx->GetInputDim("Input");
92+
auto filter_dims = ctx->GetInputDim("Filter");
93+
if (ctx->HasOutput(framework::GradVarName("Input"))) {
94+
ctx->SetOutputDim(framework::GradVarName("Input"), in_dims);
11895
}
119-
};
96+
if (ctx->HasOutput(framework::GradVarName("Filter"))) {
97+
ctx->SetOutputDim(framework::GradVarName("Filter"), filter_dims);
98+
}
99+
}
120100

121101
} // namespace operators
122102
} // namespace paddle

paddle/operators/conv2d_op.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

15-
#include "paddle/operators/gemm_conv2d_op.h"
15+
#include "paddle/operators/conv2d_op.h"
1616

1717
namespace ops = paddle::operators;
1818

paddle/operators/gemm_conv2d_op.h renamed to paddle/operators/conv2d_op.h

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,38 @@ namespace operators {
2424

2525
using Tensor = framework::Tensor;
2626

27+
// Base convolution operator definations for other conv
28+
// like operators to reuse the implementation.
29+
inline int OutputSize(int input_size, int filter_size, int padding,
30+
int stride) {
31+
int output_size = (input_size - filter_size + 2 * padding) / stride + 1;
32+
return output_size;
33+
}
34+
35+
// Define Op classes in .h file so that other conv
36+
// operator implementations can reuse the code.
37+
class Conv2DOpMaker : public framework::OpProtoAndCheckerMaker {
38+
public:
39+
Conv2DOpMaker(framework::OpProto* proto,
40+
framework::OpAttrChecker* op_checker);
41+
};
42+
43+
class Conv2DOp : public framework::OperatorWithKernel {
44+
public:
45+
using framework::OperatorWithKernel::OperatorWithKernel;
46+
47+
protected:
48+
void InferShape(framework::InferShapeContext* ctx) const override;
49+
};
50+
51+
class Conv2DOpGrad : public framework::OperatorWithKernel {
52+
public:
53+
using framework::OperatorWithKernel::OperatorWithKernel;
54+
55+
protected:
56+
void InferShape(framework::InferShapeContext* ctx) const override;
57+
};
58+
2759
template <typename Place, typename T>
2860
class GemmConv2DKernel : public framework::OpKernel<T> {
2961
public:
@@ -74,7 +106,6 @@ class GemmConv2DKernel : public framework::OpKernel<T> {
74106

75107
framework::DDim output_matrix_shape = {output_channels,
76108
output_height * output_width};
77-
78109
// convolution operator: im2col + gemm
79110
int in_step = input_channels / groups;
80111
int out_step = output_channels / groups;

paddle/operators/conv_cudnn_op.cc

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/operators/conv2d_op.h"
16+
17+
namespace paddle {
18+
namespace operators {
19+
20+
class CudnnConvOpMaker : public Conv2DOpMaker {
21+
public:
22+
CudnnConvOpMaker(framework::OpProto* proto,
23+
framework::OpAttrChecker* op_checker)
24+
: Conv2DOpMaker(proto, op_checker) {
25+
AddAttr<std::vector<int>>("dilations", "dilations of convolution operator.")
26+
.SetDefault(std::vector<int>{1, 1});
27+
AddAttr<int>("workspace_size_MB",
28+
"workspace size for cudnn, in MB, "
29+
"workspace is a section of GPU memory which will be "
30+
"allocated/freed each time the operator runs, larger "
31+
"workspace size can increase performance but also requires "
32+
"better hardward. This size should be carefully setted.")
33+
.SetDefault(4096);
34+
}
35+
};
36+
37+
} // namespace operators
38+
} // namespace paddle
39+
40+
namespace ops = paddle::operators;
41+
REGISTER_OP(conv_cudnn, ops::Conv2DOp, ops::CudnnConvOpMaker, conv_cudnn_grad,
42+
ops::Conv2DOpGrad);
43+
REGISTER_OP_CPU_KERNEL(
44+
conv_cudnn, ops::GemmConv2DKernel<paddle::platform::CPUPlace, float>);
45+
REGISTER_OP_CPU_KERNEL(
46+
conv_cudnn_grad,
47+
ops::GemmConvGrad2DKernel<paddle::platform::CPUPlace, float>);

0 commit comments

Comments
 (0)