@@ -15,6 +15,8 @@ limitations under the License. */
15
15
#include " paddle/fluid/operators/fc_op.h"
16
16
#include < vector>
17
17
18
+ DECLARE_int32 (paddle_num_threads);
19
+
18
20
namespace paddle {
19
21
namespace operators {
20
22
@@ -25,25 +27,23 @@ void FCOp::InferShape(framework::InferShapeContext* ctx) const {
25
27
" Out(Output) of Fully Connected should not be null." );
26
28
PADDLE_ENFORCE (ctx->HasInput (" W" ),
27
29
" W(Input) of Fully Connected should not be null." );
28
-
30
+ // NCHW
29
31
auto in_dims = ctx->GetInputDim (" Input" );
32
+ // IO, I=C*H*W
30
33
auto w_dims = ctx->GetInputDim (" W" );
31
34
std::vector<int64_t > output_shape ({in_dims[0 ], w_dims[1 ]});
32
35
33
36
if (ctx->HasInput (" Bias" )) {
34
37
auto bias_dims = ctx->GetInputDim (" Bias" );
35
38
PADDLE_ENFORCE_EQ (bias_dims[0 ], 1 , " The shape of Bias must be [1, dim]." );
36
- PADDLE_ENFORCE_EQ (bias_dims[1 ], framework::product ( w_dims) / w_dims[ 0 ],
39
+ PADDLE_ENFORCE_EQ (bias_dims[1 ], w_dims[ 1 ],
37
40
" The shape of Bias must be [1, dim]." );
38
41
}
39
42
PADDLE_ENFORCE (in_dims.size () == 2 || in_dims.size () == 4 ,
40
43
" Fully Connected input should be 2-D or 4-D tensor." );
41
-
42
- PADDLE_ENFORCE (w_dims.size () == 2 || w_dims.size () == 4 ,
43
- " Fully Connected input should be 2-D or 4-D tensor." );
44
-
45
- PADDLE_ENFORCE_EQ (framework::product (w_dims) / w_dims[0 ],
46
- framework::product (in_dims) / in_dims[0 ],
44
+ PADDLE_ENFORCE_EQ (w_dims.size (), 2UL ,
45
+ " Fully Connected input should be 2-D tensor." );
46
+ PADDLE_ENFORCE_EQ (framework::product (in_dims) / in_dims[0 ], w_dims[0 ],
47
47
" Fully Connected input and weigth size do not match." );
48
48
49
49
ctx->SetOutputDim (" Out" , framework::make_ddim (output_shape));
@@ -54,7 +54,7 @@ framework::OpKernelType FCOp::GetExpectedKernelType(
54
54
const framework::ExecutionContext& ctx) const {
55
55
framework::LibraryType library = framework::LibraryType::kPlain ;
56
56
framework::DataLayout layout = framework::DataLayout::kAnyLayout ;
57
- if (ctx.Attr <bool >(" use_mkldnn" ); ) {
57
+ if (ctx.Attr <bool >(" use_mkldnn" )) {
58
58
library = framework::LibraryType::kMKLDNN ;
59
59
layout = framework::DataLayout::kMKLDNN ;
60
60
}
@@ -75,8 +75,9 @@ void FCOpGrad::InferShape(framework::InferShapeContext* ctx) const {
75
75
}
76
76
77
77
if (ctx->HasInput (" Bias" )) {
78
+ PADDLE_ENFORCE (ctx->HasOutput (framework::GradVarName (" Bias" )),
79
+ " Should have bias grad" );
78
80
auto bias_dims = ctx->GetInputDim (" Bias" );
79
- PADDLE_ENFORCE (ctx->HasOutput (framework::GradVarName (" Bias" ));
80
81
ctx->SetOutputDim (framework::GradVarName (" Bias" ), bias_dims);
81
82
}
82
83
}
@@ -85,7 +86,7 @@ framework::OpKernelType FCOpGrad::GetExpectedKernelType(
85
86
const framework::ExecutionContext& ctx) const {
86
87
framework::LibraryType library = framework::LibraryType::kPlain ;
87
88
framework::DataLayout layout = framework::DataLayout::kAnyLayout ;
88
- if (ctx.Attr <bool >(" use_mkldnn" ); ) {
89
+ if (ctx.Attr <bool >(" use_mkldnn" )) {
89
90
library = framework::LibraryType::kMKLDNN ;
90
91
layout = framework::DataLayout::kMKLDNN ;
91
92
}
@@ -95,9 +96,11 @@ framework::OpKernelType FCOpGrad::GetExpectedKernelType(
95
96
}
96
97
97
98
void FCOpMaker::Make () {
98
- AddInput (" Input" , " (Tensor) The input tensor of fully connected operator. " );
99
- AddInput (" W" , " (Tensor), The second input tensor of fc op." );
100
- AddInput (" Bias" , " (Tensor, optional) Bias vector with shape (1 x D" )
99
+ AddInput (" Input" ,
100
+ " (Tensor), The input tensor of fully connected operator with format "
101
+ " (NCHW). " );
102
+ AddInput (" W" , " (Tensor), The weight fc op with shape (I, O)." );
103
+ AddInput (" Bias" , " (Tensor, optional) Bias vector with shape (1 x O" )
101
104
.AsDispensable ();
102
105
AddOutput (" Out" , " (Tensor) The output tensor of fully connected operator. " );
103
106
AddAttr<bool >(" use_mkldnn" ,
@@ -120,25 +123,32 @@ template <typename T>
120
123
class FCOpKernel : public framework ::OpKernel<T> {
121
124
public:
122
125
void Compute (const paddle::framework::ExecutionContext& ctx) const override {
123
- PADDLE_ENFORCE (paddle:: platform::is_cpu_place (ctx.GetPlace ()),
126
+ PADDLE_ENFORCE (platform::is_cpu_place (ctx.GetPlace ()),
124
127
" It must use CPUPlace." );
125
- auto & dev_ctx = ctx.template device_context <CPUDeviceContext>();
126
- auto blas = math::GetBlas<CPUDeviceContext, T>(dev_ctx);
127
128
auto input = ctx.Input <Tensor>(" Input" );
128
129
auto w = ctx.Input <Tensor>(" W" );
129
130
auto b = ctx.Input <Tensor>(" Bias" );
131
+ auto output = ctx.Output <Tensor>(" Out" );
132
+ auto in_dims = ctx->GetInputDim (" Input" );
133
+ auto w_dims = ctx->GetInputDim (" W" );
130
134
135
+ auto & dev_ctx = ctx.template device_context <CPUDeviceContext>();
136
+ auto blas = math::GetBlas<CPUDeviceContext, T>(dev_ctx);
131
137
const T* input_data = input->data <T>();
132
138
const T* w_data = w->data <T>();
133
- auto output = ctx.Output <Tensor>(" Out" );
134
139
T* output_data = output->mutable_data <T>(ctx.GetPlace ());
135
140
136
- auto in_dims = ctx-> GetInputDim ( " Input " );
137
- auto w_dims = ctx-> GetInputDim ( " W " );
138
- std::vector< int64_t > output_shape ({in_dims[ 0 ], w_dims[ 1 ]} );
141
+ blas. GEMM (CblasNoTrans, CblasNoTrans, in_dims[ 0 ], w_dims[ 1 ], w_dims[ 0 ],
142
+ static_cast <T>( 1 ), input_data, w_data, static_cast <T>( 0 ),
143
+ output_data );
139
144
140
145
if (bias) {
141
146
const T* bias_data = bias->data <T>();
147
+ #pragma omp parallel for if (FLAGS_paddle_num_threads > 1)
148
+ for (int bs = 0 ; bs < in_dims[0 ]; bs++) {
149
+ blas.AXPY (w_dims[1 ], static_cast <T>(1 ), bias_data,
150
+ output_data + bs * w_dimws[1 ]);
151
+ }
142
152
}
143
153
}
144
154
};
@@ -150,5 +160,4 @@ namespace ops = paddle::operators;
150
160
REGISTER_OPERATOR (fc, ops::FCOp, ops::FCOpMaker,
151
161
paddle::framework::DefaultGradOpDescMaker<true >);
152
162
REGISTER_OPERATOR (fc_grad, ops::FCOpGrad);
153
- REGISTER_OP_CPU_KERNEL (fc, ops::FCMKLDNNOpKernel<float >,
154
- ops::FCMKLDNNOpKernel<double >);
163
+ REGISTER_OP_CPU_KERNEL (fc, ops::FCOpKernel<float >, ops::FCOpKernel<double >);
0 commit comments