Skip to content

Commit e133df6

Browse files
committed
enable native fc forward
1 parent 038cbf7 commit e133df6

File tree

2 files changed

+33
-23
lines changed

2 files changed

+33
-23
lines changed

paddle/fluid/operators/fc_mkldnn_op.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ class FCMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
128128

129129
PADDLE_ENFORCE(input->dims().size() == 2 || input->dims().size() == 4,
130130
"Input must be with 2 or 4 dimensions, i.e. NCHW");
131+
// TODO(intel): the src weight is io and mkldnn weight need be transposed !
131132
PADDLE_ENFORCE(w->dims().size() == 2 || w->dims().size() == 4,
132133
"Weights must be with 2 or 4 dimensions, i.e. OI or OIHW");
133134

paddle/fluid/operators/fc_op.cc

Lines changed: 32 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ limitations under the License. */
1515
#include "paddle/fluid/operators/fc_op.h"
1616
#include <vector>
1717

18+
DECLARE_int32(paddle_num_threads);
19+
1820
namespace paddle {
1921
namespace operators {
2022

@@ -25,25 +27,23 @@ void FCOp::InferShape(framework::InferShapeContext* ctx) const {
2527
"Out(Output) of Fully Connected should not be null.");
2628
PADDLE_ENFORCE(ctx->HasInput("W"),
2729
"W(Input) of Fully Connected should not be null.");
28-
30+
// NCHW
2931
auto in_dims = ctx->GetInputDim("Input");
32+
// IO, I=C*H*W
3033
auto w_dims = ctx->GetInputDim("W");
3134
std::vector<int64_t> output_shape({in_dims[0], w_dims[1]});
3235

3336
if (ctx->HasInput("Bias")) {
3437
auto bias_dims = ctx->GetInputDim("Bias");
3538
PADDLE_ENFORCE_EQ(bias_dims[0], 1, "The shape of Bias must be [1, dim].");
36-
PADDLE_ENFORCE_EQ(bias_dims[1], framework::product(w_dims) / w_dims[0],
39+
PADDLE_ENFORCE_EQ(bias_dims[1], w_dims[1],
3740
"The shape of Bias must be [1, dim].");
3841
}
3942
PADDLE_ENFORCE(in_dims.size() == 2 || in_dims.size() == 4,
4043
"Fully Connected input should be 2-D or 4-D tensor.");
41-
42-
PADDLE_ENFORCE(w_dims.size() == 2 || w_dims.size() == 4,
43-
"Fully Connected input should be 2-D or 4-D tensor.");
44-
45-
PADDLE_ENFORCE_EQ(framework::product(w_dims) / w_dims[0],
46-
framework::product(in_dims) / in_dims[0],
44+
PADDLE_ENFORCE_EQ(w_dims.size(), 2UL,
45+
"Fully Connected input should be 2-D tensor.");
46+
PADDLE_ENFORCE_EQ(framework::product(in_dims) / in_dims[0], w_dims[0],
4747
"Fully Connected input and weigth size do not match.");
4848

4949
ctx->SetOutputDim("Out", framework::make_ddim(output_shape));
@@ -54,7 +54,7 @@ framework::OpKernelType FCOp::GetExpectedKernelType(
5454
const framework::ExecutionContext& ctx) const {
5555
framework::LibraryType library = framework::LibraryType::kPlain;
5656
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
57-
if (ctx.Attr<bool>("use_mkldnn");) {
57+
if (ctx.Attr<bool>("use_mkldnn")) {
5858
library = framework::LibraryType::kMKLDNN;
5959
layout = framework::DataLayout::kMKLDNN;
6060
}
@@ -75,8 +75,9 @@ void FCOpGrad::InferShape(framework::InferShapeContext* ctx) const {
7575
}
7676

7777
if (ctx->HasInput("Bias")) {
78+
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("Bias")),
79+
"Should have bias grad");
7880
auto bias_dims = ctx->GetInputDim("Bias");
79-
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("Bias"));
8081
ctx->SetOutputDim(framework::GradVarName("Bias"), bias_dims);
8182
}
8283
}
@@ -85,7 +86,7 @@ framework::OpKernelType FCOpGrad::GetExpectedKernelType(
8586
const framework::ExecutionContext& ctx) const {
8687
framework::LibraryType library = framework::LibraryType::kPlain;
8788
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
88-
if (ctx.Attr<bool>("use_mkldnn");) {
89+
if (ctx.Attr<bool>("use_mkldnn")) {
8990
library = framework::LibraryType::kMKLDNN;
9091
layout = framework::DataLayout::kMKLDNN;
9192
}
@@ -95,9 +96,11 @@ framework::OpKernelType FCOpGrad::GetExpectedKernelType(
9596
}
9697

9798
void FCOpMaker::Make() {
98-
AddInput("Input", "(Tensor) The input tensor of fully connected operator. ");
99-
AddInput("W", "(Tensor), The second input tensor of fc op.");
100-
AddInput("Bias", "(Tensor, optional) Bias vector with shape (1 x D")
99+
AddInput("Input",
100+
"(Tensor), The input tensor of fully connected operator with format "
101+
"(NCHW). ");
102+
AddInput("W", "(Tensor), The weight fc op with shape (I, O).");
103+
AddInput("Bias", "(Tensor, optional) Bias vector with shape (1 x O")
101104
.AsDispensable();
102105
AddOutput("Out", "(Tensor) The output tensor of fully connected operator. ");
103106
AddAttr<bool>("use_mkldnn",
@@ -120,25 +123,32 @@ template <typename T>
120123
class FCOpKernel : public framework::OpKernel<T> {
121124
public:
122125
void Compute(const paddle::framework::ExecutionContext& ctx) const override {
123-
PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()),
126+
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()),
124127
"It must use CPUPlace.");
125-
auto& dev_ctx = ctx.template device_context<CPUDeviceContext>();
126-
auto blas = math::GetBlas<CPUDeviceContext, T>(dev_ctx);
127128
auto input = ctx.Input<Tensor>("Input");
128129
auto w = ctx.Input<Tensor>("W");
129130
auto b = ctx.Input<Tensor>("Bias");
131+
auto output = ctx.Output<Tensor>("Out");
132+
auto in_dims = ctx->GetInputDim("Input");
133+
auto w_dims = ctx->GetInputDim("W");
130134

135+
auto& dev_ctx = ctx.template device_context<CPUDeviceContext>();
136+
auto blas = math::GetBlas<CPUDeviceContext, T>(dev_ctx);
131137
const T* input_data = input->data<T>();
132138
const T* w_data = w->data<T>();
133-
auto output = ctx.Output<Tensor>("Out");
134139
T* output_data = output->mutable_data<T>(ctx.GetPlace());
135140

136-
auto in_dims = ctx->GetInputDim("Input");
137-
auto w_dims = ctx->GetInputDim("W");
138-
std::vector<int64_t> output_shape({in_dims[0], w_dims[1]});
141+
blas.GEMM(CblasNoTrans, CblasNoTrans, in_dims[0], w_dims[1], w_dims[0],
142+
static_cast<T>(1), input_data, w_data, static_cast<T>(0),
143+
output_data);
139144

140145
if (bias) {
141146
const T* bias_data = bias->data<T>();
147+
#pragma omp parallel for if (FLAGS_paddle_num_threads > 1)
148+
for (int bs = 0; bs < in_dims[0]; bs++) {
149+
blas.AXPY(w_dims[1], static_cast<T>(1), bias_data,
150+
output_data + bs * w_dimws[1]);
151+
}
142152
}
143153
}
144154
};
@@ -150,5 +160,4 @@ namespace ops = paddle::operators;
150160
REGISTER_OPERATOR(fc, ops::FCOp, ops::FCOpMaker,
151161
paddle::framework::DefaultGradOpDescMaker<true>);
152162
REGISTER_OPERATOR(fc_grad, ops::FCOpGrad);
153-
REGISTER_OP_CPU_KERNEL(fc, ops::FCMKLDNNOpKernel<float>,
154-
ops::FCMKLDNNOpKernel<double>);
163+
REGISTER_OP_CPU_KERNEL(fc, ops::FCOpKernel<float>, ops::FCOpKernel<double>);

0 commit comments

Comments
 (0)