Skip to content

Commit a9be969

Browse files
authored
Update attention_lstm_fuse_pass.cc
1 parent 5d5b70a commit a9be969

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

paddle/fluid/framework/ir/attention_lstm_fuse_pass.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -217,10 +217,10 @@ void PrepareLSTMWeight(const LoDTensor& W_forget_w0,
217217
float* out_data = out->mutable_data<float>(platform::CPUPlace());
218218
std::array<const float*, 4> tensors(
219219
{{W_forget_w0.data<float>(), W_input_w0.data<float>(),
220-
W_output_w0.data<float>(), W_cell_w0.data<float>()}});
220+
W_output_w0.data<float>(), W_cell_w0.data<float>()}});
221221
std::array<const float*, 4> tensors1(
222222
{{W_forget_w1.data<float>(), W_input_w1.data<float>(),
223-
W_output_w1.data<float>(), W_cell_w1.data<float>()}});
223+
W_output_w1.data<float>(), W_cell_w1.data<float>()}});
224224

225225
for (int row = 0; row < D; row++) {
226226
for (int col = 0; col < 4; col++) {
@@ -244,7 +244,7 @@ void PrepareLSTMBias(const LoDTensor& B_forget, const LoDTensor& B_input,
244244
LoDTensor* out) {
245245
std::array<const float*, 4> tensors(
246246
{{B_forget.data<float>(), B_input.data<float>(), B_output.data<float>(),
247-
B_cell.data<float>()}});
247+
B_cell.data<float>()}});
248248

249249
PADDLE_ENFORCE_EQ(B_forget.dims().size(), 1);
250250
int D = B_forget.dims()[0];

0 commit comments

Comments
 (0)