Skip to content

Commit 46e14bb

Browse files
committed
Enforce: 2 and 4 dims, remove information about out in format
1 parent 32f8ac7 commit 46e14bb

File tree

2 files changed

+9
-18
lines changed

2 files changed

+9
-18
lines changed

paddle/fluid/operators/fc_mkldnn_op.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -125,10 +125,10 @@ class FCMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
125125
auto input = ctx.Input<Tensor>("Input");
126126
auto w = ctx.Input<Tensor>("W");
127127

128-
PADDLE_ENFORCE(input->dims().size() == 4 || input->dims().size() == 2,
128+
PADDLE_ENFORCE(input->dims().size() == 2 || input->dims().size() == 4,
129129
"Input must be with 2 or 4 dimensions, i.e. NCHW");
130-
PADDLE_ENFORCE(w->dims().size() == 2,
131-
"Weights must be with 2 dimensions, i.e. NC");
130+
PADDLE_ENFORCE(w->dims().size() == 2 || w->dims().size() == 4,
131+
"Weights must be with 2 or 4 dimensions, i.e. OI or OIHW");
132132

133133
bool with_bias = ctx.Attr<bool>("bias_attr");
134134
MKLDNNMD<Tensor> md(input, w, with_bias);

paddle/fluid/operators/fc_op.cc

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

1515
#include "paddle/fluid/operators/fc_op.h"
16+
#include <vector>
1617

1718
namespace paddle {
1819
namespace operators {
@@ -29,11 +30,11 @@ void FCOp::InferShape(framework::InferShapeContext* ctx) const {
2930
auto w_dims = ctx->GetInputDim("W");
3031
std::vector<int64_t> output_shape({in_dims[0], w_dims[1]});
3132

32-
PADDLE_ENFORCE(in_dims.size() == 4 || in_dims.size() == 2,
33+
PADDLE_ENFORCE(in_dims.size() == 2 || in_dims.size() == 4,
3334
"Fully Connected input should be 2-D or 4-D tensor.");
3435

35-
PADDLE_ENFORCE(w_dims.size() == 2,
36-
"Fully Connected input should be 2-D tensor.");
36+
PADDLE_ENFORCE(w_dims.size() == 2 || w_dims.size() == 4,
37+
"Fully Connected input should be 2-D or 4-D tensor.");
3738

3839
ctx->SetOutputDim("Out", framework::make_ddim(output_shape));
3940
ctx->ShareLoD("Input", "Out");
@@ -73,19 +74,9 @@ framework::OpKernelType FCOpGrad::GetExpectedKernelType(
7374

7475
FCOpMaker::FCOpMaker(OpProto* proto, OpAttrChecker* op_checker)
7576
: OpProtoAndCheckerMaker(proto, op_checker) {
76-
AddInput(
77-
"Input",
78-
"(Tensor) The input tensor of fully connected operator. "
79-
"The format of input tensor is NCHW, where N is batch size, C is the "
80-
"number of channels, H is the height of the feature, "
81-
"and W is the width of the feature.");
77+
AddInput("Input", "(Tensor) The input tensor of fully connected operator. ");
8278
AddInput("W", "(Tensor), The second input tensor of fc op.");
83-
AddOutput("Out",
84-
"(Tensor) The output tensor of fully connected operator. "
85-
"The format of output tensor is also NCHW, "
86-
"where N is batch size, C is the number of channels, "
87-
"H is the height of the feature, "
88-
"and W is the width of the feature.");
79+
AddOutput("Out", "(Tensor) The output tensor of fully connected operator. ");
8980
AddAttr<bool>("use_mkldnn",
9081
"(bool, default false) Only used in mkldnn kernel")
9182
.SetDefault(false);

0 commit comments

Comments
 (0)