@@ -27,11 +27,9 @@ void FCOp::InferShape(framework::InferShapeContext* ctx) const {
27
27
" Out(Output) of Fully Connected should not be null." );
28
28
PADDLE_ENFORCE (ctx->HasInput (" W" ),
29
29
" W(Input) of Fully Connected should not be null." );
30
- // NCHW
30
+
31
31
auto in_dims = ctx->GetInputDim (" Input" );
32
- // IO, I=C*H*W
33
32
auto w_dims = ctx->GetInputDim (" W" );
34
- std::vector<int64_t > output_shape ({in_dims[0 ], w_dims[1 ]});
35
33
36
34
if (ctx->HasInput (" Bias" )) {
37
35
auto bias_dims = ctx->GetInputDim (" Bias" );
@@ -44,14 +42,32 @@ void FCOp::InferShape(framework::InferShapeContext* ctx) const {
44
42
" The shape of Bias must be [1, dim]." );
45
43
}
46
44
}
47
- PADDLE_ENFORCE (in_dims.size () == 2 || in_dims.size () == 4 ,
48
- " Fully Connected input should be 2-D or 4-D tensor." );
45
+
46
+ if (ctx->Attrs ().Get <bool >(" use_mkldnn" )) {
47
+ PADDLE_ENFORCE (in_dims.size () == 2 || in_dims.size () == 4 ,
48
+ " Fully Connected input should be 2-D or 4-D tensor." );
49
+ }
49
50
PADDLE_ENFORCE_EQ (w_dims.size (), 2UL ,
50
51
" Fully Connected input should be 2-D tensor." );
51
- PADDLE_ENFORCE_EQ (framework::product (in_dims) / in_dims[0 ], w_dims[0 ],
52
- " Fully Connected input and weigth size do not match." );
52
+ int in_num_col_dims = ctx->Attrs ().Get <int >(" in_num_col_dims" );
53
+ PADDLE_ENFORCE_GT (
54
+ in_dims.size (), in_num_col_dims,
55
+ " The input tensor Input's rank of FCOp should be larger than "
56
+ " in_num_col_dims." );
57
+
58
+ auto in_mat_dims = framework::flatten_to_2d (in_dims, in_num_col_dims);
59
+ PADDLE_ENFORCE_EQ (
60
+ in_mat_dims[1 ], w_dims[0 ],
61
+ " Fully Connected input and weigth size do not match. %s, %s" );
62
+
63
+ std::vector<int64_t > output_dims;
64
+ output_dims.reserve (static_cast <size_t >(in_num_col_dims + 1 ));
65
+ for (int i = 0 ; i < in_num_col_dims; ++i) {
66
+ output_dims.push_back (in_dims[i]);
67
+ }
68
+ output_dims.push_back (w_dims[1 ]);
53
69
54
- ctx->SetOutputDim (" Out" , framework::make_ddim (output_shape ));
70
+ ctx->SetOutputDim (" Out" , framework::make_ddim (output_dims ));
55
71
ctx->ShareLoD (" Input" , " Out" );
56
72
}
57
73
@@ -101,12 +117,15 @@ framework::OpKernelType FCOpGrad::GetExpectedKernelType(
101
117
}
102
118
103
119
void FCOpMaker::Make () {
104
- AddInput (" Input" ,
105
- " (Tensor), The input tensor of fully connected operator with format "
106
- " (NCHW). " );
120
+ AddInput (" Input" , " (Tensor), The input tensor of fully connected operator." );
107
121
AddInput (" W" , " (Tensor), The weight fc op with shape (I, O)." );
108
122
AddInput (" Bias" , " (Tensor, optional) Bias vector with shape (1 x O" )
109
123
.AsDispensable ();
124
+ AddAttr<int >(" x_num_col_dims" ,
125
+ " (int, default 1), The fc op can take tensors with more than "
126
+ " two dimensions as its inputs." )
127
+ .SetDefault (1 )
128
+ .EqualGreaterThan (1 );
110
129
AddOutput (" Out" , " (Tensor) The output tensor of fully connected operator. " );
111
130
AddAttr<bool >(" use_mkldnn" ,
112
131
" (bool, default false) Only used in mkldnn kernel" )
@@ -131,13 +150,15 @@ class FCOpKernel : public framework::OpKernel<T> {
131
150
auto output = ctx.Output <Tensor>(" Out" );
132
151
auto in_dims = input->dims ();
133
152
auto w_dims = w->dims ();
153
+ auto out_dims = output->dims ();
154
+ int M = framework::product (out_dims) / out_dims[out_dims.size () - 1 ];
134
155
135
156
const T* input_data = input->data <T>();
136
157
const T* w_data = w->data <T>();
137
158
T* output_data = output->mutable_data <T>(ctx.GetPlace ());
138
159
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(ctx);
139
160
math::FCCompute<platform::CPUDeviceContext, T>(
140
- blas, in_dims[ 0 ] , w_dims[1 ], w_dims[0 ], input_data, w_data, output_data,
161
+ blas, M , w_dims[1 ], w_dims[0 ], input_data, w_data, output_data,
141
162
bias ? bias->data <T>() : NULL );
142
163
143
164
// TODO(TJ): fuse act
0 commit comments