Skip to content

Commit 3e7caa2

Browse files
committed
[tmva][sofie] Fix in GRU for reverse. It is equivalent to backward
1 parent 45f3e7b commit 3e7caa2

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

tmva/sofie/inc/TMVA/ROperator_GRU.icc

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ void ROperator_GRU<T>::Initialize(RModel &model) {
8282
for (size_t batch = 0; batch < batch_size; batch++) {
8383
size_t bias_offset = direction * 6 * fAttrHiddenSize + i * fAttrHiddenSize;
8484
size_t offset = direction * 6 * batch_size * seq_length * fAttrHiddenSize +
85-
i * batch_size * seq_length * fAttrHiddenSize +
85+
i * batch_size * seq_length * fAttrHiddenSize +
8686
+ seq *batch_size *fAttrHiddenSize + batch *fAttrHiddenSize;
8787
std::copy(original_bias + bias_offset, original_bias + bias_offset + fAttrHiddenSize,
8888
new_bias + offset);
@@ -146,7 +146,9 @@ void ROperator_GRU<T>::Initialize(RModel &model) {
146146
activation + " not implemented");
147147
}
148148
}
149+
if (fAttrDirection == "reverse") fAttrDirection = "backward";
149150
if (fAttrDirection != "forward" && fAttrDirection != "backward" &&
151+
fAttrDirection != "reverse" &&
150152
fAttrDirection != "bidirectional") {
151153
throw std::runtime_error(
152154
"TMVA SOFIE - Invalid GRU direction fAttrDirection = " +
@@ -206,7 +208,7 @@ std::string ROperator_GRU<T>::GenerateSessionMembersCode(std::string opName)
206208
out << "std::vector<" << fType << "> fVec_" << opName << "_update_gate = std::vector<" << fType << ">(" << hs_size << ");\n";
207209
out << "std::vector<" << fType << "> fVec_" << opName << "_reset_gate = std::vector<" << fType << ">(" << hs_size << ");\n";
208210
out << "std::vector<" << fType << "> fVec_" << opName << "_hidden_gate = std::vector<" << fType << ">(" << hs_size << ");\n";
209-
211+
210212
// feedback
211213
out << "std::vector<" << fType << "> fVec_" << opName << "_feedback = std::vector<" << fType << ">("
212214
<< batch_size * fAttrHiddenSize << ");\n";
@@ -754,7 +756,7 @@ auto ROperator_GRU<T>::Generate(std::string OpName)
754756
<< OpName << "_n, &" << OpName << "_m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR
755757
<< " + " << rh_offset << ", &" << OpName << "_n, " << OpName << "_hidden_state + previous_offset, &"
756758
<< OpName << "_n, &" << OpName << "_beta, " << OpName << "_feedback, &" << OpName << "_n);\n";
757-
// endif on seq 0 or not
759+
// endif on seq 0 or not
758760
out << SP << SP << "}\n";
759761
// Add the bias of the recurrence to feedback
760762
if (!fNB.empty()) {

0 commit comments

Comments
 (0)