Skip to content

Commit 6447155

Browse files
authored
Merge pull request #13851 from tensor-tang/fea/jitkernel_peephole
Fea jitkernel lstm peephole
2 parents 7fb5b66 + bcb8ea3 commit 6447155

17 files changed

+2274
-403
lines changed

paddle/fluid/operators/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,7 @@ op_library(flatten_op DEPS reshape_op)
300300
op_library(sequence_pad_op DEPS sequence_padding)
301301
op_library(unstack_op DEPS stack_op)
302302
op_library(fake_quantize_op DEPS memory)
303-
op_library(fusion_lstm_op DEPS cpu_lstm_compute)
303+
op_library(fusion_lstm_op DEPS jit_kernel)
304304
if (WITH_GPU)
305305
op_library(conv_op DEPS vol2col depthwise_conv im2col)
306306
op_library(layer_norm_op DEPS cub)

paddle/fluid/operators/fusion_lstm_op.cc

Lines changed: 102 additions & 261 deletions
Large diffs are not rendered by default.

paddle/fluid/operators/math/CMakeLists.txt

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,6 @@ math_library(im2col)
4545
if (NOT WIN32) # windows do not support avx functions yet.
4646
math_library(gru_compute DEPS activation_functions math_function)
4747
math_library(lstm_compute DEPS activation_functions)
48-
# TODO(TJ): ugly workaround, clean me
49-
cc_library(cpu_lstm_compute SRCS cpu_lstm_compute.cc DEPS activation_functions cblas cpu_info)
5048
endif (NOT WIN32)
5149

5250
cc_library(blas SRCS blas.cc DEPS cblas framework_proto device_context)
@@ -76,3 +74,7 @@ if(WITH_GPU)
7674
endif()
7775
cc_test(concat_test SRCS concat_test.cc DEPS concat)
7876
cc_test(cpu_vec_test SRCS cpu_vec_test.cc DEPS blas cpu_info)
77+
cc_library(jit_kernel
78+
SRCS jit_kernel.cc jit_kernel_blas.cc jit_kernel_exp.cc jit_kernel_lstm.cc
79+
DEPS cpu_info cblas activation_functions)
80+
cc_test(jit_kernel_test SRCS jit_kernel_test.cc DEPS jit_kernel)

paddle/fluid/operators/math/cpu_lstm_compute.cc

Lines changed: 0 additions & 43 deletions
This file was deleted.

paddle/fluid/operators/math/cpu_lstm_compute.h

Lines changed: 0 additions & 64 deletions
This file was deleted.

paddle/fluid/operators/math/cpu_vec.h

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -125,10 +125,8 @@ inline void vec_scal<float, platform::jit::avx2>(const int n, const float a,
125125
}
126126

127127
template <>
128-
inline void vec_scal<float, platform::jit::avx512_common>(const int n,
129-
const float a,
130-
const float* x,
131-
float* y) {
128+
inline void vec_scal<float, platform::jit::avx512f>(const int n, const float a,
129+
const float* x, float* y) {
132130
// TODO(TJ): enable me
133131
vec_scal<float, platform::jit::avx2>(n, a, x, y);
134132
}
@@ -181,10 +179,10 @@ inline void vec_bias_sub<float, platform::jit::avx2>(const int n, const float a,
181179
}
182180

183181
template <>
184-
inline void vec_bias_sub<float, platform::jit::avx512_common>(const int n,
185-
const float a,
186-
const float* x,
187-
float* y) {
182+
inline void vec_bias_sub<float, platform::jit::avx512f>(const int n,
183+
const float a,
184+
const float* x,
185+
float* y) {
188186
// TODO(TJ): enable me
189187
vec_bias_sub<float, platform::jit::avx2>(n, a, x, y);
190188
}
@@ -242,7 +240,7 @@ inline void vec_cross<float, platform::jit::avx2>(const int n, const float* x,
242240
}
243241

244242
template <>
245-
inline void vec_cross<float, platform::jit::avx512_common>(
243+
inline void vec_cross<float, platform::jit::avx512f>(
246244
const int n, const float* x, const float* y, const float* z, float* out) {
247245
// TODO(TJ): enable me
248246
vec_cross<float, platform::jit::avx>(n, x, y, z, out);
@@ -296,10 +294,10 @@ inline void vec_add_bias<float, platform::jit::avx2>(const int n, const float a,
296294
}
297295

298296
template <>
299-
inline void vec_add_bias<float, platform::jit::avx512_common>(const int n,
300-
const float a,
301-
const float* x,
302-
float* y) {
297+
inline void vec_add_bias<float, platform::jit::avx512f>(const int n,
298+
const float a,
299+
const float* x,
300+
float* y) {
303301
// TODO(TJ): enable me
304302
vec_add_bias<float, platform::jit::avx2>(n, a, x, y);
305303
}
@@ -390,9 +388,9 @@ inline void vec_sigmoid<float, platform::jit::avx2>(const int n, const float* x,
390388
}
391389

392390
template <>
393-
inline void vec_sigmoid<float, platform::jit::avx512_common>(const int n,
394-
const float* x,
395-
float* y) {
391+
inline void vec_sigmoid<float, platform::jit::avx512f>(const int n,
392+
const float* x,
393+
float* y) {
396394
// TODO(TJ): enable me
397395
vec_sigmoid<float, platform::jit::avx2>(n, x, y);
398396
}
@@ -454,9 +452,8 @@ inline void vec_relu<float, platform::jit::avx2>(const int n, const float* x,
454452
}
455453

456454
template <>
457-
inline void vec_relu<float, platform::jit::avx512_common>(const int n,
458-
const float* x,
459-
float* y) {
455+
inline void vec_relu<float, platform::jit::avx512f>(const int n, const float* x,
456+
float* y) {
460457
// TODO(TJ): enable me
461458
vec_relu<float, platform::jit::avx2>(n, x, y);
462459
}

paddle/fluid/operators/math/cpu_vec_test.cc

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ TEST(CpuVecTest, sigmoid) {
110110
TestAndBench<float>(sz, vec_sigmoid<float>, ref_sigmoid<float>);
111111
TestAndBench<float>(sz, vec_sigmoid<float, jit::avx>, ref_sigmoid<float>);
112112
TestAndBench<float>(sz, vec_sigmoid<float, jit::avx2>, ref_sigmoid<float>);
113-
TestAndBench<float>(sz, vec_sigmoid<float, jit::avx512_common>,
113+
TestAndBench<float>(sz, vec_sigmoid<float, jit::avx512f>,
114114
ref_sigmoid<float>);
115115
}
116116
TestAndBench<double>(30, vec_sigmoid<double>, ref_sigmoid<double>);
@@ -123,8 +123,7 @@ TEST(CpuVecTest, tanh) {
123123
TestAndBench<float>(sz, vec_tanh<float>, ref_tanh<float>);
124124
TestAndBench<float>(sz, vec_tanh<float, jit::avx>, ref_tanh<float>);
125125
TestAndBench<float>(sz, vec_tanh<float, jit::avx2>, ref_tanh<float>);
126-
TestAndBench<float>(sz, vec_tanh<float, jit::avx512_common>,
127-
ref_tanh<float>);
126+
TestAndBench<float>(sz, vec_tanh<float, jit::avx512f>, ref_tanh<float>);
128127
}
129128
TestAndBench<double>(30, vec_tanh<double>, ref_tanh<double>);
130129
}
@@ -136,8 +135,7 @@ TEST(CpuVecTest, relu) {
136135
TestAndBench<float>(sz, vec_relu<float>, ref_relu<float>);
137136
TestAndBench<float>(sz, vec_relu<float, jit::avx>, ref_relu<float>);
138137
TestAndBench<float>(sz, vec_relu<float, jit::avx2>, ref_relu<float>);
139-
TestAndBench<float>(sz, vec_relu<float, jit::avx512_common>,
140-
ref_relu<float>);
138+
TestAndBench<float>(sz, vec_relu<float, jit::avx512f>, ref_relu<float>);
141139
}
142140
TestAndBench<double>(30, vec_relu<double>, ref_relu<double>);
143141
}
@@ -170,7 +168,7 @@ TEST(CpuVecTest, inplace_sigmoid) {
170168
TestInplace<float>(sz, vec_sigmoid<float>, ref_sigmoid<float>);
171169
TestInplace<float>(sz, vec_sigmoid<float, jit::avx>, ref_sigmoid<float>);
172170
TestInplace<float>(sz, vec_sigmoid<float, jit::avx2>, ref_sigmoid<float>);
173-
TestInplace<float>(sz, vec_sigmoid<float, jit::avx512_common>,
171+
TestInplace<float>(sz, vec_sigmoid<float, jit::avx512f>,
174172
ref_sigmoid<float>);
175173
}
176174
TestInplace<double>(30, vec_sigmoid<double>, ref_sigmoid<double>);
@@ -183,8 +181,7 @@ TEST(CpuVecTest, inplace_tanh) {
183181
TestInplace<float>(sz, vec_tanh<float>, ref_tanh<float>);
184182
TestInplace<float>(sz, vec_tanh<float, jit::avx>, ref_tanh<float>);
185183
TestInplace<float>(sz, vec_tanh<float, jit::avx2>, ref_tanh<float>);
186-
TestInplace<float>(sz, vec_tanh<float, jit::avx512_common>,
187-
ref_tanh<float>);
184+
TestInplace<float>(sz, vec_tanh<float, jit::avx512f>, ref_tanh<float>);
188185
}
189186
TestInplace<double>(30, vec_tanh<double>, ref_tanh<double>);
190187
}
@@ -196,8 +193,7 @@ TEST(CpuVecTest, inplace_relu) {
196193
TestInplace<float>(sz, vec_relu<float>, ref_relu<float>);
197194
TestInplace<float>(sz, vec_relu<float, jit::avx>, ref_relu<float>);
198195
TestInplace<float>(sz, vec_relu<float, jit::avx2>, ref_relu<float>);
199-
TestInplace<float>(sz, vec_relu<float, jit::avx512_common>,
200-
ref_relu<float>);
196+
TestInplace<float>(sz, vec_relu<float, jit::avx512f>, ref_relu<float>);
201197
}
202198
TestInplace<double>(30, vec_relu<double>, ref_relu<double>);
203199
}
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/operators/math/jit_kernel.h"
16+
#include <iostream>
17+
#include <string>
18+
19+
namespace paddle {
20+
namespace operators {
21+
namespace math {
22+
namespace jitkernel {
23+
24+
namespace jit = platform::jit;
25+
26+
KernelPool& KernelPool::Instance() {
27+
static thread_local KernelPool g_jit_kernels;
28+
return g_jit_kernels;
29+
}
30+
31+
std::shared_ptr<const Kernel> KernelPool::Get(const std::string& key) const {
32+
if (kers_.find(key) == kers_.end()) {
33+
return nullptr;
34+
}
35+
return kers_.at(key);
36+
}
37+
38+
} // namespace jitkernel
39+
} // namespace math
40+
} // namespace operators
41+
} // namespace paddle

0 commit comments

Comments
 (0)