@@ -82,7 +82,7 @@ void ROperator_GRU<T>::Initialize(RModel &model) {
8282 for (size_t batch = 0; batch < batch_size; batch++) {
8383 size_t bias_offset = direction * 6 * fAttrHiddenSize + i * fAttrHiddenSize;
8484 size_t offset = direction * 6 * batch_size * seq_length * fAttrHiddenSize +
85- i * batch_size * seq_length * fAttrHiddenSize +
85+ i * batch_size * seq_length * fAttrHiddenSize +
8686 + seq *batch_size *fAttrHiddenSize + batch *fAttrHiddenSize;
8787 std::copy(original_bias + bias_offset, original_bias + bias_offset + fAttrHiddenSize,
8888 new_bias + offset);
@@ -146,7 +146,9 @@ void ROperator_GRU<T>::Initialize(RModel &model) {
146146 activation + " not implemented");
147147 }
148148 }
149+ if (fAttrDirection == "reverse") fAttrDirection = "backward";
149150 if (fAttrDirection != "forward" && fAttrDirection != "backward" &&
151+ fAttrDirection != "reverse" &&
150152 fAttrDirection != "bidirectional") {
151153 throw std::runtime_error(
152154 "TMVA SOFIE - Invalid GRU direction fAttrDirection = " +
@@ -206,7 +208,7 @@ std::string ROperator_GRU<T>::GenerateSessionMembersCode(std::string opName)
206208 out << "std::vector<" << fType << "> fVec_" << opName << "_update_gate = std::vector<" << fType << ">(" << hs_size << ");\n";
207209 out << "std::vector<" << fType << "> fVec_" << opName << "_reset_gate = std::vector<" << fType << ">(" << hs_size << ");\n";
208210 out << "std::vector<" << fType << "> fVec_" << opName << "_hidden_gate = std::vector<" << fType << ">(" << hs_size << ");\n";
209-
211+
210212 // feedback
211213 out << "std::vector<" << fType << "> fVec_" << opName << "_feedback = std::vector<" << fType << ">("
212214 << batch_size * fAttrHiddenSize << ");\n";
@@ -754,7 +756,7 @@ auto ROperator_GRU<T>::Generate(std::string OpName)
754756 << OpName << "_n, &" << OpName << "_m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR
755757 << " + " << rh_offset << ", &" << OpName << "_n, " << OpName << "_hidden_state + previous_offset, &"
756758 << OpName << "_n, &" << OpName << "_beta, " << OpName << "_feedback, &" << OpName << "_n);\n";
757- // endif on seq 0 or not
759+ // endif on seq 0 or not
758760 out << SP << SP << "}\n";
759761 // Add the bias of the recurrence to feedback
760762 if (!fNB.empty()) {
0 commit comments