@@ -14,6 +14,9 @@ limitations under the License. */
14
14
15
15
#include " paddle/fluid/operators/fc_op.h"
16
16
#include < vector>
17
+ #include " paddle/fluid/operators/math/blas.h"
18
+
19
+ DECLARE_int32 (paddle_num_threads);
17
20
18
21
namespace paddle {
19
22
namespace operators {
@@ -25,26 +28,37 @@ void FCOp::InferShape(framework::InferShapeContext* ctx) const {
25
28
" Out(Output) of Fully Connected should not be null." );
26
29
PADDLE_ENFORCE (ctx->HasInput (" W" ),
27
30
" W(Input) of Fully Connected should not be null." );
28
-
31
+ // NCHW
29
32
auto in_dims = ctx->GetInputDim (" Input" );
33
+ // IO, I=C*H*W
30
34
auto w_dims = ctx->GetInputDim (" W" );
31
35
std::vector<int64_t > output_shape ({in_dims[0 ], w_dims[1 ]});
32
36
37
+ if (ctx->HasInput (" Bias" )) {
38
+ auto bias_dims = ctx->GetInputDim (" Bias" );
39
+ PADDLE_ENFORCE_EQ (bias_dims[0 ], 1 , " The shape of Bias must be [1, dim]." );
40
+ PADDLE_ENFORCE_EQ (bias_dims[1 ], w_dims[1 ],
41
+ " The shape of Bias must be [1, dim]." );
42
+ }
33
43
PADDLE_ENFORCE (in_dims.size () == 2 || in_dims.size () == 4 ,
34
44
" Fully Connected input should be 2-D or 4-D tensor." );
35
-
36
- PADDLE_ENFORCE (w_dims.size () == 2 || w_dims.size () == 4 ,
37
- " Fully Connected input should be 2-D or 4-D tensor." );
45
+ PADDLE_ENFORCE_EQ (w_dims.size (), 2UL ,
46
+ " Fully Connected input should be 2-D tensor." );
47
+ PADDLE_ENFORCE_EQ (framework::product (in_dims) / in_dims[0 ], w_dims[0 ],
48
+ " Fully Connected input and weigth size do not match." );
38
49
39
50
ctx->SetOutputDim (" Out" , framework::make_ddim (output_shape));
40
51
ctx->ShareLoD (" Input" , " Out" );
41
52
}
42
53
43
54
framework::OpKernelType FCOp::GetExpectedKernelType (
44
55
const framework::ExecutionContext& ctx) const {
45
- framework::LibraryType library{framework::LibraryType::kMKLDNN };
46
- framework::DataLayout layout{framework::DataLayout::kMKLDNN };
47
-
56
+ framework::LibraryType library = framework::LibraryType::kPlain ;
57
+ framework::DataLayout layout = framework::DataLayout::kAnyLayout ;
58
+ if (ctx.Attr <bool >(" use_mkldnn" )) {
59
+ library = framework::LibraryType::kMKLDNN ;
60
+ layout = framework::DataLayout::kMKLDNN ;
61
+ }
48
62
return framework::OpKernelType (
49
63
framework::ToDataType (ctx.Input <Tensor>(" Input" )->type ()), ctx.GetPlace (),
50
64
layout, library);
@@ -60,27 +74,39 @@ void FCOpGrad::InferShape(framework::InferShapeContext* ctx) const {
60
74
if (ctx->HasOutput (framework::GradVarName (" W" ))) {
61
75
ctx->SetOutputDim (framework::GradVarName (" W" ), w_dims);
62
76
}
77
+
78
+ if (ctx->HasInput (" Bias" )) {
79
+ PADDLE_ENFORCE (ctx->HasOutput (framework::GradVarName (" Bias" )),
80
+ " Should have bias grad" );
81
+ auto bias_dims = ctx->GetInputDim (" Bias" );
82
+ ctx->SetOutputDim (framework::GradVarName (" Bias" ), bias_dims);
83
+ }
63
84
}
64
85
65
86
framework::OpKernelType FCOpGrad::GetExpectedKernelType (
66
87
const framework::ExecutionContext& ctx) const {
67
- framework::LibraryType library{framework::LibraryType::kMKLDNN };
68
- framework::DataLayout layout{framework::DataLayout::kMKLDNN };
69
-
88
+ framework::LibraryType library = framework::LibraryType::kPlain ;
89
+ framework::DataLayout layout = framework::DataLayout::kAnyLayout ;
90
+ if (ctx.Attr <bool >(" use_mkldnn" )) {
91
+ library = framework::LibraryType::kMKLDNN ;
92
+ layout = framework::DataLayout::kMKLDNN ;
93
+ }
70
94
return framework::OpKernelType (
71
95
framework::ToDataType (ctx.Input <Tensor>(" Input" )->type ()), ctx.GetPlace (),
72
96
layout, library);
73
97
}
74
98
75
99
void FCOpMaker::Make () {
76
- AddInput (" Input" , " (Tensor) The input tensor of fully connected operator. " );
77
- AddInput (" W" , " (Tensor), The second input tensor of fc op." );
100
+ AddInput (" Input" ,
101
+ " (Tensor), The input tensor of fully connected operator with format "
102
+ " (NCHW). " );
103
+ AddInput (" W" , " (Tensor), The weight fc op with shape (I, O)." );
104
+ AddInput (" Bias" , " (Tensor, optional) Bias vector with shape (1 x O" )
105
+ .AsDispensable ();
78
106
AddOutput (" Out" , " (Tensor) The output tensor of fully connected operator. " );
79
107
AddAttr<bool >(" use_mkldnn" ,
80
108
" (bool, default false) Only used in mkldnn kernel" )
81
109
.SetDefault (false );
82
- AddAttr<bool >(" bias_attr" , " (bool, default false) Only used in mkldnn kernel" )
83
- .SetDefault (false );
84
110
AddComment (R"DOC(
85
111
Fully Connected Operator.
86
112
@@ -94,9 +120,47 @@ void FCOpMaker::Make() {
94
120
)DOC" );
95
121
}
96
122
123
+ template <typename T>
124
+ class FCOpKernel : public framework ::OpKernel<T> {
125
+ public:
126
+ void Compute (const paddle::framework::ExecutionContext& ctx) const override {
127
+ PADDLE_ENFORCE (platform::is_cpu_place (ctx.GetPlace ()),
128
+ " It must use CPUPlace." );
129
+ auto input = ctx.Input <Tensor>(" Input" );
130
+ auto w = ctx.Input <Tensor>(" W" );
131
+ auto bias = ctx.Input <Tensor>(" Bias" );
132
+ auto output = ctx.Output <Tensor>(" Out" );
133
+ auto in_dims = input->dims ();
134
+ auto w_dims = w->dims ();
135
+
136
+ auto & dev_ctx = ctx.template device_context <platform::CPUDeviceContext>();
137
+ auto blas = math::GetBlas<platform::CPUDeviceContext, T>(dev_ctx);
138
+ const T* input_data = input->data <T>();
139
+ const T* w_data = w->data <T>();
140
+ T* output_data = output->mutable_data <T>(ctx.GetPlace ());
141
+
142
+ blas.GEMM (CblasNoTrans, CblasNoTrans, in_dims[0 ], w_dims[1 ], w_dims[0 ],
143
+ static_cast <T>(1 ), input_data, w_data, static_cast <T>(0 ),
144
+ output_data);
145
+
146
+ if (bias) {
147
+ const T* bias_data = bias->data <T>();
148
+ #ifdef PADDLE_WITH_MKLML
149
+ #pragma omp parallel for if (FLAGS_paddle_num_threads > 1)
150
+ #endif
151
+ for (int bs = 0 ; bs < in_dims[0 ]; bs++) {
152
+ blas.AXPY (w_dims[1 ], static_cast <T>(1 ), bias_data,
153
+ output_data + bs * w_dims[1 ]);
154
+ }
155
+ }
156
+ }
157
+ };
158
+
97
159
} // namespace operators
98
160
} // namespace paddle
99
161
100
- REGISTER_OPERATOR (fc, paddle::operators::FCOp, paddle::operators::FCOpMaker,
162
+ namespace ops = paddle::operators;
163
+ REGISTER_OPERATOR (fc, ops::FCOp, ops::FCOpMaker,
101
164
paddle::framework::DefaultGradOpDescMaker<true >);
102
- REGISTER_OPERATOR (fc_grad, paddle::operators::FCOpGrad);
165
+ REGISTER_OPERATOR (fc_grad, ops::FCOpGrad);
166
+ REGISTER_OP_CPU_KERNEL (fc, ops::FCOpKernel<float >, ops::FCOpKernel<double >);
0 commit comments