@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
13
13
limitations under the License. */
14
14
15
15
#include " paddle/fluid/operators/fc_op.h"
16
+ #include < vector>
16
17
17
18
namespace paddle {
18
19
namespace operators {
@@ -29,11 +30,11 @@ void FCOp::InferShape(framework::InferShapeContext* ctx) const {
29
30
auto w_dims = ctx->GetInputDim (" W" );
30
31
std::vector<int64_t > output_shape ({in_dims[0 ], w_dims[1 ]});
31
32
32
- PADDLE_ENFORCE (in_dims.size () == 4 || in_dims.size () == 2 ,
33
+ PADDLE_ENFORCE (in_dims.size () == 2 || in_dims.size () == 4 ,
33
34
" Fully Connected input should be 2-D or 4-D tensor." );
34
35
35
- PADDLE_ENFORCE (w_dims.size () == 2 ,
36
- " Fully Connected input should be 2-D tensor." );
36
+ PADDLE_ENFORCE (w_dims.size () == 2 || w_dims. size () == 4 ,
37
+ " Fully Connected input should be 2-D or 4-D tensor." );
37
38
38
39
ctx->SetOutputDim (" Out" , framework::make_ddim (output_shape));
39
40
ctx->ShareLoD (" Input" , " Out" );
@@ -73,19 +74,9 @@ framework::OpKernelType FCOpGrad::GetExpectedKernelType(
73
74
74
75
FCOpMaker::FCOpMaker (OpProto* proto, OpAttrChecker* op_checker)
75
76
: OpProtoAndCheckerMaker(proto, op_checker) {
76
- AddInput (
77
- " Input" ,
78
- " (Tensor) The input tensor of fully connected operator. "
79
- " The format of input tensor is NCHW, where N is batch size, C is the "
80
- " number of channels, H is the height of the feature, "
81
- " and W is the width of the feature." );
77
+ AddInput (" Input" , " (Tensor) The input tensor of fully connected operator. " );
82
78
AddInput (" W" , " (Tensor), The second input tensor of fc op." );
83
- AddOutput (" Out" ,
84
- " (Tensor) The output tensor of fully connected operator. "
85
- " The format of output tensor is also NCHW, "
86
- " where N is batch size, C is the number of channels, "
87
- " H is the height of the feature, "
88
- " and W is the width of the feature." );
79
+ AddOutput (" Out" , " (Tensor) The output tensor of fully connected operator. " );
89
80
AddAttr<bool >(" use_mkldnn" ,
90
81
" (bool, default false) Only used in mkldnn kernel" )
91
82
.SetDefault (false );
0 commit comments