@@ -73,15 +73,15 @@ class LSTMKernel : public framework::OpKernel<T> {
73
73
T* bias_data = const_cast <T*>(bias->data <T>());
74
74
// the code style in LstmMetaValue will be updated later.
75
75
76
- lstm_value.checkIg = bias_data + 4 * frame_size;
77
- lstm_value.checkFg = lstm_value.checkIg + frame_size;
78
- lstm_value.checkOg = lstm_value.checkFg + frame_size;
76
+ lstm_value.check_ig = bias_data + 4 * frame_size;
77
+ lstm_value.check_fg = lstm_value.check_ig + frame_size;
78
+ lstm_value.check_og = lstm_value.check_fg + frame_size;
79
79
} else {
80
- lstm_value.checkIg = nullptr ;
81
- lstm_value.checkFg = nullptr ;
82
- lstm_value.checkOg = nullptr ;
80
+ lstm_value.check_ig = nullptr ;
81
+ lstm_value.check_fg = nullptr ;
82
+ lstm_value.check_og = nullptr ;
83
83
}
84
- lstm_value.prevStateValue = nullptr ;
84
+ lstm_value.prev_state_value = nullptr ;
85
85
Tensor ordered_c0;
86
86
const size_t * order = batch_gate->lod ()[2 ].data ();
87
87
if (cell_t0) {
@@ -90,7 +90,7 @@ class LSTMKernel : public framework::OpKernel<T> {
90
90
// to reorder.
91
91
ReorderInitState<Place, T>(device_ctx, *cell_t0, order, &ordered_c0,
92
92
true );
93
- lstm_value.prevStateValue = ordered_c0.data <T>();
93
+ lstm_value.prev_state_value = ordered_c0.data <T>();
94
94
}
95
95
96
96
// Use the local variable as here.
@@ -140,14 +140,14 @@ class LSTMKernel : public framework::OpKernel<T> {
140
140
static_cast <T>(1.0 ));
141
141
}
142
142
143
- lstm_value.gateValue = gate_t .data <T>();
144
- lstm_value.outputValue = out_t .data <T>();
145
- lstm_value.stateValue = cell_t .data <T>();
146
- lstm_value.stateActiveValue = cell_pre_act_t .data <T>();
143
+ lstm_value.gate_value = gate_t .data <T>();
144
+ lstm_value.output_value = out_t .data <T>();
145
+ lstm_value.state_value = cell_t .data <T>();
146
+ lstm_value.state_active_value = cell_pre_act_t .data <T>();
147
147
math::LstmUnitFunctor<Place, T>::compute (device_ctx, lstm_value,
148
148
frame_size, cur_batch_size,
149
149
gate_act, cell_act, cand_act);
150
- lstm_value.prevStateValue = lstm_value.stateValue ;
150
+ lstm_value.prev_state_value = lstm_value.state_value ;
151
151
}
152
152
153
153
math::Batch2LoDTensorFunctor<Place, T> to_seq;
@@ -214,13 +214,13 @@ class LSTMGradKernel : public framework::OpKernel<T> {
214
214
math::LstmMetaValue<T> lstm_value;
215
215
if (bias && ctx.Attr <bool >(" use_peepholes" )) {
216
216
T* bias_data = const_cast <T*>(bias->data <T>());
217
- lstm_value.checkIg = bias_data + 4 * frame_size;
218
- lstm_value.checkFg = lstm_value.checkIg + frame_size;
219
- lstm_value.checkOg = lstm_value.checkFg + frame_size;
217
+ lstm_value.check_ig = bias_data + 4 * frame_size;
218
+ lstm_value.check_fg = lstm_value.check_ig + frame_size;
219
+ lstm_value.check_og = lstm_value.check_fg + frame_size;
220
220
} else {
221
- lstm_value.checkIg = nullptr ;
222
- lstm_value.checkFg = nullptr ;
223
- lstm_value.checkOg = nullptr ;
221
+ lstm_value.check_ig = nullptr ;
222
+ lstm_value.check_fg = nullptr ;
223
+ lstm_value.check_og = nullptr ;
224
224
}
225
225
226
226
math::LstmMetaGrad<T> lstm_grad;
@@ -231,13 +231,13 @@ class LSTMGradKernel : public framework::OpKernel<T> {
231
231
}
232
232
if (bias && bias_g && ctx.Attr <bool >(" use_peepholes" )) {
233
233
T* bias_g_data = bias_g->data <T>();
234
- lstm_grad.checkIgGrad = bias_g_data + 4 * frame_size;
235
- lstm_grad.checkFgGrad = lstm_grad.checkIgGrad + frame_size;
236
- lstm_grad.checkOgGrad = lstm_grad.checkFgGrad + frame_size;
234
+ lstm_grad.check_ig_grad = bias_g_data + 4 * frame_size;
235
+ lstm_grad.check_fg_grad = lstm_grad.check_ig_grad + frame_size;
236
+ lstm_grad.check_og_grad = lstm_grad.check_fg_grad + frame_size;
237
237
} else {
238
- lstm_grad.checkIgGrad = nullptr ;
239
- lstm_grad.checkFgGrad = nullptr ;
240
- lstm_grad.checkOgGrad = nullptr ;
238
+ lstm_grad.check_ig_grad = nullptr ;
239
+ lstm_grad.check_fg_grad = nullptr ;
240
+ lstm_grad.check_og_grad = nullptr ;
241
241
}
242
242
243
243
math::LoDTensor2BatchFunctor<Place, T> to_batch;
@@ -276,26 +276,26 @@ class LSTMGradKernel : public framework::OpKernel<T> {
276
276
Tensor gate = batch_gate->Slice (bstart, bend);
277
277
Tensor cell = batch_cell.Slice (bstart, bend);
278
278
Tensor cell_pre_act = batch_cell_pre_act->Slice (bstart, bend);
279
- lstm_value.gateValue = gate.data <T>();
280
- lstm_value.stateValue = cell.data <T>();
281
- lstm_value.stateActiveValue = cell_pre_act.data <T>();
279
+ lstm_value.gate_value = gate.data <T>();
280
+ lstm_value.state_value = cell.data <T>();
281
+ lstm_value.state_active_value = cell_pre_act.data <T>();
282
282
283
283
Tensor out_g = batch_hidden_g.Slice (bstart, bend);
284
284
Tensor gate_g = batch_gate_g.Slice (bstart, bend);
285
285
Tensor cell_g = batch_cell_g.Slice (bstart, bend);
286
- lstm_grad.stateGrad = cell_g.data <T>();
287
- lstm_grad.gateGrad = gate_g.data <T>();
288
- lstm_grad.outputGrad = out_g.data <T>();
286
+ lstm_grad.state_grad = cell_g.data <T>();
287
+ lstm_grad.gate_grad = gate_g.data <T>();
288
+ lstm_grad.output_grad = out_g.data <T>();
289
289
290
290
if (n > 0 ) {
291
291
int bstart_pre = static_cast <int >(batch_starts[n - 1 ]);
292
292
Tensor cell_pre = batch_cell.Slice (bstart_pre, bstart);
293
293
Tensor cell_pre_g = batch_cell_g.Slice (bstart_pre, bstart);
294
- lstm_value.prevStateValue = cell_pre.data <T>();
295
- lstm_grad.prevStateGrad = cell_pre_g.data <T>();
294
+ lstm_value.prev_state_value = cell_pre.data <T>();
295
+ lstm_grad.prev_state_grad = cell_pre_g.data <T>();
296
296
} else {
297
- lstm_value.prevStateValue = c0 ? ordered_c0.data <T>() : nullptr ;
298
- lstm_grad.prevStateGrad = c0_g ? ordered_c0_g.data <T>() : nullptr ;
297
+ lstm_value.prev_state_value = c0 ? ordered_c0.data <T>() : nullptr ;
298
+ lstm_grad.prev_state_grad = c0_g ? ordered_c0_g.data <T>() : nullptr ;
299
299
}
300
300
301
301
int cur_batch_size = bend - bstart;
0 commit comments