@@ -12,11 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
See the License for the specific language governing permissions and
13
13
limitations under the License. */
14
14
15
+ #pragma once
16
+ #include < type_traits>
15
17
#include " paddle/fluid/operators/math/detail/activation_functions.h"
16
18
#include " paddle/fluid/platform/hostdevice.h"
17
19
18
- #include < type_traits>
19
-
20
20
namespace paddle {
21
21
namespace operators {
22
22
namespace math {
@@ -27,19 +27,19 @@ namespace forward {
27
27
template <class T >
28
28
class lstm {
29
29
public:
30
- HOSTDEVICE void operator ()(T & value_in, T & value_ig, T & value_fg, T & value_og,
31
- T & prev_state, T & state, T & state_atv, T & output,
32
- T & checkI, T & checkF, T & checkO,
30
+ HOSTDEVICE void operator ()(T * value_in, T * value_ig, T * value_fg, T * value_og,
31
+ T * prev_state, T * state, T * state_atv, T * output,
32
+ T * checkI, T * checkF, T * checkO,
33
33
ActivationType active_node,
34
34
ActivationType active_gate,
35
35
ActivationType active_state) {
36
- value_in = activation (value_in, active_node);
37
- value_ig = activation (value_ig + prev_state * checkI, active_gate);
38
- value_fg = activation (value_fg + prev_state * checkF, active_gate);
39
- state = value_in * value_ig + prev_state * value_fg;
40
- value_og = activation (value_og + state * checkO, active_gate);
41
- state_atv = activation (state, active_state);
42
- output = value_og * state_atv;
36
+ * value_in = activation (* value_in, active_node);
37
+ * value_ig = activation (* value_ig + (* prev_state) * (* checkI) , active_gate);
38
+ * value_fg = activation (* value_fg + (* prev_state) * (* checkF) , active_gate);
39
+ * state = (* value_in) * (* value_ig) + (* prev_state) * (* value_fg) ;
40
+ * value_og = activation (* value_og + (* state) * (* checkO) , active_gate);
41
+ * state_atv = activation (* state, active_state);
42
+ * output = (* value_og) * (* state_atv) ;
43
43
}
44
44
#ifndef __NVCC__
45
45
#ifndef __AVX__ // If not compiled with AVX instructs. Disable AVX by default
@@ -48,27 +48,27 @@ class lstm {
48
48
// Only float support AVX optimization
49
49
static const bool avx = std::is_same<T, float >::value;
50
50
51
- HOSTDEVICE void operator ()(__m256 & value_in, __m256 & value_ig,
52
- __m256 & value_fg, __m256 & value_og,
53
- __m256 & prev_state, __m256 & state,
54
- __m256 & state_atv, __m256 & output, __m256 & checkI,
55
- __m256 & checkF, __m256 & checkO,
51
+ HOSTDEVICE void operator ()(__m256 * value_in, __m256 * value_ig,
52
+ __m256 * value_fg, __m256 * value_og,
53
+ __m256 * prev_state, __m256 * state,
54
+ __m256 * state_atv, __m256 * output, __m256 * checkI,
55
+ __m256 * checkF, __m256 * checkO,
56
56
ActivationType active_node,
57
57
ActivationType active_gate,
58
58
ActivationType active_state) {
59
- value_in = activation (value_in, active_node);
60
- value_ig =
61
- activation ( _mm256_add_ps (value_ig, _mm256_mul_ps (prev_state, checkI)),
62
- active_gate);
63
- value_fg =
64
- activation ( _mm256_add_ps (value_fg, _mm256_mul_ps (prev_state, checkF)),
65
- active_gate);
66
- state = _mm256_add_ps (_mm256_mul_ps (value_in, value_ig),
67
- _mm256_mul_ps (prev_state, value_fg));
68
- value_og = activation (_mm256_add_ps (value_og, _mm256_mul_ps (state, checkO)),
69
- active_gate);
70
- state_atv = activation (state, active_state);
71
- output = _mm256_mul_ps (value_og, state_atv);
59
+ * value_in = activation (* value_in, active_node);
60
+ * value_ig = activation (
61
+ _mm256_add_ps (* value_ig, _mm256_mul_ps (* prev_state, * checkI)),
62
+ active_gate);
63
+ * value_fg = activation (
64
+ _mm256_add_ps (* value_fg, _mm256_mul_ps (* prev_state, * checkF)),
65
+ active_gate);
66
+ * state = _mm256_add_ps (_mm256_mul_ps (* value_in, * value_ig),
67
+ _mm256_mul_ps (* prev_state, * value_fg));
68
+ * value_og = activation (
69
+ _mm256_add_ps (*value_og, _mm256_mul_ps (*state, *checkO)), active_gate);
70
+ * state_atv = activation (* state, active_state);
71
+ * output = _mm256_mul_ps (* value_og, * state_atv);
72
72
}
73
73
#endif
74
74
#endif
@@ -81,26 +81,29 @@ namespace backward {
81
81
template <class T >
82
82
class lstm {
83
83
public:
84
- HOSTDEVICE void operator ()(T & value_in, T & value_ig, T & value_fg, T & value_og,
85
- T & grad_in, T & grad_ig, T & grad_fg, T & grad_og,
86
- T & prev_state, T & prev_state_grad, T & state,
87
- T & state_grad, T & state_atv, T & output_grad,
88
- T & checkI, T & checkF, T & checkO, T & checkIGrad,
89
- T & checkFGrad, T & checkOGrad,
84
+ HOSTDEVICE void operator ()(T * value_in, T * value_ig, T * value_fg, T * value_og,
85
+ T * grad_in, T * grad_ig, T * grad_fg, T * grad_og,
86
+ T * prev_state, T * prev_state_grad, T * state,
87
+ T * state_grad, T * state_atv, T * output_grad,
88
+ T * checkI, T * checkF, T * checkO, T * checkIGrad,
89
+ T * checkFGrad, T * checkOGrad,
90
90
ActivationType active_node,
91
91
ActivationType active_gate,
92
92
ActivationType active_state) {
93
- grad_og = activation (output_grad * state_atv, value_og, active_gate);
94
- state_grad += activation (output_grad * value_og, state_atv, active_state) +
95
- grad_og * checkO;
96
- grad_in = activation (state_grad * value_ig, value_in, active_node);
97
- grad_ig = activation (state_grad * value_in, value_ig, active_gate);
98
- grad_fg = activation (state_grad * prev_state, value_fg, active_gate);
99
- prev_state_grad =
100
- grad_ig * checkI + grad_fg * checkF + state_grad * value_fg;
101
- checkIGrad = grad_ig * prev_state;
102
- checkFGrad = grad_fg * prev_state;
103
- checkOGrad = grad_og * state;
93
+ *grad_og =
94
+ activation ((*output_grad) * (*state_atv), *value_og, active_gate);
95
+ *state_grad +=
96
+ activation ((*output_grad) * (*value_og), *state_atv, active_state) +
97
+ (*grad_og) * (*checkO);
98
+ *grad_in = activation ((*state_grad) * (*value_ig), *value_in, active_node);
99
+ *grad_ig = activation ((*state_grad) * (*value_in), *value_ig, active_gate);
100
+ *grad_fg =
101
+ activation ((*state_grad) * (*prev_state), *value_fg, active_gate);
102
+ *prev_state_grad = (*grad_ig) * (*checkI) + (*grad_fg) * (*checkF) +
103
+ (*state_grad) * (*value_fg);
104
+ *checkIGrad = (*grad_ig) * (*prev_state);
105
+ *checkFGrad = (*grad_fg) * (*prev_state);
106
+ *checkOGrad = (*grad_og) * (*state);
104
107
}
105
108
#ifndef __NVCC__
106
109
#ifndef __AVX__ // If not compiled with AVX instructs. Disable AVX by default
@@ -109,32 +112,33 @@ class lstm {
109
112
// Only float support AVX optimization
110
113
static const bool avx = std::is_same<T, float >::value;
111
114
HOSTDEVICE void operator ()(
112
- __m256 & value_in, __m256 & value_ig, __m256 & value_fg, __m256 & value_og,
113
- __m256 & grad_in, __m256 & grad_ig, __m256 & grad_fg, __m256 & grad_og,
114
- __m256 & prev_state, __m256 & prev_state_grad, __m256 & state,
115
- __m256 & state_grad, __m256 & state_atv, __m256 & output_grad,
116
- __m256 & checkI, __m256 & checkF, __m256 & checkO, __m256 & checkIGrad,
117
- __m256 & checkFGrad, __m256 & checkOGrad, ActivationType active_node,
115
+ __m256 * value_in, __m256 * value_ig, __m256 * value_fg, __m256 * value_og,
116
+ __m256 * grad_in, __m256 * grad_ig, __m256 * grad_fg, __m256 * grad_og,
117
+ __m256 * prev_state, __m256 * prev_state_grad, __m256 * state,
118
+ __m256 * state_grad, __m256 * state_atv, __m256 * output_grad,
119
+ __m256 * checkI, __m256 * checkF, __m256 * checkO, __m256 * checkIGrad,
120
+ __m256 * checkFGrad, __m256 * checkOGrad, ActivationType active_node,
118
121
ActivationType active_gate, ActivationType active_state) {
119
- grad_og = activation (_mm256_mul_ps (output_grad, state_atv), value_og,
120
- active_gate);
121
- state_grad = _mm256_add_ps (activation (_mm256_mul_ps (output_grad, value_og),
122
- state_atv, active_state),
123
- state_grad);
124
- state_grad = _mm256_add_ps (_mm256_mul_ps (grad_og, checkO), state_grad);
125
- grad_in =
126
- activation (_mm256_mul_ps (state_grad, value_ig), value_in, active_node);
127
- grad_ig =
128
- activation (_mm256_mul_ps (state_grad, value_in), value_ig, active_gate);
129
- grad_fg = activation (_mm256_mul_ps (state_grad, prev_state), value_fg,
130
- active_gate);
131
- prev_state_grad = _mm256_add_ps (_mm256_mul_ps (grad_ig, checkI),
132
- _mm256_mul_ps (grad_fg, checkF));
133
- prev_state_grad =
134
- _mm256_add_ps (_mm256_mul_ps (state_grad, value_fg), prev_state_grad);
135
- checkIGrad = _mm256_mul_ps (grad_ig, prev_state);
136
- checkFGrad = _mm256_mul_ps (grad_fg, prev_state);
137
- checkOGrad = _mm256_mul_ps (grad_og, state);
122
+ *grad_og = activation (_mm256_mul_ps (*output_grad, *state_atv), *value_og,
123
+ active_gate);
124
+ *state_grad =
125
+ _mm256_add_ps (activation (_mm256_mul_ps (*output_grad, *value_og),
126
+ *state_atv, active_state),
127
+ *state_grad);
128
+ *state_grad = _mm256_add_ps (_mm256_mul_ps (*grad_og, *checkO), *state_grad);
129
+ *grad_in = activation (_mm256_mul_ps (*state_grad, *value_ig), *value_in,
130
+ active_node);
131
+ *grad_ig = activation (_mm256_mul_ps (*state_grad, *value_in), *value_ig,
132
+ active_gate);
133
+ *grad_fg = activation (_mm256_mul_ps (*state_grad, *prev_state), *value_fg,
134
+ active_gate);
135
+ *prev_state_grad = _mm256_add_ps (_mm256_mul_ps (*grad_ig, *checkI),
136
+ _mm256_mul_ps (*grad_fg, *checkF));
137
+ *prev_state_grad =
138
+ _mm256_add_ps (_mm256_mul_ps (*state_grad, *value_fg), *prev_state_grad);
139
+ *checkIGrad = _mm256_mul_ps (*grad_ig, *prev_state);
140
+ *checkFGrad = _mm256_mul_ps (*grad_fg, *prev_state);
141
+ *checkOGrad = _mm256_mul_ps (*grad_og, *state);
138
142
}
139
143
#endif
140
144
#endif
0 commit comments