@@ -30,21 +30,34 @@ void FCOp::InferShape(framework::InferShapeContext* ctx) const {
30
30
auto w_dims = ctx->GetInputDim (" W" );
31
31
std::vector<int64_t > output_shape ({in_dims[0 ], w_dims[1 ]});
32
32
33
+ if (ctx->HasInput (" Bias" )) {
34
+ auto bias_dims = ctx->GetInputDim (" Bias" );
35
+ PADDLE_ENFORCE_EQ (bias_dims[0 ], 1 , " The shape of Bias must be [1, dim]." );
36
+ PADDLE_ENFORCE_EQ (bias_dims[1 ], framework::product (w_dims) / w_dims[0 ],
37
+ " The shape of Bias must be [1, dim]." );
38
+ }
33
39
PADDLE_ENFORCE (in_dims.size () == 2 || in_dims.size () == 4 ,
34
40
" Fully Connected input should be 2-D or 4-D tensor." );
35
41
36
42
PADDLE_ENFORCE (w_dims.size () == 2 || w_dims.size () == 4 ,
37
43
" Fully Connected input should be 2-D or 4-D tensor." );
38
44
45
+ PADDLE_ENFORCE_EQ (framework::product (w_dims) / w_dims[0 ],
46
+ framework::product (in_dims) / in_dims[0 ],
47
+ " Fully Connected input and weigth size do not match." );
48
+
39
49
ctx->SetOutputDim (" Out" , framework::make_ddim (output_shape));
40
50
ctx->ShareLoD (" Input" , " Out" );
41
51
}
42
52
43
53
framework::OpKernelType FCOp::GetExpectedKernelType (
44
54
const framework::ExecutionContext& ctx) const {
45
- framework::LibraryType library{framework::LibraryType::kMKLDNN };
46
- framework::DataLayout layout{framework::DataLayout::kMKLDNN };
47
-
55
+ framework::LibraryType library = framework::LibraryType::kPlain ;
56
+ framework::DataLayout layout = framework::DataLayout::kAnyLayout ;
57
+ if (ctx.Attr <bool >(" use_mkldnn" );) {
58
+ library = framework::LibraryType::kMKLDNN ;
59
+ layout = framework::DataLayout::kMKLDNN ;
60
+ }
48
61
return framework::OpKernelType (
49
62
framework::ToDataType (ctx.Input <Tensor>(" Input" )->type ()), ctx.GetPlace (),
50
63
layout, library);
@@ -60,13 +73,22 @@ void FCOpGrad::InferShape(framework::InferShapeContext* ctx) const {
60
73
if (ctx->HasOutput (framework::GradVarName (" W" ))) {
61
74
ctx->SetOutputDim (framework::GradVarName (" W" ), w_dims);
62
75
}
76
+
77
+ if (ctx->HasInput (" Bias" )) {
78
+ auto bias_dims = ctx->GetInputDim (" Bias" );
79
+ PADDLE_ENFORCE (ctx->HasOutput (framework::GradVarName (" Bias" ));
80
+ ctx->SetOutputDim (framework::GradVarName (" Bias" ), bias_dims);
81
+ }
63
82
}
64
83
65
84
framework::OpKernelType FCOpGrad::GetExpectedKernelType (
66
85
const framework::ExecutionContext& ctx) const {
67
- framework::LibraryType library{framework::LibraryType::kMKLDNN };
68
- framework::DataLayout layout{framework::DataLayout::kMKLDNN };
69
-
86
+ framework::LibraryType library = framework::LibraryType::kPlain ;
87
+ framework::DataLayout layout = framework::DataLayout::kAnyLayout ;
88
+ if (ctx.Attr <bool >(" use_mkldnn" );) {
89
+ library = framework::LibraryType::kMKLDNN ;
90
+ layout = framework::DataLayout::kMKLDNN ;
91
+ }
70
92
return framework::OpKernelType (
71
93
framework::ToDataType (ctx.Input <Tensor>(" Input" )->type ()), ctx.GetPlace (),
72
94
layout, library);
@@ -75,12 +97,12 @@ framework::OpKernelType FCOpGrad::GetExpectedKernelType(
75
97
void FCOpMaker::Make () {
76
98
AddInput (" Input" , " (Tensor) The input tensor of fully connected operator. " );
77
99
AddInput (" W" , " (Tensor), The second input tensor of fc op." );
100
+ AddInput (" Bias" , " (Tensor, optional) Bias vector with shape (1 x D" )
101
+ .AsDispensable ();
78
102
AddOutput (" Out" , " (Tensor) The output tensor of fully connected operator. " );
79
103
AddAttr<bool >(" use_mkldnn" ,
80
104
" (bool, default false) Only used in mkldnn kernel" )
81
105
.SetDefault (false );
82
- AddAttr<bool >(" bias_attr" , " (bool, default false) Only used in mkldnn kernel" )
83
- .SetDefault (false );
84
106
AddComment (R"DOC(
85
107
Fully Connected Operator.
86
108
@@ -94,9 +116,39 @@ void FCOpMaker::Make() {
94
116
)DOC" );
95
117
}
96
118
119
+ template <typename T>
120
+ class FCOpKernel : public framework ::OpKernel<T> {
121
+ public:
122
+ void Compute (const paddle::framework::ExecutionContext& ctx) const override {
123
+ PADDLE_ENFORCE (paddle::platform::is_cpu_place (ctx.GetPlace ()),
124
+ " It must use CPUPlace." );
125
+ auto & dev_ctx = ctx.template device_context <CPUDeviceContext>();
126
+ auto blas = math::GetBlas<CPUDeviceContext, T>(dev_ctx);
127
+ auto input = ctx.Input <Tensor>(" Input" );
128
+ auto w = ctx.Input <Tensor>(" W" );
129
+ auto b = ctx.Input <Tensor>(" Bias" );
130
+
131
+ const T* input_data = input->data <T>();
132
+ const T* w_data = w->data <T>();
133
+ auto output = ctx.Output <Tensor>(" Out" );
134
+ T* output_data = output->mutable_data <T>(ctx.GetPlace ());
135
+
136
+ auto in_dims = ctx->GetInputDim (" Input" );
137
+ auto w_dims = ctx->GetInputDim (" W" );
138
+ std::vector<int64_t > output_shape ({in_dims[0 ], w_dims[1 ]});
139
+
140
+ if (bias) {
141
+ const T* bias_data = bias->data <T>();
142
+ }
143
+ }
144
+ };
145
+
97
146
} // namespace operators
98
147
} // namespace paddle
99
148
100
- REGISTER_OPERATOR (fc, paddle::operators::FCOp, paddle::operators::FCOpMaker,
149
+ namespace ops = paddle::operators;
150
+ REGISTER_OPERATOR (fc, ops::FCOp, ops::FCOpMaker,
101
151
paddle::framework::DefaultGradOpDescMaker<true >);
102
- REGISTER_OPERATOR (fc_grad, paddle::operators::FCOpGrad);
152
+ REGISTER_OPERATOR (fc_grad, ops::FCOpGrad);
153
+ REGISTER_OP_CPU_KERNEL (fc, ops::FCMKLDNNOpKernel<float >,
154
+ ops::FCMKLDNNOpKernel<double >);
0 commit comments