Skip to content

Commit 8117725

Browse files
committed
add jit kernel hsum, hmax and softmax refer code
test=develop
1 parent 67e4450 commit 8117725

File tree

8 files changed

+269
-121
lines changed

8 files changed

+269
-121
lines changed

paddle/fluid/operators/jit/benchmark.cc

Lines changed: 59 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ void BenchAllImpls(const typename KernelTuples::attr_type& attr, Args... args) {
158158

159159
using Tensor = paddle::framework::Tensor;
160160

161-
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
161+
template <jit::KernelType KT, typename T, typename PlaceType>
162162
void BenchXYZNKernel() {
163163
for (int d : TestSizes()) {
164164
Tensor x, y, z;
@@ -175,7 +175,7 @@ void BenchXYZNKernel() {
175175
}
176176
}
177177

178-
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
178+
template <jit::KernelType KT, typename T, typename PlaceType>
179179
void BenchAXYNKernel() {
180180
for (int d : TestSizes()) {
181181
const T a = static_cast<T>(3);
@@ -190,7 +190,17 @@ void BenchAXYNKernel() {
190190
}
191191
}
192192

193-
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
193+
template <jit::KernelType KT, typename T, typename PlaceType>
194+
void BenchXRNKernel() {
195+
for (int d : TestSizes()) {
196+
Tensor x;
197+
RandomVec<T>(d, x.mutable_data<T>({d}, PlaceType()));
198+
T res;
199+
BenchAllImpls<KT, jit::XRNTuples<T>, PlaceType>(d, x.data<T>(), &res, d);
200+
}
201+
}
202+
203+
template <jit::KernelType KT, typename T, typename PlaceType>
194204
void BenchXYNKernel() {
195205
for (int d : TestSizes()) {
196206
Tensor x, y;
@@ -203,7 +213,7 @@ void BenchXYNKernel() {
203213
}
204214
}
205215

206-
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
216+
template <jit::KernelType KT, typename T, typename PlaceType>
207217
void BenchLSTMKernel() {
208218
for (bool use_peephole : {true, false}) {
209219
for (int d : TestSizes()) {
@@ -240,7 +250,7 @@ void BenchLSTMKernel() {
240250
}
241251
}
242252

243-
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
253+
template <jit::KernelType KT, typename T, typename PlaceType>
244254
void BenchGRUKernel() {
245255
for (int d : TestSizes()) {
246256
const jit::gru_attr_t attr(d, jit::kVSigmoid, jit::kVTanh);
@@ -262,7 +272,7 @@ void BenchGRUKernel() {
262272
}
263273
}
264274

265-
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
275+
template <jit::KernelType KT, typename T, typename PlaceType>
266276
void BenchSeqPoolKernel() {
267277
std::vector<jit::SeqPoolType> pool_types = {
268278
jit::SeqPoolType::kSum, jit::SeqPoolType::kAvg, jit::SeqPoolType::kSqrt};
@@ -284,7 +294,7 @@ void BenchSeqPoolKernel() {
284294
}
285295
}
286296

287-
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
297+
template <jit::KernelType KT, typename T, typename PlaceType>
288298
void BenchMatMulKernel() {
289299
for (int m : {1, 2, 3, 4}) {
290300
for (int n : TestSizes()) {
@@ -305,57 +315,64 @@ void BenchMatMulKernel() {
305315
}
306316
}
307317

318+
template <jit::KernelType KT, typename T, typename PlaceType>
319+
void BenchSoftmaxKernel() {
320+
for (int bs : {1, 2, 10}) {
321+
for (int n : TestSizes()) {
322+
Tensor x, y;
323+
x.Resize({bs, n});
324+
y.Resize({bs, n});
325+
RandomVec<T>(bs * n, x.mutable_data<T>(PlaceType()), -2.f, 2.f);
326+
const T* x_data = x.data<T>();
327+
T* y_data = y.mutable_data<T>(PlaceType());
328+
BenchAllImpls<KT, jit::SoftmaxTuples<T>, PlaceType>(n, x_data, y_data, n,
329+
bs);
330+
}
331+
}
332+
}
333+
308334
using T = float;
309-
using PlaceType = paddle::platform::CPUPlace;
335+
using CPUPlace = paddle::platform::CPUPlace;
310336

311337
// xyzn
312-
BENCH_FP32_CPU(kVMul) { BenchXYZNKernel<jit::kVMul, T, PlaceType>(); }
313-
314-
BENCH_FP32_CPU(kVAdd) { BenchXYZNKernel<jit::kVAdd, T, PlaceType>(); }
315-
316-
BENCH_FP32_CPU(kVAddRelu) { BenchXYZNKernel<jit::kVAddRelu, T, PlaceType>(); }
317-
318-
BENCH_FP32_CPU(kVSub) { BenchXYZNKernel<jit::kVSub, T, PlaceType>(); }
338+
BENCH_FP32_CPU(kVMul) { BenchXYZNKernel<jit::kVMul, T, CPUPlace>(); }
339+
BENCH_FP32_CPU(kVAdd) { BenchXYZNKernel<jit::kVAdd, T, CPUPlace>(); }
340+
BENCH_FP32_CPU(kVAddRelu) { BenchXYZNKernel<jit::kVAddRelu, T, CPUPlace>(); }
341+
BENCH_FP32_CPU(kVSub) { BenchXYZNKernel<jit::kVSub, T, CPUPlace>(); }
319342

320343
// axyn
321-
BENCH_FP32_CPU(kVScal) { BenchAXYNKernel<jit::kVScal, T, PlaceType>(); }
344+
BENCH_FP32_CPU(kVScal) { BenchAXYNKernel<jit::kVScal, T, CPUPlace>(); }
345+
BENCH_FP32_CPU(kVAddBias) { BenchAXYNKernel<jit::kVAddBias, T, CPUPlace>(); }
322346

323-
BENCH_FP32_CPU(kVAddBias) { BenchAXYNKernel<jit::kVAddBias, T, PlaceType>(); }
347+
// xrn
348+
BENCH_FP32_CPU(kHSum) { BenchXRNKernel<jit::kHSum, T, CPUPlace>(); }
349+
BENCH_FP32_CPU(kHMax) { BenchXRNKernel<jit::kHMax, T, CPUPlace>(); }
324350

325351
// xyn
326-
BENCH_FP32_CPU(kVRelu) { BenchXYNKernel<jit::kVRelu, T, PlaceType>(); }
327-
328-
BENCH_FP32_CPU(kVIdentity) { BenchXYNKernel<jit::kVIdentity, T, PlaceType>(); }
329-
330-
BENCH_FP32_CPU(kVSquare) { BenchXYNKernel<jit::kVSquare, T, PlaceType>(); }
331-
332-
BENCH_FP32_CPU(kVExp) { BenchXYNKernel<jit::kVExp, T, PlaceType>(); }
333-
334-
BENCH_FP32_CPU(kVSigmoid) { BenchXYNKernel<jit::kVSigmoid, T, PlaceType>(); }
335-
336-
BENCH_FP32_CPU(kVTanh) { BenchXYNKernel<jit::kVTanh, T, PlaceType>(); }
352+
BENCH_FP32_CPU(kVRelu) { BenchXYNKernel<jit::kVRelu, T, CPUPlace>(); }
353+
BENCH_FP32_CPU(kVIdentity) { BenchXYNKernel<jit::kVIdentity, T, CPUPlace>(); }
354+
BENCH_FP32_CPU(kVSquare) { BenchXYNKernel<jit::kVSquare, T, CPUPlace>(); }
355+
BENCH_FP32_CPU(kVExp) { BenchXYNKernel<jit::kVExp, T, CPUPlace>(); }
356+
BENCH_FP32_CPU(kVSigmoid) { BenchXYNKernel<jit::kVSigmoid, T, CPUPlace>(); }
357+
BENCH_FP32_CPU(kVTanh) { BenchXYNKernel<jit::kVTanh, T, CPUPlace>(); }
337358

338359
// lstm and peephole
339-
BENCH_FP32_CPU(kLSTMCtHt) { BenchLSTMKernel<jit::kLSTMCtHt, T, PlaceType>(); }
340-
341-
BENCH_FP32_CPU(kLSTMC1H1) { BenchLSTMKernel<jit::kLSTMC1H1, T, PlaceType>(); }
360+
BENCH_FP32_CPU(kLSTMCtHt) { BenchLSTMKernel<jit::kLSTMCtHt, T, CPUPlace>(); }
361+
BENCH_FP32_CPU(kLSTMC1H1) { BenchLSTMKernel<jit::kLSTMC1H1, T, CPUPlace>(); }
342362

343363
// gru functions
344-
BENCH_FP32_CPU(kGRUH1) { BenchGRUKernel<jit::kGRUH1, T, PlaceType>(); }
345-
346-
BENCH_FP32_CPU(kGRUHtPart1) {
347-
BenchGRUKernel<jit::kGRUHtPart1, T, PlaceType>();
348-
}
349-
350-
BENCH_FP32_CPU(kGRUHtPart2) {
351-
BenchGRUKernel<jit::kGRUHtPart2, T, PlaceType>();
352-
}
364+
BENCH_FP32_CPU(kGRUH1) { BenchGRUKernel<jit::kGRUH1, T, CPUPlace>(); }
365+
BENCH_FP32_CPU(kGRUHtPart1) { BenchGRUKernel<jit::kGRUHtPart1, T, CPUPlace>(); }
366+
BENCH_FP32_CPU(kGRUHtPart2) { BenchGRUKernel<jit::kGRUHtPart2, T, CPUPlace>(); }
353367

354368
// seq pool function
355-
BENCH_FP32_CPU(kSeqPool) { BenchSeqPoolKernel<jit::kSeqPool, T, PlaceType>(); }
369+
BENCH_FP32_CPU(kSeqPool) { BenchSeqPoolKernel<jit::kSeqPool, T, CPUPlace>(); }
356370

357371
// matmul
358-
BENCH_FP32_CPU(kMatMul) { BenchMatMulKernel<jit::kMatMul, T, PlaceType>(); }
372+
BENCH_FP32_CPU(kMatMul) { BenchMatMulKernel<jit::kMatMul, T, CPUPlace>(); }
373+
374+
// softmax
375+
BENCH_FP32_CPU(kSoftmax) { BenchSoftmaxKernel<jit::kSoftmax, T, CPUPlace>(); }
359376

360377
// Benchmark all jit kernels including jitcode, mkl and refer.
361378
// To use this tool, run command: ./benchmark [options...]

paddle/fluid/operators/jit/helper.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,9 @@ const char* to_string(KernelType kt) {
4949
ONE_CASE(kNCHW16CMulNC);
5050
ONE_CASE(kSeqPool);
5151
ONE_CASE(kMatMul);
52+
ONE_CASE(kHMax);
53+
ONE_CASE(kHSum);
54+
ONE_CASE(kSoftmax);
5255
default:
5356
PADDLE_THROW("Not support type: %d, or forget to add it.", kt);
5457
return "NOT JITKernel";

paddle/fluid/operators/jit/kernel_base.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ namespace paddle {
2020
namespace operators {
2121
namespace jit {
2222

23+
// TODO(TJ): reorder by alphabet
2324
typedef enum {
2425
kNone = 0,
2526
kVMul = 1,
@@ -44,6 +45,9 @@ typedef enum {
4445
kNCHW16CMulNC,
4546
kSeqPool,
4647
kMatMul,
48+
kHSum, // horizontal max
49+
kHMax, // horizontal sum
50+
kSoftmax,
4751
} KernelType;
4852

4953
typedef enum {
@@ -70,6 +74,10 @@ struct XYNTuples {
7074
typedef void (*func_type)(const T*, T*, int);
7175
};
7276

77+
// x, return and int
78+
template <typename T>
79+
struct XRNTuples : public XYNTuples<T> {};
80+
7381
typedef struct {
7482
void* gates; // gates: x_ch, x_ih, x_fh, x_oh
7583
const void* ct_1;
@@ -159,6 +167,13 @@ struct LayerNormTuples {
159167
const float, int);
160168
};
161169

170+
template <typename T>
171+
struct SoftmaxTuples {
172+
typedef T data_type;
173+
typedef int attr_type;
174+
typedef void (*func_type)(const T*, T*, int, int);
175+
};
176+
162177
// nChw16c = nChw16c .* NC
163178
template <typename T>
164179
struct NCHW16CMulNCTuples {

paddle/fluid/operators/jit/refer/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,6 @@ USE_JITKERNEL_REFER(kNCHW16CMulNC)
2929
USE_JITKERNEL_REFER(kSeqPool)
3030
USE_JITKERNEL_REFER(kMatMul)
3131
USE_JITKERNEL_REFER(kVSquare)
32+
USE_JITKERNEL_REFER(kHSum)
33+
USE_JITKERNEL_REFER(kHMax)
34+
USE_JITKERNEL_REFER(kSoftmax)

paddle/fluid/operators/jit/refer/refer.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,4 +52,9 @@ REGISTER_REFER_KERNEL(kSeqPool, SeqPool);
5252

5353
REGISTER_REFER_KERNEL(kMatMul, MatMul);
5454

55+
REGISTER_REFER_KERNEL(kHMax, HMax);
56+
REGISTER_REFER_KERNEL(kHSum, HSum);
57+
58+
REGISTER_REFER_KERNEL(kSoftmax, Softmax);
59+
5560
#undef REGISTER_REFER_KERNEL

paddle/fluid/operators/jit/refer/refer.h

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,40 @@ void MatMul(const T* A, const T* B, T* C, int M, int N, int K) {
378378
}
379379
}
380380

381+
template <typename T>
382+
void HMax(const T* x, T* res, int n) {
383+
res[0] = x[0];
384+
for (int i = 1; i < n; ++i) {
385+
res[0] = res[0] < x[i] ? x[i] : res[0];
386+
}
387+
}
388+
389+
template <typename T>
390+
void HSum(const T* x, T* res, int n) {
391+
res[0] = x[0];
392+
for (int i = 1; i < n; ++i) {
393+
res[0] += x[i];
394+
}
395+
}
396+
397+
// y = e^(x - max(x))
398+
// y = y / sum(y)
399+
template <typename T>
400+
void Softmax(const T* x, T* y, int n, int bs = 1) {
401+
for (int i = 0; i < bs; ++i) {
402+
T scalar;
403+
HMax(x, &scalar, n);
404+
scalar = static_cast<T>(0) - scalar;
405+
VAddBias(&scalar, x, y, n); // x - max
406+
VExp(y, y, n);
407+
HSum(y, &scalar, n);
408+
scalar = static_cast<T>(1) / scalar;
409+
VScal(&scalar, y, y, n);
410+
x += n;
411+
y += n;
412+
}
413+
}
414+
381415
#define DECLARE_REFER_KERNEL(name, tuples) \
382416
template <typename T> \
383417
class name##Kernel : public ReferKernel<tuples<T>> { \
@@ -421,6 +455,11 @@ DECLARE_REFER_KERNEL(SeqPool, SeqPoolTuples);
421455

422456
DECLARE_REFER_KERNEL(MatMul, MatMulTuples);
423457

458+
DECLARE_REFER_KERNEL(HMax, XRNTuples);
459+
DECLARE_REFER_KERNEL(HSum, XRNTuples);
460+
461+
DECLARE_REFER_KERNEL(Softmax, SoftmaxTuples);
462+
424463
#undef DECLARE_REFER_KERNEL
425464

426465
} // namespace refer

0 commit comments

Comments
 (0)