Skip to content

Commit c744922

Browse files
authored
Merge pull request #15563 from tensor-tang/jit/softmax
refine softmax kernel
2 parents 245b1f0 + d59f733 commit c744922

File tree

22 files changed

+637
-148
lines changed

22 files changed

+637
-148
lines changed

paddle/fluid/operators/jit/benchmark.cc

Lines changed: 62 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ void BenchAllImpls(const typename KernelTuples::attr_type& attr, Args... args) {
158158

159159
using Tensor = paddle::framework::Tensor;
160160

161-
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
161+
template <jit::KernelType KT, typename T, typename PlaceType>
162162
void BenchXYZNKernel() {
163163
for (int d : TestSizes()) {
164164
Tensor x, y, z;
@@ -175,7 +175,7 @@ void BenchXYZNKernel() {
175175
}
176176
}
177177

178-
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
178+
template <jit::KernelType KT, typename T, typename PlaceType>
179179
void BenchAXYNKernel() {
180180
for (int d : TestSizes()) {
181181
const T a = static_cast<T>(3);
@@ -187,10 +187,23 @@ void BenchAXYNKernel() {
187187
RandomVec<T>(d, x_data);
188188
BenchAllImpls<KT, jit::AXYNTuples<T>, PlaceType>(d, &a, x.data<T>(), y_data,
189189
d);
190+
// test inplace
191+
BenchAllImpls<KT, jit::AXYNTuples<T>, PlaceType>(d, &a, x.data<T>(), x_data,
192+
d);
193+
}
194+
}
195+
196+
template <jit::KernelType KT, typename T, typename PlaceType>
197+
void BenchXRNKernel() {
198+
for (int d : TestSizes()) {
199+
Tensor x;
200+
RandomVec<T>(d, x.mutable_data<T>({d}, PlaceType()));
201+
T res;
202+
BenchAllImpls<KT, jit::XRNTuples<T>, PlaceType>(d, x.data<T>(), &res, d);
190203
}
191204
}
192205

193-
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
206+
template <jit::KernelType KT, typename T, typename PlaceType>
194207
void BenchXYNKernel() {
195208
for (int d : TestSizes()) {
196209
Tensor x, y;
@@ -203,7 +216,7 @@ void BenchXYNKernel() {
203216
}
204217
}
205218

206-
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
219+
template <jit::KernelType KT, typename T, typename PlaceType>
207220
void BenchLSTMKernel() {
208221
for (bool use_peephole : {true, false}) {
209222
for (int d : TestSizes()) {
@@ -240,7 +253,7 @@ void BenchLSTMKernel() {
240253
}
241254
}
242255

243-
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
256+
template <jit::KernelType KT, typename T, typename PlaceType>
244257
void BenchGRUKernel() {
245258
for (int d : TestSizes()) {
246259
const jit::gru_attr_t attr(d, jit::kVSigmoid, jit::kVTanh);
@@ -262,7 +275,7 @@ void BenchGRUKernel() {
262275
}
263276
}
264277

265-
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
278+
template <jit::KernelType KT, typename T, typename PlaceType>
266279
void BenchSeqPoolKernel() {
267280
std::vector<jit::SeqPoolType> pool_types = {
268281
jit::SeqPoolType::kSum, jit::SeqPoolType::kAvg, jit::SeqPoolType::kSqrt};
@@ -284,7 +297,7 @@ void BenchSeqPoolKernel() {
284297
}
285298
}
286299

287-
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
300+
template <jit::KernelType KT, typename T, typename PlaceType>
288301
void BenchMatMulKernel() {
289302
for (int m : {1, 2, 3, 4}) {
290303
for (int n : TestSizes()) {
@@ -305,57 +318,64 @@ void BenchMatMulKernel() {
305318
}
306319
}
307320

321+
template <jit::KernelType KT, typename T, typename PlaceType>
322+
void BenchSoftmaxKernel() {
323+
for (int bs : {1, 2, 10}) {
324+
for (int n : TestSizes()) {
325+
Tensor x, y;
326+
x.Resize({bs, n});
327+
y.Resize({bs, n});
328+
RandomVec<T>(bs * n, x.mutable_data<T>(PlaceType()), -2.f, 2.f);
329+
const T* x_data = x.data<T>();
330+
T* y_data = y.mutable_data<T>(PlaceType());
331+
BenchAllImpls<KT, jit::SoftmaxTuples<T>, PlaceType>(n, x_data, y_data, n,
332+
bs);
333+
}
334+
}
335+
}
336+
308337
using T = float;
309-
using PlaceType = paddle::platform::CPUPlace;
338+
using CPUPlace = paddle::platform::CPUPlace;
310339

311340
// xyzn
312-
BENCH_FP32_CPU(kVMul) { BenchXYZNKernel<jit::kVMul, T, PlaceType>(); }
313-
314-
BENCH_FP32_CPU(kVAdd) { BenchXYZNKernel<jit::kVAdd, T, PlaceType>(); }
315-
316-
BENCH_FP32_CPU(kVAddRelu) { BenchXYZNKernel<jit::kVAddRelu, T, PlaceType>(); }
317-
318-
BENCH_FP32_CPU(kVSub) { BenchXYZNKernel<jit::kVSub, T, PlaceType>(); }
341+
BENCH_FP32_CPU(kVMul) { BenchXYZNKernel<jit::kVMul, T, CPUPlace>(); }
342+
BENCH_FP32_CPU(kVAdd) { BenchXYZNKernel<jit::kVAdd, T, CPUPlace>(); }
343+
BENCH_FP32_CPU(kVAddRelu) { BenchXYZNKernel<jit::kVAddRelu, T, CPUPlace>(); }
344+
BENCH_FP32_CPU(kVSub) { BenchXYZNKernel<jit::kVSub, T, CPUPlace>(); }
319345

320346
// axyn
321-
BENCH_FP32_CPU(kVScal) { BenchAXYNKernel<jit::kVScal, T, PlaceType>(); }
347+
BENCH_FP32_CPU(kVScal) { BenchAXYNKernel<jit::kVScal, T, CPUPlace>(); }
348+
BENCH_FP32_CPU(kVAddBias) { BenchAXYNKernel<jit::kVAddBias, T, CPUPlace>(); }
322349

323-
BENCH_FP32_CPU(kVAddBias) { BenchAXYNKernel<jit::kVAddBias, T, PlaceType>(); }
350+
// xrn
351+
BENCH_FP32_CPU(kHSum) { BenchXRNKernel<jit::kHSum, T, CPUPlace>(); }
352+
BENCH_FP32_CPU(kHMax) { BenchXRNKernel<jit::kHMax, T, CPUPlace>(); }
324353

325354
// xyn
326-
BENCH_FP32_CPU(kVRelu) { BenchXYNKernel<jit::kVRelu, T, PlaceType>(); }
327-
328-
BENCH_FP32_CPU(kVIdentity) { BenchXYNKernel<jit::kVIdentity, T, PlaceType>(); }
329-
330-
BENCH_FP32_CPU(kVSquare) { BenchXYNKernel<jit::kVSquare, T, PlaceType>(); }
331-
332-
BENCH_FP32_CPU(kVExp) { BenchXYNKernel<jit::kVExp, T, PlaceType>(); }
333-
334-
BENCH_FP32_CPU(kVSigmoid) { BenchXYNKernel<jit::kVSigmoid, T, PlaceType>(); }
335-
336-
BENCH_FP32_CPU(kVTanh) { BenchXYNKernel<jit::kVTanh, T, PlaceType>(); }
355+
BENCH_FP32_CPU(kVRelu) { BenchXYNKernel<jit::kVRelu, T, CPUPlace>(); }
356+
BENCH_FP32_CPU(kVIdentity) { BenchXYNKernel<jit::kVIdentity, T, CPUPlace>(); }
357+
BENCH_FP32_CPU(kVSquare) { BenchXYNKernel<jit::kVSquare, T, CPUPlace>(); }
358+
BENCH_FP32_CPU(kVExp) { BenchXYNKernel<jit::kVExp, T, CPUPlace>(); }
359+
BENCH_FP32_CPU(kVSigmoid) { BenchXYNKernel<jit::kVSigmoid, T, CPUPlace>(); }
360+
BENCH_FP32_CPU(kVTanh) { BenchXYNKernel<jit::kVTanh, T, CPUPlace>(); }
337361

338362
// lstm and peephole
339-
BENCH_FP32_CPU(kLSTMCtHt) { BenchLSTMKernel<jit::kLSTMCtHt, T, PlaceType>(); }
340-
341-
BENCH_FP32_CPU(kLSTMC1H1) { BenchLSTMKernel<jit::kLSTMC1H1, T, PlaceType>(); }
363+
BENCH_FP32_CPU(kLSTMCtHt) { BenchLSTMKernel<jit::kLSTMCtHt, T, CPUPlace>(); }
364+
BENCH_FP32_CPU(kLSTMC1H1) { BenchLSTMKernel<jit::kLSTMC1H1, T, CPUPlace>(); }
342365

343366
// gru functions
344-
BENCH_FP32_CPU(kGRUH1) { BenchGRUKernel<jit::kGRUH1, T, PlaceType>(); }
345-
346-
BENCH_FP32_CPU(kGRUHtPart1) {
347-
BenchGRUKernel<jit::kGRUHtPart1, T, PlaceType>();
348-
}
349-
350-
BENCH_FP32_CPU(kGRUHtPart2) {
351-
BenchGRUKernel<jit::kGRUHtPart2, T, PlaceType>();
352-
}
367+
BENCH_FP32_CPU(kGRUH1) { BenchGRUKernel<jit::kGRUH1, T, CPUPlace>(); }
368+
BENCH_FP32_CPU(kGRUHtPart1) { BenchGRUKernel<jit::kGRUHtPart1, T, CPUPlace>(); }
369+
BENCH_FP32_CPU(kGRUHtPart2) { BenchGRUKernel<jit::kGRUHtPart2, T, CPUPlace>(); }
353370

354371
// seq pool function
355-
BENCH_FP32_CPU(kSeqPool) { BenchSeqPoolKernel<jit::kSeqPool, T, PlaceType>(); }
372+
BENCH_FP32_CPU(kSeqPool) { BenchSeqPoolKernel<jit::kSeqPool, T, CPUPlace>(); }
356373

357374
// matmul
358-
BENCH_FP32_CPU(kMatMul) { BenchMatMulKernel<jit::kMatMul, T, PlaceType>(); }
375+
BENCH_FP32_CPU(kMatMul) { BenchMatMulKernel<jit::kMatMul, T, CPUPlace>(); }
376+
377+
// softmax
378+
BENCH_FP32_CPU(kSoftmax) { BenchSoftmaxKernel<jit::kSoftmax, T, CPUPlace>(); }
359379

360380
// Benchmark all jit kernels including jitcode, mkl and refer.
361381
// To use this tool, run command: ./benchmark [options...]

paddle/fluid/operators/jit/gen/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,3 +28,5 @@ USE_JITKERNEL_GEN(kGRUHtPart1)
2828
USE_JITKERNEL_GEN(kGRUHtPart2)
2929
USE_JITKERNEL_GEN(kNCHW16CMulNC)
3030
USE_JITKERNEL_GEN(kSeqPool)
31+
USE_JITKERNEL_GEN(kHMax)
32+
USE_JITKERNEL_GEN(kHSum)

paddle/fluid/operators/jit/gen/act.cc

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,9 +81,7 @@ void VActJitCode::genCode() {
8181
#define DECLARE_ACT_CREATOR(name) \
8282
class name##Creator : public JitCodeCreator<int> { \
8383
public: \
84-
bool UseMe(const int& attr) const override { \
85-
return platform::MayIUse(platform::avx); \
86-
} \
84+
bool UseMe(const int& attr) const override; \
8785
size_t CodeSize(const int& d) const override; \
8886
std::unique_ptr<GenBase> CreateJitCode(const int& attr) const override { \
8987
return make_unique<name##JitCode>(attr, CodeSize(attr)); \
@@ -98,6 +96,30 @@ DECLARE_ACT_CREATOR(VSigmoid);
9896
DECLARE_ACT_CREATOR(VTanh);
9997

10098
// TODO(TJ): tuning use me
99+
bool VReluCreator::UseMe(const int& d) const {
100+
return platform::MayIUse(platform::avx);
101+
}
102+
103+
bool VSquareCreator::UseMe(const int& d) const {
104+
return platform::MayIUse(platform::avx);
105+
}
106+
107+
bool VIdentityCreator::UseMe(const int& d) const {
108+
return platform::MayIUse(platform::avx);
109+
}
110+
111+
bool VExpCreator::UseMe(const int& d) const {
112+
return platform::MayIUse(platform::avx) && d < 32;
113+
}
114+
115+
bool VSigmoidCreator::UseMe(const int& d) const {
116+
return platform::MayIUse(platform::avx);
117+
}
118+
119+
bool VTanhCreator::UseMe(const int& d) const {
120+
return platform::MayIUse(platform::avx);
121+
}
122+
101123
size_t VReluCreator::CodeSize(const int& d) const {
102124
return 96 /* init size */ +
103125
(d / YMM_FLOAT_BLOCK + 3) * 4 /* instructions */ *
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
*
3+
* Licensed under the Apache License, Version 2.0 (the "License");
4+
* you may not use this file except in compliance with the License.
5+
* You may obtain a copy of the License at
6+
*
7+
* http://www.apache.org/licenses/LICENSE-2.0
8+
*
9+
* Unless required by applicable law or agreed to in writing, software
10+
* distributed under the License is distributed on an "AS IS" BASIS,
11+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
* See the License for the specific language governing permissions and
13+
* limitations under the License. */
14+
15+
#include "paddle/fluid/operators/jit/gen/hopv.h"
16+
#include "paddle/fluid/operators/jit/registry.h"
17+
#include "paddle/fluid/platform/cpu_info.h"
18+
19+
namespace paddle {
20+
namespace operators {
21+
namespace jit {
22+
namespace gen {
23+
24+
void HOPVJitCode::genCode() {
25+
const int num_blocks = num_ / YMM_FLOAT_BLOCK;
26+
int offset = 0;
27+
28+
if (num_blocks > 0) {
29+
// load one firstly
30+
vmovups(ymm_tmp, ptr[param_src]);
31+
offset += sizeof(float) * YMM_FLOAT_BLOCK;
32+
for (int i = 1; i < num_blocks; ++i) {
33+
vmovups(ymm_src, ptr[param_src + offset]);
34+
process(ymm_tmp, ymm_src, ymm_tmp);
35+
offset += sizeof(float) * YMM_FLOAT_BLOCK;
36+
}
37+
vextractf128(xmm_dst, ymm_tmp, 1);
38+
process(xmm_dst, xmm_dst, xmm_tmp);
39+
} else {
40+
if (type_ == operand_type::MAX) {
41+
vbroadcastss(ymm_dst, ptr[param_src]);
42+
} else if (type_ == operand_type::ADD) {
43+
vxorps(ymm_dst, ymm_dst, ymm_dst);
44+
}
45+
}
46+
47+
int rest = num_ % YMM_FLOAT_BLOCK;
48+
if (rest >= 4) {
49+
vmovups(xmm_src, ptr[param_src + offset]);
50+
offset += sizeof(float) * 4;
51+
rest -= 4;
52+
process(xmm_dst, xmm_dst, xmm_src);
53+
}
54+
55+
vpermilps(xmm_tmp, xmm_dst, 16 + 8 + 3);
56+
process(xmm_dst, xmm_dst, xmm_tmp);
57+
58+
if (rest >= 2) {
59+
vmovq(xmm_src, ptr[param_src + offset]);
60+
offset += sizeof(float) * 2;
61+
rest -= 2;
62+
process(xmm_dst, xmm_dst, xmm_src);
63+
}
64+
65+
vpermilps(xmm_tmp, xmm_dst, 1);
66+
process(xmm_dst, xmm_dst, xmm_tmp);
67+
68+
if (rest >= 1) {
69+
vmovss(xmm_src, ptr[param_src + offset]);
70+
process(xmm_dst, xmm_dst, xmm_src);
71+
}
72+
vmovss(ptr[param_dst], xmm_dst);
73+
ret();
74+
}
75+
76+
#define DECLARE_HOP_CREATOR(name) \
77+
class name##Creator : public JitCodeCreator<int> { \
78+
public: \
79+
bool UseMe(const int& attr) const override { \
80+
return platform::MayIUse(platform::avx); \
81+
} \
82+
size_t CodeSize(const int& d) const override { \
83+
return 96 + d / YMM_FLOAT_BLOCK * 4 * 8; \
84+
} \
85+
std::unique_ptr<GenBase> CreateJitCode(const int& attr) const override { \
86+
return make_unique<name##JitCode>(attr, CodeSize(attr)); \
87+
} \
88+
}
89+
90+
DECLARE_HOP_CREATOR(HMax);
91+
DECLARE_HOP_CREATOR(HSum);
92+
93+
#undef DECLARE_HOP_CREATOR
94+
95+
} // namespace gen
96+
} // namespace jit
97+
} // namespace operators
98+
} // namespace paddle
99+
100+
namespace gen = paddle::operators::jit::gen;
101+
102+
REGISTER_JITKERNEL_GEN(kHMax, gen::HMaxCreator);
103+
REGISTER_JITKERNEL_GEN(kHSum, gen::HSumCreator);

0 commit comments

Comments
 (0)