Skip to content

Commit 4b5986b

Browse files
committed
enable fc op in normal case
1 parent e133df6 commit 4b5986b

File tree

2 files changed

+7
-12
lines changed

2 files changed

+7
-12
lines changed

paddle/fluid/operators/CMakeLists.txt

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -295,12 +295,6 @@ op_library(channel_recv_op DEPS concurrency)
295295

296296
list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS})
297297

298-
# The fully connected layer is deleted when the WITH_MKLDNN flag is OFF
299-
# Because the fully connected layer has only one MKLDNN's operator
300-
if(NOT WITH_MKLDNN)
301-
list(REMOVE_ITEM GENERAL_OPS fc_op)
302-
endif(NOT WITH_MKLDNN)
303-
304298
foreach(src ${GENERAL_OPS})
305299
op_library(${src})
306300
endforeach()

paddle/fluid/operators/fc_op.cc

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ limitations under the License. */
1414

1515
#include "paddle/fluid/operators/fc_op.h"
1616
#include <vector>
17+
#include "paddle/fluid/operators/math/blas.h"
1718

1819
DECLARE_int32(paddle_num_threads);
1920

@@ -127,13 +128,13 @@ class FCOpKernel : public framework::OpKernel<T> {
127128
"It must use CPUPlace.");
128129
auto input = ctx.Input<Tensor>("Input");
129130
auto w = ctx.Input<Tensor>("W");
130-
auto b = ctx.Input<Tensor>("Bias");
131+
auto bias = ctx.Input<Tensor>("Bias");
131132
auto output = ctx.Output<Tensor>("Out");
132-
auto in_dims = ctx->GetInputDim("Input");
133-
auto w_dims = ctx->GetInputDim("W");
133+
auto in_dims = input->dims();
134+
auto w_dims = w->dims();
134135

135-
auto& dev_ctx = ctx.template device_context<CPUDeviceContext>();
136-
auto blas = math::GetBlas<CPUDeviceContext, T>(dev_ctx);
136+
auto& dev_ctx = ctx.template device_context<platform::CPUDeviceContext>();
137+
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(dev_ctx);
137138
const T* input_data = input->data<T>();
138139
const T* w_data = w->data<T>();
139140
T* output_data = output->mutable_data<T>(ctx.GetPlace());
@@ -147,7 +148,7 @@ class FCOpKernel : public framework::OpKernel<T> {
147148
#pragma omp parallel for if (FLAGS_paddle_num_threads > 1)
148149
for (int bs = 0; bs < in_dims[0]; bs++) {
149150
blas.AXPY(w_dims[1], static_cast<T>(1), bias_data,
150-
output_data + bs * w_dimws[1]);
151+
output_data + bs * w_dims[1]);
151152
}
152153
}
153154
}

0 commit comments

Comments
 (0)