@@ -14,6 +14,7 @@ limitations under the License. */
14
14
15
15
#include " paddle/fluid/operators/fc_op.h"
16
16
#include < vector>
17
+ #include " paddle/fluid/operators/math/blas.h"
17
18
18
19
DECLARE_int32 (paddle_num_threads);
19
20
@@ -127,13 +128,13 @@ class FCOpKernel : public framework::OpKernel<T> {
127
128
" It must use CPUPlace." );
128
129
auto input = ctx.Input <Tensor>(" Input" );
129
130
auto w = ctx.Input <Tensor>(" W" );
130
- auto b = ctx.Input <Tensor>(" Bias" );
131
+ auto bias = ctx.Input <Tensor>(" Bias" );
131
132
auto output = ctx.Output <Tensor>(" Out" );
132
- auto in_dims = ctx-> GetInputDim ( " Input " );
133
- auto w_dims = ctx-> GetInputDim ( " W " );
133
+ auto in_dims = input-> dims ( );
134
+ auto w_dims = w-> dims ( );
134
135
135
- auto & dev_ctx = ctx.template device_context <CPUDeviceContext>();
136
- auto blas = math::GetBlas<CPUDeviceContext, T>(dev_ctx);
136
+ auto & dev_ctx = ctx.template device_context <platform:: CPUDeviceContext>();
137
+ auto blas = math::GetBlas<platform:: CPUDeviceContext, T>(dev_ctx);
137
138
const T* input_data = input->data <T>();
138
139
const T* w_data = w->data <T>();
139
140
T* output_data = output->mutable_data <T>(ctx.GetPlace ());
@@ -147,7 +148,7 @@ class FCOpKernel : public framework::OpKernel<T> {
147
148
#pragma omp parallel for if (FLAGS_paddle_num_threads > 1)
148
149
for (int bs = 0 ; bs < in_dims[0 ]; bs++) {
149
150
blas.AXPY (w_dims[1 ], static_cast <T>(1 ), bias_data,
150
- output_data + bs * w_dimws [1 ]);
151
+ output_data + bs * w_dims [1 ]);
151
152
}
152
153
}
153
154
}
0 commit comments