Skip to content

Commit b652821

Browse files
sidgoyal78abhinavarora
authored andcommitted
Fix cpplint errors in lstm kernel (#10394)
1 parent bd66eed commit b652821

File tree

3 files changed

+100
-96
lines changed

3 files changed

+100
-96
lines changed

paddle/fluid/operators/math/detail/lstm_cpu_kernel.h

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,9 @@ void naive_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value,
5959
r_prev_state = value.prev_state_value[i];
6060
}
6161

62-
op(r_value_in, r_value_ig, r_value_fg, r_value_og, r_prev_state, r_state,
63-
r_state_atv, r_out, r_checkI, r_checkF, r_checkO, active_node,
64-
active_gate, active_state);
62+
op(&r_value_in, &r_value_ig, &r_value_fg, &r_value_og, &r_prev_state,
63+
&r_state, &r_state_atv, &r_out, &r_checkI, &r_checkF, &r_checkO,
64+
active_node, active_gate, active_state);
6565

6666
value_in[i] = r_value_in;
6767
value_ig[i] = r_value_ig;
@@ -125,11 +125,11 @@ void naive_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value,
125125
r_prev_state = value.prev_state_value[i];
126126
}
127127

128-
op(r_value_in, r_value_ig, r_value_fg, r_value_og, r_grad_in, r_grad_ig,
129-
r_grad_fg, r_grad_og, r_prev_state, r_prev_state_grad, r_state,
130-
r_state_grad, r_state_atv, r_output_grad, r_checkI, r_checkF, r_checkO,
131-
r_checkIGrad, r_checkFGrad, r_checkOGrad, active_node, active_gate,
132-
active_state);
128+
op(&r_value_in, &r_value_ig, &r_value_fg, &r_value_og, &r_grad_in,
129+
&r_grad_ig, &r_grad_fg, &r_grad_og, &r_prev_state, &r_prev_state_grad,
130+
&r_state, &r_state_grad, &r_state_atv, &r_output_grad, &r_checkI,
131+
&r_checkF, &r_checkO, &r_checkIGrad, &r_checkFGrad, &r_checkOGrad,
132+
active_node, active_gate, active_state);
133133

134134
grad_in[i] = r_grad_in;
135135
grad_ig[i] = r_grad_ig;
@@ -186,9 +186,9 @@ void avx_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value,
186186
r_prev_state = (reinterpret_cast<__m256 *>(value.prev_state_value))[i];
187187
}
188188

189-
op(r_value_in, r_value_ig, r_value_fg, r_value_og, r_prev_state, r_state,
190-
r_state_atv, r_out, r_checkI, r_checkF, r_checkO, active_node,
191-
active_gate, active_state);
189+
op(&r_value_in, &r_value_ig, &r_value_fg, &r_value_og, &r_prev_state,
190+
&r_state, &r_state_atv, &r_out, &r_checkI, &r_checkF, &r_checkO,
191+
active_node, active_gate, active_state);
192192

193193
value_in[i] = r_value_in;
194194
value_ig[i] = r_value_ig;
@@ -258,11 +258,11 @@ void avx_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value,
258258
r_prev_state = (reinterpret_cast<__m256 *>(value.prev_state_value))[i];
259259
}
260260

261-
op(r_value_in, r_value_ig, r_value_fg, r_value_og, r_grad_in, r_grad_ig,
262-
r_grad_fg, r_grad_og, r_prev_state, r_prev_state_grad, r_state,
263-
r_state_grad, r_state_atv, r_output_grad, r_checkI, r_checkF, r_checkO,
264-
r_checkIGrad, r_checkFGrad, r_checkOGrad, active_node, active_gate,
265-
active_state);
261+
op(&r_value_in, &r_value_ig, &r_value_fg, &r_value_og, &r_grad_in,
262+
&r_grad_ig, &r_grad_fg, &r_grad_og, &r_prev_state, &r_prev_state_grad,
263+
&r_state, &r_state_grad, &r_state_atv, &r_output_grad, &r_checkI,
264+
&r_checkF, &r_checkO, &r_checkIGrad, &r_checkFGrad, &r_checkOGrad,
265+
active_node, active_gate, active_state);
266266

267267
grad_in[i] = r_grad_in;
268268
grad_ig[i] = r_grad_ig;

paddle/fluid/operators/math/detail/lstm_gpu_kernel.h

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,9 @@ __global__ void KeLstmForward(Op op, LstmMetaValue<T> value, int frame_size,
7070
r_prev_state = value.prev_state_value[frame_idx];
7171
}
7272

73-
op(r_value_in, r_value_ig, r_value_fg, r_value_og, r_prev_state, r_state,
74-
r_state_atv, r_out, r_checkI, r_checkF, r_checkO, active_node, active_gate,
75-
active_state);
73+
op(&r_value_in, &r_value_ig, &r_value_fg, &r_value_og, &r_prev_state,
74+
&r_state, &r_state_atv, &r_out, &r_checkI, &r_checkF, &r_checkO,
75+
active_node, active_gate, active_state);
7676

7777
value.gate_value[frame_idx] = r_value_in;
7878
value.gate_value[frame_idx + frame_size] = r_value_ig;
@@ -145,11 +145,11 @@ __global__ void KeLstmBackward(Op op, LstmMetaValue<T> value,
145145
r_prev_state = value.prev_state_value[frame_idx];
146146
}
147147

148-
op(r_value_in, r_value_ig, r_value_fg, r_value_og, r_grad_in, r_grad_ig,
149-
r_grad_fg, r_grad_og, r_prev_state, r_prev_state_grad, r_state,
150-
r_state_grad, r_state_atv, r_output_grad, r_checkI, r_checkF, r_checkO,
151-
r_checkIGrad, r_checkFGrad, r_checkOGrad, active_node, active_gate,
152-
active_state);
148+
op(&r_value_in, &r_value_ig, &r_value_fg, &r_value_og, &r_grad_in, &r_grad_ig,
149+
&r_grad_fg, &r_grad_og, &r_prev_state, &r_prev_state_grad, &r_state,
150+
&r_state_grad, &r_state_atv, &r_output_grad, &r_checkI, &r_checkF,
151+
&r_checkO, &r_checkIGrad, &r_checkFGrad, &r_checkOGrad, active_node,
152+
active_gate, active_state);
153153

154154
grad.gate_grad[frame_idx] = r_grad_in;
155155
grad.gate_grad[frame_idx + frame_size] = r_grad_ig;

paddle/fluid/operators/math/detail/lstm_kernel.h

Lines changed: 76 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

15+
#pragma once
16+
#include <type_traits>
1517
#include "paddle/fluid/operators/math/detail/activation_functions.h"
1618
#include "paddle/fluid/platform/hostdevice.h"
1719

18-
#include <type_traits>
19-
2020
namespace paddle {
2121
namespace operators {
2222
namespace math {
@@ -27,19 +27,19 @@ namespace forward {
2727
template <class T>
2828
class lstm {
2929
public:
30-
HOSTDEVICE void operator()(T &value_in, T &value_ig, T &value_fg, T &value_og,
31-
T &prev_state, T &state, T &state_atv, T &output,
32-
T &checkI, T &checkF, T &checkO,
30+
HOSTDEVICE void operator()(T *value_in, T *value_ig, T *value_fg, T *value_og,
31+
T *prev_state, T *state, T *state_atv, T *output,
32+
T *checkI, T *checkF, T *checkO,
3333
ActivationType active_node,
3434
ActivationType active_gate,
3535
ActivationType active_state) {
36-
value_in = activation(value_in, active_node);
37-
value_ig = activation(value_ig + prev_state * checkI, active_gate);
38-
value_fg = activation(value_fg + prev_state * checkF, active_gate);
39-
state = value_in * value_ig + prev_state * value_fg;
40-
value_og = activation(value_og + state * checkO, active_gate);
41-
state_atv = activation(state, active_state);
42-
output = value_og * state_atv;
36+
*value_in = activation(*value_in, active_node);
37+
*value_ig = activation(*value_ig + (*prev_state) * (*checkI), active_gate);
38+
*value_fg = activation(*value_fg + (*prev_state) * (*checkF), active_gate);
39+
*state = (*value_in) * (*value_ig) + (*prev_state) * (*value_fg);
40+
*value_og = activation(*value_og + (*state) * (*checkO), active_gate);
41+
*state_atv = activation(*state, active_state);
42+
*output = (*value_og) * (*state_atv);
4343
}
4444
#ifndef __NVCC__
4545
#ifndef __AVX__ // If not compiled with AVX instructs. Disable AVX by default
@@ -48,27 +48,27 @@ class lstm {
4848
// Only float support AVX optimization
4949
static const bool avx = std::is_same<T, float>::value;
5050

51-
HOSTDEVICE void operator()(__m256 &value_in, __m256 &value_ig,
52-
__m256 &value_fg, __m256 &value_og,
53-
__m256 &prev_state, __m256 &state,
54-
__m256 &state_atv, __m256 &output, __m256 &checkI,
55-
__m256 &checkF, __m256 &checkO,
51+
HOSTDEVICE void operator()(__m256 *value_in, __m256 *value_ig,
52+
__m256 *value_fg, __m256 *value_og,
53+
__m256 *prev_state, __m256 *state,
54+
__m256 *state_atv, __m256 *output, __m256 *checkI,
55+
__m256 *checkF, __m256 *checkO,
5656
ActivationType active_node,
5757
ActivationType active_gate,
5858
ActivationType active_state) {
59-
value_in = activation(value_in, active_node);
60-
value_ig =
61-
activation(_mm256_add_ps(value_ig, _mm256_mul_ps(prev_state, checkI)),
62-
active_gate);
63-
value_fg =
64-
activation(_mm256_add_ps(value_fg, _mm256_mul_ps(prev_state, checkF)),
65-
active_gate);
66-
state = _mm256_add_ps(_mm256_mul_ps(value_in, value_ig),
67-
_mm256_mul_ps(prev_state, value_fg));
68-
value_og = activation(_mm256_add_ps(value_og, _mm256_mul_ps(state, checkO)),
69-
active_gate);
70-
state_atv = activation(state, active_state);
71-
output = _mm256_mul_ps(value_og, state_atv);
59+
*value_in = activation(*value_in, active_node);
60+
*value_ig = activation(
61+
_mm256_add_ps(*value_ig, _mm256_mul_ps(*prev_state, *checkI)),
62+
active_gate);
63+
*value_fg = activation(
64+
_mm256_add_ps(*value_fg, _mm256_mul_ps(*prev_state, *checkF)),
65+
active_gate);
66+
*state = _mm256_add_ps(_mm256_mul_ps(*value_in, *value_ig),
67+
_mm256_mul_ps(*prev_state, *value_fg));
68+
*value_og = activation(
69+
_mm256_add_ps(*value_og, _mm256_mul_ps(*state, *checkO)), active_gate);
70+
*state_atv = activation(*state, active_state);
71+
*output = _mm256_mul_ps(*value_og, *state_atv);
7272
}
7373
#endif
7474
#endif
@@ -81,26 +81,29 @@ namespace backward {
8181
template <class T>
8282
class lstm {
8383
public:
84-
HOSTDEVICE void operator()(T &value_in, T &value_ig, T &value_fg, T &value_og,
85-
T &grad_in, T &grad_ig, T &grad_fg, T &grad_og,
86-
T &prev_state, T &prev_state_grad, T &state,
87-
T &state_grad, T &state_atv, T &output_grad,
88-
T &checkI, T &checkF, T &checkO, T &checkIGrad,
89-
T &checkFGrad, T &checkOGrad,
84+
HOSTDEVICE void operator()(T *value_in, T *value_ig, T *value_fg, T *value_og,
85+
T *grad_in, T *grad_ig, T *grad_fg, T *grad_og,
86+
T *prev_state, T *prev_state_grad, T *state,
87+
T *state_grad, T *state_atv, T *output_grad,
88+
T *checkI, T *checkF, T *checkO, T *checkIGrad,
89+
T *checkFGrad, T *checkOGrad,
9090
ActivationType active_node,
9191
ActivationType active_gate,
9292
ActivationType active_state) {
93-
grad_og = activation(output_grad * state_atv, value_og, active_gate);
94-
state_grad += activation(output_grad * value_og, state_atv, active_state) +
95-
grad_og * checkO;
96-
grad_in = activation(state_grad * value_ig, value_in, active_node);
97-
grad_ig = activation(state_grad * value_in, value_ig, active_gate);
98-
grad_fg = activation(state_grad * prev_state, value_fg, active_gate);
99-
prev_state_grad =
100-
grad_ig * checkI + grad_fg * checkF + state_grad * value_fg;
101-
checkIGrad = grad_ig * prev_state;
102-
checkFGrad = grad_fg * prev_state;
103-
checkOGrad = grad_og * state;
93+
*grad_og =
94+
activation((*output_grad) * (*state_atv), *value_og, active_gate);
95+
*state_grad +=
96+
activation((*output_grad) * (*value_og), *state_atv, active_state) +
97+
(*grad_og) * (*checkO);
98+
*grad_in = activation((*state_grad) * (*value_ig), *value_in, active_node);
99+
*grad_ig = activation((*state_grad) * (*value_in), *value_ig, active_gate);
100+
*grad_fg =
101+
activation((*state_grad) * (*prev_state), *value_fg, active_gate);
102+
*prev_state_grad = (*grad_ig) * (*checkI) + (*grad_fg) * (*checkF) +
103+
(*state_grad) * (*value_fg);
104+
*checkIGrad = (*grad_ig) * (*prev_state);
105+
*checkFGrad = (*grad_fg) * (*prev_state);
106+
*checkOGrad = (*grad_og) * (*state);
104107
}
105108
#ifndef __NVCC__
106109
#ifndef __AVX__ // If not compiled with AVX instructs. Disable AVX by default
@@ -109,32 +112,33 @@ class lstm {
109112
// Only float support AVX optimization
110113
static const bool avx = std::is_same<T, float>::value;
111114
HOSTDEVICE void operator()(
112-
__m256 &value_in, __m256 &value_ig, __m256 &value_fg, __m256 &value_og,
113-
__m256 &grad_in, __m256 &grad_ig, __m256 &grad_fg, __m256 &grad_og,
114-
__m256 &prev_state, __m256 &prev_state_grad, __m256 &state,
115-
__m256 &state_grad, __m256 &state_atv, __m256 &output_grad,
116-
__m256 &checkI, __m256 &checkF, __m256 &checkO, __m256 &checkIGrad,
117-
__m256 &checkFGrad, __m256 &checkOGrad, ActivationType active_node,
115+
__m256 *value_in, __m256 *value_ig, __m256 *value_fg, __m256 *value_og,
116+
__m256 *grad_in, __m256 *grad_ig, __m256 *grad_fg, __m256 *grad_og,
117+
__m256 *prev_state, __m256 *prev_state_grad, __m256 *state,
118+
__m256 *state_grad, __m256 *state_atv, __m256 *output_grad,
119+
__m256 *checkI, __m256 *checkF, __m256 *checkO, __m256 *checkIGrad,
120+
__m256 *checkFGrad, __m256 *checkOGrad, ActivationType active_node,
118121
ActivationType active_gate, ActivationType active_state) {
119-
grad_og = activation(_mm256_mul_ps(output_grad, state_atv), value_og,
120-
active_gate);
121-
state_grad = _mm256_add_ps(activation(_mm256_mul_ps(output_grad, value_og),
122-
state_atv, active_state),
123-
state_grad);
124-
state_grad = _mm256_add_ps(_mm256_mul_ps(grad_og, checkO), state_grad);
125-
grad_in =
126-
activation(_mm256_mul_ps(state_grad, value_ig), value_in, active_node);
127-
grad_ig =
128-
activation(_mm256_mul_ps(state_grad, value_in), value_ig, active_gate);
129-
grad_fg = activation(_mm256_mul_ps(state_grad, prev_state), value_fg,
130-
active_gate);
131-
prev_state_grad = _mm256_add_ps(_mm256_mul_ps(grad_ig, checkI),
132-
_mm256_mul_ps(grad_fg, checkF));
133-
prev_state_grad =
134-
_mm256_add_ps(_mm256_mul_ps(state_grad, value_fg), prev_state_grad);
135-
checkIGrad = _mm256_mul_ps(grad_ig, prev_state);
136-
checkFGrad = _mm256_mul_ps(grad_fg, prev_state);
137-
checkOGrad = _mm256_mul_ps(grad_og, state);
122+
*grad_og = activation(_mm256_mul_ps(*output_grad, *state_atv), *value_og,
123+
active_gate);
124+
*state_grad =
125+
_mm256_add_ps(activation(_mm256_mul_ps(*output_grad, *value_og),
126+
*state_atv, active_state),
127+
*state_grad);
128+
*state_grad = _mm256_add_ps(_mm256_mul_ps(*grad_og, *checkO), *state_grad);
129+
*grad_in = activation(_mm256_mul_ps(*state_grad, *value_ig), *value_in,
130+
active_node);
131+
*grad_ig = activation(_mm256_mul_ps(*state_grad, *value_in), *value_ig,
132+
active_gate);
133+
*grad_fg = activation(_mm256_mul_ps(*state_grad, *prev_state), *value_fg,
134+
active_gate);
135+
*prev_state_grad = _mm256_add_ps(_mm256_mul_ps(*grad_ig, *checkI),
136+
_mm256_mul_ps(*grad_fg, *checkF));
137+
*prev_state_grad =
138+
_mm256_add_ps(_mm256_mul_ps(*state_grad, *value_fg), *prev_state_grad);
139+
*checkIGrad = _mm256_mul_ps(*grad_ig, *prev_state);
140+
*checkFGrad = _mm256_mul_ps(*grad_fg, *prev_state);
141+
*checkOGrad = _mm256_mul_ps(*grad_og, *state);
138142
}
139143
#endif
140144
#endif

0 commit comments

Comments
 (0)