Skip to content

Commit f0f0699

Browse files
authored
Merge pull request #12878 from tensor-tang/feature/op/attention_lstm
Add attention lstm cpu forward
2 parents 5ea7bf8 + 4e538db commit f0f0699

File tree

12 files changed

+969
-83
lines changed

12 files changed

+969
-83
lines changed

CMakeLists.txt

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -138,12 +138,6 @@ else()
138138
set(THIRD_PARTY_BUILD_TYPE Release)
139139
endif()
140140

141-
if(WITH_MKL)
142-
option(MKL_SPLIT_GEMM "PaddlePaddle MKL gemm would split to small ones" OFF)
143-
if (MKL_SPLIT_GEMM)
144-
add_definitions(-DPADDLE_MKL_SPLIT_GEMM)
145-
endif()
146-
endif()
147141
set(WITH_MKLML ${WITH_MKL})
148142
if (NOT DEFINED WITH_MKLDNN)
149143
if (WITH_MKL AND AVX2_FOUND)

paddle/fluid/operators/attention_lstm_op.cc

Lines changed: 422 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
#include "paddle/fluid/framework/op_registry.h"
17+
18+
namespace paddle {
19+
namespace operators {
20+
21+
using LoDTensor = framework::LoDTensor;
22+
using Tensor = framework::Tensor;
23+
24+
class AttentionLSTMOp : public framework::OperatorWithKernel {
25+
public:
26+
using framework::OperatorWithKernel::OperatorWithKernel;
27+
28+
void InferShape(framework::InferShapeContext* ctx) const override;
29+
30+
protected:
31+
framework::OpKernelType GetExpectedKernelType(
32+
const framework::ExecutionContext& ctx) const override;
33+
};
34+
35+
class AttentionLSTMOpMaker : public framework::OpProtoAndCheckerMaker {
36+
public:
37+
void Make() override;
38+
};
39+
40+
} // namespace operators
41+
} // namespace paddle

paddle/fluid/operators/fusion_lstm_op.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

1515
#pragma once
16-
// #include <string>
1716
#include "paddle/fluid/framework/op_registry.h"
1817

1918
namespace paddle {

paddle/fluid/operators/math/blas.h

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,11 @@ class Blas {
9090
void GEMM(bool transA, bool transB, int M, int N, int K, T alpha, const T* A,
9191
int lda, const T* B, int ldb, T beta, T* C, int ldc) const;
9292

93+
template <typename T>
94+
void GEMM(CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K,
95+
T alpha, const T* A, int lda, const T* B, int ldb, T beta, T* C,
96+
int ldc) const;
97+
9398
#ifdef PADDLE_WITH_MKLML
9499
template <typename T>
95100
T* GEMM_ALLOC(const CBLAS_IDENTIFIER id, const int M, const int N,
@@ -109,6 +114,10 @@ class Blas {
109114
void GEMM_FREE(T* data) const;
110115
#endif
111116

117+
template <typename T>
118+
void MatMul(const int M, const int N, const int K, const T* A, const T* B,
119+
T* C) const;
120+
112121
template <typename T>
113122
void MatMul(const framework::Tensor& mat_a, bool trans_a,
114123
const framework::Tensor& mat_b, bool trans_b, T alpha,
@@ -140,10 +149,19 @@ class Blas {
140149
template <typename T>
141150
void VCOPY(int n, const T* x, T* y) const;
142151

152+
template <typename T>
153+
void VEXP(int n, const T* x, T* y) const;
154+
143155
template <typename T>
144156
void GEMV(bool trans_a, int M, int N, T alpha, const T* A, const T* B, T beta,
145157
T* C) const;
146158

159+
template <typename T>
160+
T DOT(int n, const T* x, const T* y) const;
161+
162+
template <typename T>
163+
void SCAL(int n, const T a, T* x) const;
164+
147165
template <typename T>
148166
void BatchedGEMM(CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N,
149167
int K, T alpha, const T* A, const T* B, T beta, T* C,
@@ -215,11 +233,26 @@ class BlasT : private Blas<DeviceContext> {
215233
Base()->template VCOPY<T>(args...);
216234
}
217235

236+
template <typename... ARGS>
237+
void VEXP(ARGS... args) const {
238+
Base()->template VEXP<T>(args...);
239+
}
240+
218241
template <typename... ARGS>
219242
void GEMV(ARGS... args) const {
220243
Base()->template GEMV<T>(args...);
221244
}
222245

246+
template <typename... ARGS>
247+
T DOT(ARGS... args) const {
248+
return Base()->template DOT<T>(args...);
249+
}
250+
251+
template <typename... ARGS>
252+
void SCAL(ARGS... args) const {
253+
Base()->template SCAL<T>(args...);
254+
}
255+
223256
template <typename... ARGS>
224257
void BatchedGEMM(ARGS... args) const {
225258
Base()->template BatchedGEMM<T>(args...);

paddle/fluid/operators/math/blas_impl.h

Lines changed: 126 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,16 @@ struct CBlas<float> {
7373
platform::dynload::cblas_sgemv(args...);
7474
}
7575

76+
template <typename... ARGS>
77+
static float DOT(ARGS... args) {
78+
return platform::dynload::cblas_sdot(args...);
79+
}
80+
81+
template <typename... ARGS>
82+
static void SCAL(ARGS... args) {
83+
platform::dynload::cblas_sscal(args...);
84+
}
85+
7686
template <typename... ARGS>
7787
static void GEMM_BATCH(ARGS... args) {
7888
platform::dynload::cblas_sgemm_batch(args...);
@@ -87,6 +97,11 @@ struct CBlas<float> {
8797
static void VMUL(ARGS... args) {
8898
platform::dynload::vsMul(args...);
8999
}
100+
101+
template <typename... ARGS>
102+
static void VEXP(ARGS... args) {
103+
platform::dynload::vsExp(args...);
104+
}
90105
};
91106

92107
template <>
@@ -138,6 +153,16 @@ struct CBlas<double> {
138153
platform::dynload::cblas_dgemv(args...);
139154
}
140155

156+
template <typename... ARGS>
157+
static double DOT(ARGS... args) {
158+
return platform::dynload::cblas_ddot(args...);
159+
}
160+
161+
template <typename... ARGS>
162+
static void SCAL(ARGS... args) {
163+
platform::dynload::cblas_dscal(args...);
164+
}
165+
141166
template <typename... ARGS>
142167
static void GEMM_BATCH(ARGS... args) {
143168
platform::dynload::cblas_dgemm_batch(args...);
@@ -152,6 +177,11 @@ struct CBlas<double> {
152177
static void VMUL(ARGS... args) {
153178
platform::dynload::vdMul(args...);
154179
}
180+
181+
template <typename... ARGS>
182+
static void VEXP(ARGS... args) {
183+
platform::dynload::vdExp(args...);
184+
}
155185
};
156186

157187
#else
@@ -210,71 +240,16 @@ struct CBlas<platform::float16> {
210240
PADDLE_THROW("float16 SMM_GEMM not supported on CPU");
211241
}
212242
static void VMUL(...) { PADDLE_THROW("float16 VMUL not supported on CPU"); }
243+
static void VEXP(...) { PADDLE_THROW("float16 VEXP not supported on CPU"); }
244+
static void DOT(...) { PADDLE_THROW("float16 DOT not supported on CPU"); };
245+
static void SCAL(...) { PADDLE_THROW("float16 SCAL not supported on CPU"); };
213246
#ifdef PADDLE_WITH_MKLML
214247
static void GEMM_BATCH(...) {
215248
PADDLE_THROW("float16 GEMM_BATCH not supported on CPU");
216249
}
217250
#endif
218251
};
219252

220-
template <typename T>
221-
inline bool UseXSMM(const int &m, const int &n, const int &k, bool transa,
222-
bool transb, const T &alpha, const T &beta) {
223-
#ifdef PADDLE_WITH_LIBXSMM
224-
// Refer to https://github.com/hfp/libxsmm/blob/master/README.md
225-
// But the threshold is custom
226-
constexpr int LIBXSMM_THRESHOLD = 20 * 20 * 20;
227-
if (m * n * k > LIBXSMM_THRESHOLD || transa || transb ||
228-
std::abs<T>(alpha - static_cast<T>(1) >
229-
std::numeric_limits<T>::epsilon()) ||
230-
std::abs<T>(beta) > std::numeric_limits<T>::epsilon()) {
231-
return false;
232-
} else {
233-
return true;
234-
}
235-
#endif
236-
return false;
237-
}
238-
239-
template <>
240-
inline bool UseXSMM<platform::float16>(const int &m, const int &n, const int &k,
241-
bool transa, bool transb,
242-
const platform::float16 &alpha,
243-
const platform::float16 &beta) {
244-
return false;
245-
}
246-
247-
template <typename T>
248-
inline void GEMM_WARP(CBLAS_ORDER order, CBLAS_TRANSPOSE transA,
249-
CBLAS_TRANSPOSE transB, int M, int N, int K, T alpha,
250-
const T *A, int lda, const T *B, int ldb, T beta, T *C,
251-
int ldc) {
252-
#ifdef PADDLE_WITH_LIBXSMM
253-
if (UseXSMM<T>(M, N, K, transA != CblasNoTrans, transB != CblasNoTrans, alpha,
254-
beta)) {
255-
// Note: SMM use ColMajor
256-
const char transa = 'N';
257-
const char transb = 'N';
258-
CBlas<T>::SMM_GEMM(&transa, &transb, &N, &M, &K, &alpha, B, &ldb, A, &lda,
259-
&beta, C, &ldc);
260-
return;
261-
}
262-
#endif
263-
264-
#ifdef PADDLE_MKL_SPLIT_GEMM
265-
constexpr int bs = 2;
266-
if (M % bs == 0 && transA == CblasNoTrans && transB == CblasNoTrans) {
267-
for (int off = 0; off < M; off += bs) {
268-
CBlas<T>::GEMM(CblasRowMajor, CblasNoTrans, CblasNoTrans, bs, N, K, alpha,
269-
A + off * lda, lda, B, ldb, beta, C + off * ldb, ldc);
270-
}
271-
return;
272-
}
273-
#endif
274-
CBlas<T>::GEMM(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb,
275-
beta, C, ldc);
276-
}
277-
278253
#ifdef PADDLE_WITH_MKLML
279254
template <>
280255
template <typename T>
@@ -319,8 +294,8 @@ void Blas<platform::CPUDeviceContext>::GEMM(CBLAS_TRANSPOSE transA,
319294
int lda = (transA == CblasNoTrans) ? K : M;
320295
int ldb = (transB == CblasNoTrans) ? N : K;
321296
int ldc = N;
322-
GEMM_WARP<T>(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb,
323-
beta, C, ldc);
297+
CBlas<T>::GEMM(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb,
298+
beta, C, ldc);
324299
}
325300

326301
template <>
@@ -329,9 +304,20 @@ void Blas<platform::CPUDeviceContext>::GEMM(bool transA, bool transB, int M,
329304
int N, int K, T alpha, const T *A,
330305
int lda, const T *B, int ldb,
331306
T beta, T *C, int ldc) const {
332-
GEMM_WARP<T>(CblasRowMajor, transA == false ? CblasNoTrans : CblasTrans,
333-
transB == false ? CblasNoTrans : CblasTrans, M, N, K, alpha, A,
334-
lda, B, ldb, beta, C, ldc);
307+
CBlas<T>::GEMM(CblasRowMajor, transA == false ? CblasNoTrans : CblasTrans,
308+
transB == false ? CblasNoTrans : CblasTrans, M, N, K, alpha, A,
309+
lda, B, ldb, beta, C, ldc);
310+
}
311+
312+
template <>
313+
template <typename T>
314+
void Blas<platform::CPUDeviceContext>::GEMM(CBLAS_TRANSPOSE transA,
315+
CBLAS_TRANSPOSE transB, int M,
316+
int N, int K, T alpha, const T *A,
317+
int lda, const T *B, int ldb,
318+
T beta, T *C, int ldc) const {
319+
CBlas<T>::GEMM(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb,
320+
beta, C, ldc);
335321
}
336322

337323
template <typename DeviceContext>
@@ -399,6 +385,47 @@ void Blas<platform::CPUDeviceContext>::VMUL(int n, const T *x, const T *y,
399385
#endif
400386
}
401387

388+
template <>
389+
template <typename T>
390+
void Blas<platform::CPUDeviceContext>::VEXP(int n, const T *x, T *y) const {
391+
#ifdef PADDLE_WITH_MKLML
392+
CBlas<T>::VEXP(n, x, y);
393+
#else
394+
// try to find if openblas support vexp
395+
for (int i = 0; i < n; ++i) {
396+
y[i] = std::exp(x[i]);
397+
}
398+
#endif
399+
}
400+
401+
template <>
402+
template <typename T>
403+
T Blas<platform::CPUDeviceContext>::DOT(int n, const T *x, const T *y) const {
404+
#ifdef PADDLE_WITH_MKLML
405+
return CBlas<T>::DOT(n, x, 1, y, 1);
406+
#else
407+
// try to find if openblas support cblas_dot
408+
T sum = 0;
409+
for (int i = 0; i < n; ++i) {
410+
sum += x[i] * y[i];
411+
}
412+
return sum;
413+
#endif
414+
}
415+
416+
template <>
417+
template <typename T>
418+
void Blas<platform::CPUDeviceContext>::SCAL(int n, const T a, T *x) const {
419+
#ifdef PADDLE_WITH_MKLML
420+
CBlas<T>::SCAL(n, a, x, 1);
421+
#else
422+
// try to find if openblas support cblas_scal
423+
for (int i = 0; i < n; ++i) {
424+
x[i] = a * x[i];
425+
}
426+
#endif
427+
}
428+
402429
template <>
403430
template <typename T>
404431
void Blas<platform::CPUDeviceContext>::GEMV(bool trans_a, int M, int N, T alpha,
@@ -440,6 +467,42 @@ void Blas<platform::CPUDeviceContext>::BatchedGEMM(
440467
#endif
441468
}
442469

470+
template <typename DeviceContext>
471+
template <typename T>
472+
void Blas<DeviceContext>::MatMul(const int M, const int N, const int K,
473+
const T *A, const T *B, T *C) const {
474+
this->template GEMM<T>(CblasRowMajor, CblasNoTrans, CblasNoTrans, M, N, K,
475+
static_cast<T>(1), A, K, B, N, static_cast<T>(0), C,
476+
N);
477+
}
478+
479+
template <>
480+
template <typename T>
481+
void Blas<platform::CPUDeviceContext>::MatMul(const int M, const int N,
482+
const int K, const T *A,
483+
const T *B, T *C) const {
484+
#ifdef PADDLE_WITH_LIBXSMM
485+
// Refer to https://github.com/hfp/libxsmm/blob/master/README.md
486+
// But the threshold is custom constexpr int LIBXSMM_THRESHOLD = 20 * 20 * 20;
487+
488+
// Since the matrix is very small,
489+
// so the unit of calculation is already very fast,
490+
// and the if( M*N*K < LIBXSMM_THRESHOLD) would be overhead,
491+
// use xsmm directly.
492+
// Note: SMM use ColMajor
493+
const char transa = 'N';
494+
const char transb = 'N';
495+
const T alpha = static_cast<T>(1);
496+
const T beta = static_cast<T>(0);
497+
CBlas<T>::SMM_GEMM(&transa, &transb, &N, &M, &K, &alpha, B, &N, A, &K, &beta,
498+
C, &N);
499+
return;
500+
#endif
501+
502+
CBlas<T>::GEMM(CblasRowMajor, CblasNoTrans, CblasNoTrans, M, N, K,
503+
static_cast<T>(1), A, K, B, N, static_cast<T>(0), C, N);
504+
}
505+
443506
template <typename DeviceContext>
444507
template <typename T>
445508
void Blas<DeviceContext>::MatMul(const framework::Tensor &mat_a,

0 commit comments

Comments
 (0)