@@ -13,12 +13,9 @@ See the License for the specific language governing permissions and
13
13
limitations under the License. */
14
14
15
15
#include " paddle/operators/math/detail/hl_activation_functions.h"
16
+ #include " paddle/platform/hostdevice.h"
16
17
17
- #ifdef __CUDA_ARCH__
18
- #define INLINE __device__ inline
19
- #else
20
- #define INLINE inline
21
- #endif
18
+ #include < type_traits>
22
19
23
20
namespace paddle {
24
21
namespace operators {
@@ -30,12 +27,12 @@ namespace forward {
30
27
template <class T >
31
28
class lstm {
32
29
public:
33
- INLINE void operator ()(T &valueIn, T &valueIg, T &valueFg, T &valueOg,
34
- T &prevState, T &state, T &stateAtv, T &output,
35
- T &checkI, T &checkF, T &checkO,
36
- typename hppl::ForwardActType<T>::type actInput,
37
- typename hppl::ForwardActType<T>::type actGate,
38
- typename hppl::ForwardActType<T>::type actState) {
30
+ HOSTDEVICE void operator ()(T &valueIn, T &valueIg, T &valueFg, T &valueOg,
31
+ T &prevState, T &state, T &stateAtv, T &output,
32
+ T &checkI, T &checkF, T &checkO,
33
+ typename hppl::ForwardActType<T>::type actInput,
34
+ typename hppl::ForwardActType<T>::type actGate,
35
+ typename hppl::ForwardActType<T>::type actState) {
39
36
valueIn = actInput (valueIn);
40
37
valueIg = actGate (valueIg + prevState * checkI);
41
38
valueFg = actGate (valueFg + prevState * checkF);
@@ -45,17 +42,19 @@ class lstm {
45
42
output = valueOg * stateAtv;
46
43
}
47
44
#ifndef __NVCC__
48
- #ifndef __AVX__
45
+ #ifndef __AVX__ // If not compiled with AVX instructs. Disable AVX by default
49
46
static const bool avx = false ;
50
47
#else
51
- static const bool avx = true ;
52
- INLINE void operator ()(__m256 &valueIn, __m256 &valueIg, __m256 &valueFg,
53
- __m256 &valueOg, __m256 &prevState, __m256 &state,
54
- __m256 &stateAtv, __m256 &output, __m256 &checkI,
55
- __m256 &checkF, __m256 &checkO,
56
- hppl::Active<__m256>::forward actInput,
57
- hppl::Active<__m256>::forward actGate,
58
- hppl::Active<__m256>::forward actState) {
48
+ // Only float support AVX optimization
49
+ static const bool avx = std::is_same<T, float >::value;
50
+
51
+ HOSTDEVICE void operator ()(__m256 &valueIn, __m256 &valueIg, __m256 &valueFg,
52
+ __m256 &valueOg, __m256 &prevState, __m256 &state,
53
+ __m256 &stateAtv, __m256 &output, __m256 &checkI,
54
+ __m256 &checkF, __m256 &checkO,
55
+ hppl::Active<__m256>::forward actInput,
56
+ hppl::Active<__m256>::forward actGate,
57
+ hppl::Active<__m256>::forward actState) {
59
58
valueIn = actInput (valueIn);
60
59
valueIg = actGate (_mm256_add_ps (valueIg, _mm256_mul_ps (prevState, checkI)));
61
60
valueFg = actGate (_mm256_add_ps (valueFg, _mm256_mul_ps (prevState, checkF)));
@@ -76,14 +75,15 @@ namespace backward {
76
75
template <class T >
77
76
class lstm {
78
77
public:
79
- INLINE void operator ()(T &valueIn, T &valueIg, T &valueFg, T &valueOg,
80
- T &gradIn, T &gradIg, T &gradFg, T &gradOg,
81
- T &prevState, T &prevStateGrad, T &state, T &stateGrad,
82
- T &stateAtv, T &outputGrad, T &checkI, T &checkF,
83
- T &checkO, T &checkIGrad, T &checkFGrad, T &checkOGrad,
84
- typename hppl::BackwardActType<T>::type actInput,
85
- typename hppl::BackwardActType<T>::type actGate,
86
- typename hppl::BackwardActType<T>::type actState) {
78
+ HOSTDEVICE void operator ()(T &valueIn, T &valueIg, T &valueFg, T &valueOg,
79
+ T &gradIn, T &gradIg, T &gradFg, T &gradOg,
80
+ T &prevState, T &prevStateGrad, T &state,
81
+ T &stateGrad, T &stateAtv, T &outputGrad,
82
+ T &checkI, T &checkF, T &checkO, T &checkIGrad,
83
+ T &checkFGrad, T &checkOGrad,
84
+ typename hppl::BackwardActType<T>::type actInput,
85
+ typename hppl::BackwardActType<T>::type actGate,
86
+ typename hppl::BackwardActType<T>::type actState) {
87
87
gradOg = actGate (outputGrad * stateAtv, valueOg);
88
88
stateGrad += actState (outputGrad * valueOg, stateAtv) + gradOg * checkO;
89
89
gradIn = actInput (stateGrad * valueIg, valueIn);
@@ -95,21 +95,22 @@ class lstm {
95
95
checkOGrad = gradOg * state;
96
96
}
97
97
#ifndef __NVCC__
98
- #ifndef __AVX__
98
+ #ifndef __AVX__ // If not compiled with AVX instructs. Disable AVX by default
99
99
static const bool avx = false ;
100
100
#else
101
- static const bool avx = true ;
102
- INLINE void operator ()(__m256 &valueIn, __m256 &valueIg, __m256 &valueFg,
103
- __m256 &valueOg, __m256 &gradIn, __m256 &gradIg,
104
- __m256 &gradFg, __m256 &gradOg, __m256 &prevState,
105
- __m256 &prevStateGrad, __m256 &state,
106
- __m256 &stateGrad, __m256 &stateAtv,
107
- __m256 &outputGrad, __m256 &checkI, __m256 &checkF,
108
- __m256 &checkO, __m256 &checkIGrad, __m256 &checkFGrad,
109
- __m256 &checkOGrad,
110
- hppl::Active<__m256>::backward actInput,
111
- hppl::Active<__m256>::backward actGate,
112
- hppl::Active<__m256>::backward actState) {
101
+ // Only float support AVX optimization
102
+ static const bool avx = std::is_same<T, float >::value;
103
+ HOSTDEVICE void operator ()(__m256 &valueIn, __m256 &valueIg, __m256 &valueFg,
104
+ __m256 &valueOg, __m256 &gradIn, __m256 &gradIg,
105
+ __m256 &gradFg, __m256 &gradOg, __m256 &prevState,
106
+ __m256 &prevStateGrad, __m256 &state,
107
+ __m256 &stateGrad, __m256 &stateAtv,
108
+ __m256 &outputGrad, __m256 &checkI, __m256 &checkF,
109
+ __m256 &checkO, __m256 &checkIGrad,
110
+ __m256 &checkFGrad, __m256 &checkOGrad,
111
+ hppl::Active<__m256>::backward actInput,
112
+ hppl::Active<__m256>::backward actGate,
113
+ hppl::Active<__m256>::backward actState) {
113
114
gradOg = actGate (_mm256_mul_ps (outputGrad, stateAtv), valueOg);
114
115
stateGrad = _mm256_add_ps (
115
116
actState (_mm256_mul_ps (outputGrad, valueOg), stateAtv), stateGrad);
0 commit comments