Skip to content

Commit 9c9dd83

Browse files
committed
Added all 4 Basic Binary Operators:- Add,Sub,Mul and Div.
Added the tests for all 4 Binary Operators
1 parent fdc99d6 commit 9c9dd83

File tree

14 files changed

+264
-24
lines changed

14 files changed

+264
-24
lines changed

tmva/sofie/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ ROOT_STANDARD_LIBRARY_PACKAGE(ROOTTMVASofie
1616
TMVA/OperatorList.hxx
1717
TMVA/RModel.hxx
1818
TMVA/ROperator.hxx
19-
TMVA/ROperator_Add.hxx
19+
TMVA/ROperator_BasicBinary.hxx
2020
TMVA/ROperator_BatchNormalization.hxx
2121
TMVA/ROperator_Conv.hxx
2222
TMVA/ROperator_Gemm.hxx

tmva/sofie/inc/TMVA/OperatorList.hxx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
#include "TMVA/ROperator_LSTM.hxx"
1010
#include "TMVA/ROperator_BatchNormalization.hxx"
1111
#include "TMVA/ROperator_Pool.hxx"
12-
#include "TMVA/ROperator_Add.hxx"
12+
#include "TMVA/ROperator_BasicBinary.hxx"
1313
#include "TMVA/ROperator_Reshape.hxx"
1414
#include "TMVA/ROperator_Slice.hxx"
1515
#include "TMVA/ROperator_GRU.hxx"

tmva/sofie/inc/TMVA/ROperator_Add.hxx renamed to tmva/sofie/inc/TMVA/ROperator_BasicBinary.hxx

Lines changed: 49 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
#ifndef TMVA_SOFIE_ROPERATOR_ADD
2-
#define TMVA_SOFIE_ROPERATOR_ADD
1+
#ifndef TMVA_SOFIE_ROperator_BasicBinary
2+
#define TMVA_SOFIE_ROperator_BasicBinary
33

44
#include "TMVA/SOFIE_common.hxx"
55
#include "TMVA/ROperator.hxx"
@@ -11,23 +11,55 @@ namespace TMVA{
1111
namespace Experimental{
1212
namespace SOFIE{
1313

14+
enum EBasicBinaryOperator { Add, Sub, Mul, Div };
15+
16+
template <typename T, EBasicBinaryOperator Op1>
17+
struct BinaryOperatorTrait {
18+
const char *Name() { return ""; }
19+
const char *Op() { return ""; }
20+
};
21+
template <typename T>
22+
struct BinaryOperatorTrait<T, Add> {
23+
static const char *Name() { return "Add"; }
24+
static const char *Op() { return "+"; }
25+
};
26+
27+
template <typename T>
28+
struct BinaryOperatorTrait<T, Sub> {
29+
static const char *Name() { return "Sub"; }
30+
static const char *Op() { return "-"; }
31+
};
32+
1433
template <typename T>
15-
class ROperator_Add final : public ROperator
16-
{
34+
struct BinaryOperatorTrait<T, Mul> {
35+
static const char *Name() { return "Mul"; }
36+
static const char *Op() { return "*"; }
37+
};
1738

39+
template <typename T>
40+
struct BinaryOperatorTrait<T, Div> {
41+
static const char *Name() { return "Div"; }
42+
static const char *Op() { return "/"; }
43+
};
44+
45+
template<typename T, EBasicBinaryOperator Op>
46+
class ROperator_BasicBinary final : public ROperator{
1847
private:
1948

2049
std::string fNX1;
2150
std::string fNX2;
2251
std::string fNY;
2352
std::vector<size_t> fShape;
2453

54+
// template <typename T, EBasicBinaryOperator Op1>
55+
// BinaryOperatorTrait<T,Op1> *s;
56+
2557
public:
26-
ROperator_Add(){}
27-
ROperator_Add(std::string nameX1, std::string nameX2, std::string nameY):
58+
ROperator_BasicBinary(){}
59+
ROperator_BasicBinary(std::string nameX1, std::string nameX2, std::string nameY):
2860
fNX1(UTILITY::Clean_name(nameX1)), fNX2(UTILITY::Clean_name(nameX2)), fNY(UTILITY::Clean_name(nameY)){}
2961

30-
// type of output given input
62+
// type of output given input
3163
std::vector<ETensorType> TypeInference(std::vector<ETensorType> input){
3264
return input;
3365
}
@@ -42,16 +74,16 @@ public:
4274
void Initialize(RModel& model){
4375
// input must be a graph input, or already initialized intermediate tensor
4476
if (model.CheckIfTensorAlreadyExist(fNX1) == false){
45-
throw std::runtime_error(std::string("TMVA SOFIE Add Op Input Tensor ") + fNX1 + "is not found in model");
77+
throw std::runtime_error(std::string("TMVA SOFIE Binary Op Input Tensor ") + fNX1 + "is not found in model");
4678
}
4779
if (model.CheckIfTensorAlreadyExist(fNX2) == false) {
48-
throw std::runtime_error(std::string("TMVA SOFIE Add Op Input Tensor ") + fNX1 + "is not found in model");
80+
throw std::runtime_error(std::string("TMVA SOFIE Binary Op Input Tensor ") + fNX2 + "is not found in model");
4981
}
5082
auto shapeX1 = model.GetTensorShape(fNX1);
5183
auto shapeX2 = model.GetTensorShape(fNX2);
52-
// assume same shape X1 and X2
84+
// assume same shape X1 and X2
5385
if (shapeX1 != shapeX2) {
54-
std::string msg = "TMVA SOFIE Add Op: Support only inputs with same shape, shape 1 is " +
86+
std::string msg = "TMVA SOFIE Binary Op: Support only inputs with same shape, shape 1 is " +
5587
ConvertShapeToString(shapeX1) + "shape 2 is " + ConvertShapeToString(shapeX2);
5688
throw std::runtime_error(msg);
5789
}
@@ -62,18 +94,20 @@ public:
6294

6395
std::string Generate(std::string OpName){
6496
OpName = "op_" + OpName;
97+
6598
if (fShape.empty()) {
66-
throw std::runtime_error("TMVA SOFIE Add called to Generate without being initialized first");
99+
throw std::runtime_error("TMVA SOFIE Binary Op called to Generate without being initialized first");
67100
}
68101
std::stringstream out;
69102
// int length = 1;
70103
// for(auto& i: fShape){
71104
// length *= i;
72105
// }
73106
size_t length = ConvertShapeToLength(fShape);
74-
out << "\n//------ Add\n";
107+
out << "\n//------ " + std::string(BinaryOperatorTrait<T,Op>::Name())+"\n";
75108
out << SP << "for (size_t id = 0; id < " << length << " ; id++){\n";
76-
out << SP << SP << "tensor_" << fNY << "[id] = tensor_" << fNX1 << "[id] + tensor_" << fNX2 << "[id];\n";
109+
out << SP << SP << "tensor_" << fNY << "[id] = tensor_" << fNX1 << "[id]" +
110+
std::string(BinaryOperatorTrait<T,Op>::Op()) + "tensor_" << fNX2 << "[id];\n";
77111
out << SP << "}\n";
78112
return out.str();
79113
}
@@ -85,4 +119,4 @@ public:
85119
}//TMVA
86120

87121

88-
#endif //TMVA_SOFIE_ROPERATOR_Add
122+
#endif //TMVA_SOFIE_ROperator_BasicBinary

tmva/sofie/test/TestCustomModelsFromONNX.cxx

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,18 @@
1212
#include "LinearWithSelu_FromONNX.hxx"
1313
#include "input_models/references/LinearWithSelu.ref.hxx"
1414

15+
#include "Sub_FromONNX.hxx"
16+
#include "input_models/references/Sub.ref.hxx"
17+
18+
#include "Add_FromONNX.hxx"
19+
#include "input_models/references/Add.ref.hxx"
20+
21+
#include "Mul_FromONNX.hxx"
22+
#include "input_models/references/Mul.ref.hxx"
23+
24+
#include "Div_FromONNX.hxx"
25+
#include "input_models/references/Div.ref.hxx"
26+
1527
#include "LinearWithLeakyRelu_FromONNX.hxx"
1628
#include "input_models/references/LinearWithLeakyRelu.ref.hxx"
1729

@@ -134,6 +146,110 @@ TEST(ONNX, Linear32)
134146
}
135147
}
136148

149+
TEST(ONNX, Sub)
150+
{
151+
constexpr float TOLERANCE = DEFAULT_TOLERANCE;
152+
153+
// Preparing the standard input
154+
std::vector<float> input1({
155+
1, 2
156+
});
157+
std::vector<float> input2({
158+
0, 1
159+
});
160+
TMVA_SOFIE_Sub::Session s("Sub_FromONNX.dat");
161+
162+
std::vector<float> output = s.infer(input2.data(),input1.data());
163+
164+
// Checking output size
165+
EXPECT_EQ(output.size(), sizeof(Sub_ExpectedOutput::outputs) / sizeof(float));
166+
167+
float *correct = Sub_ExpectedOutput::outputs;
168+
169+
// Checking every output value, one by one
170+
for (size_t i = 0; i < output.size(); ++i) {
171+
EXPECT_LE(std::abs(output[i] - correct[i]), TOLERANCE);
172+
}
173+
}
174+
175+
176+
TEST(ONNX, Add)
177+
{
178+
constexpr float TOLERANCE = DEFAULT_TOLERANCE;
179+
180+
// Preparing the standard input
181+
std::vector<float> input1({
182+
1, 2
183+
});
184+
std::vector<float> input2({
185+
0, 1
186+
});
187+
TMVA_SOFIE_Add::Session s("Add_FromONNX.dat");
188+
189+
std::vector<float> output = s.infer(input1.data(),input2.data());
190+
191+
// Checking output size
192+
EXPECT_EQ(output.size(), sizeof(Add_ExpectedOutput::outputs) / sizeof(float));
193+
194+
float *correct = Add_ExpectedOutput::outputs;
195+
196+
// Checking every output value, one by one
197+
for (size_t i = 0; i < output.size(); ++i) {
198+
EXPECT_LE(std::abs(output[i] - correct[i]), TOLERANCE);
199+
}
200+
}
201+
202+
TEST(ONNX, Mul)
203+
{
204+
constexpr float TOLERANCE = DEFAULT_TOLERANCE;
205+
206+
// Preparing the standard input
207+
std::vector<float> input1({
208+
1, 2
209+
});
210+
std::vector<float> input2({
211+
0, 1
212+
});
213+
TMVA_SOFIE_Mul::Session s("Mul_FromONNX.dat");
214+
215+
std::vector<float> output = s.infer(input1.data(),input2.data());
216+
217+
// Checking output size
218+
EXPECT_EQ(output.size(), sizeof(Mul_ExpectedOutput::outputs) / sizeof(float));
219+
220+
float *correct = Mul_ExpectedOutput::outputs;
221+
222+
// Checking every output value, one by one
223+
for (size_t i = 0; i < output.size(); ++i) {
224+
EXPECT_LE(std::abs(output[i] - correct[i]), TOLERANCE);
225+
}
226+
}
227+
228+
TEST(ONNX, Div)
229+
{
230+
constexpr float TOLERANCE = DEFAULT_TOLERANCE;
231+
232+
// Preparing the standard input
233+
std::vector<float> input1({
234+
4, 2
235+
});
236+
std::vector<float> input2({
237+
2, 2
238+
});
239+
TMVA_SOFIE_Div::Session s("Div_FromONNX.dat");
240+
241+
std::vector<float> output = s.infer(input2.data(),input1.data());
242+
243+
// Checking output size
244+
EXPECT_EQ(output.size(), sizeof(Div_ExpectedOutput::outputs) / sizeof(float));
245+
246+
float *correct = Div_ExpectedOutput::outputs;
247+
248+
// Checking every output value, one by one
249+
for (size_t i = 0; i < output.size(); ++i) {
250+
EXPECT_LE(std::abs(output[i] - correct[i]), TOLERANCE);
251+
}
252+
}
137253

138254
TEST(ONNX, Linear64)
139255
{
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
pytorch1.11.0:�
2+
)
3+
onnx::Add_0
4+
onnx::Add_12Add_0"Addtorch-jit-exportZ
5+
onnx::Add_0
6+
7+

8+
Z
9+
onnx::Add_1
10+
11+

12+
b
13+
2
14+
15+

16+
B
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
pytorch1.11.0:�
2+
)
3+
onnx::Div_0
4+
onnx::Div_12Div_0"Divtorch-jit-exportZ
5+
onnx::Div_0
6+
7+

8+
Z
9+
onnx::Div_1
10+
11+

12+
b
13+
2
14+
15+

16+
B
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
pytorch1.11.0:�
2+
)
3+
onnx::Mul_0
4+
onnx::Mul_12Mul_0"Multorch-jit-exportZ
5+
onnx::Mul_0
6+
7+

8+
Z
9+
onnx::Mul_1
10+
11+

12+
b
13+
2
14+
15+

16+
B
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
pytorch1.11.0:�
2+
)
3+
onnx::Sub_0
4+
onnx::Sub_12Sub_0"Subtorch-jit-exportZ
5+
onnx::Sub_0
6+
7+

8+
Z
9+
onnx::Sub_1
10+
11+

12+
b
13+
2
14+
15+

16+
B
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
namespace Add_ExpectedOutput{
2+
float outputs[] = {
3+
1, 3
4+
};
5+
} // namespace Add_ExpectedOutput
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
namespace Div_ExpectedOutput{
2+
float outputs[] = {
3+
2, 1
4+
};
5+
} // namespace Div_ExpectedOutput

0 commit comments

Comments
 (0)