Skip to content

Commit d5114c6

Browse files
committed
- Reviewers suggesstions to fused_embedding_fc_lstm_op
1 parent 7ab5626 commit d5114c6

File tree

2 files changed

+6
-9
lines changed

2 files changed

+6
-9
lines changed

paddle/fluid/framework/ir/embedding_fc_lstm_fuse_pass.cc

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
// limitations under the License.
1414

1515
#include "paddle/fluid/framework/ir/embedding_fc_lstm_fuse_pass.h"
16+
#include <algorithm>
1617
#include <string>
1718
#include "paddle/fluid/framework/lod_tensor.h"
1819

@@ -98,17 +99,17 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
9899

99100
// Copy only gate biases values (only actual bias data, not peephole
100101
// weights)
101-
std::vector<float> combined_biases(n, 0.0f);
102-
memcpy(&combined_biases[0], lstm_bias_tensor.data<float>(),
103-
n * sizeof(float));
102+
std::vector<float> combined_biases;
103+
combined_biases.reserve(n);
104+
std::copy_n(lstm_bias_tensor.data<float>(), n,
105+
std::back_inserter(combined_biases));
104106

105107
if (with_fc_bias) {
106108
// Add FC-bias with LSTM-bias (into GEMM result to be)
107109
auto* fc_bias_var = scope->FindVar(fc_bias->Name());
108110
const auto& fc_bias_tensor = fc_bias_var->Get<framework::LoDTensor>();
109111
for (int i = 0; i < fc_bias_tensor.numel(); i++) {
110-
combined_biases[i] =
111-
lstm_bias_tensor.data<float>()[i] + fc_bias_tensor.data<float>()[i];
112+
combined_biases[i] += fc_bias_tensor.data<float>()[i];
112113
}
113114
}
114115

paddle/fluid/operators/fused_embedding_fc_lstm_op.cc

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,10 +63,6 @@ void FusedEmbeddingFCLSTMOp::InferShape(
6363
auto embeddings_dims = ctx->GetInputDim("Embeddings");
6464
PADDLE_ENFORCE_EQ(embeddings_dims.size(), 2,
6565
"The rank of Input(Embeddings) should be 2.");
66-
// PADDLE_ENFORCE_EQ(wx_dims[0], x_dims[1],
67-
// "The first dimension of Input(Embeddings) "
68-
// "should be %d.",
69-
// x_dims[1]);
7066

7167
auto wh_dims = ctx->GetInputDim("WeightH");
7268
int frame_size = wh_dims[1] / 4;

0 commit comments

Comments
 (0)