Skip to content

Commit f799ed3

Browse files
committed
All the four Binary Operators:- Add,Sub,Mul,Div added with the corresponding unit tests and Multi-directional broadcasting functionality is added for SOFIE
1 parent 4e0ef33 commit f799ed3

19 files changed

+460
-107
lines changed

tmva/sofie/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ ROOT_STANDARD_LIBRARY_PACKAGE(ROOTTMVASofie
1616
TMVA/OperatorList.hxx
1717
TMVA/RModel.hxx
1818
TMVA/ROperator.hxx
19-
TMVA/ROperator_Add.hxx
19+
TMVA/ROperator_BasicBinary.hxx
2020
TMVA/ROperator_BatchNormalization.hxx
2121
TMVA/ROperator_Conv.hxx
2222
TMVA/ROperator_Gemm.hxx

tmva/sofie/inc/TMVA/OperatorList.hxx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
#include "TMVA/ROperator_LSTM.hxx"
1010
#include "TMVA/ROperator_BatchNormalization.hxx"
1111
#include "TMVA/ROperator_Pool.hxx"
12-
#include "TMVA/ROperator_Add.hxx"
12+
#include "TMVA/ROperator_BasicBinary.hxx"
1313
#include "TMVA/ROperator_Reshape.hxx"
1414
#include "TMVA/ROperator_Slice.hxx"
1515
#include "TMVA/ROperator_GRU.hxx"

tmva/sofie/inc/TMVA/ROperator_Add.hxx

Lines changed: 0 additions & 88 deletions
This file was deleted.
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
#ifndef TMVA_SOFIE_ROperator_BasicBinary
2+
#define TMVA_SOFIE_ROperator_BasicBinary
3+
4+
#include "TMVA/SOFIE_common.hxx"
5+
#include "TMVA/ROperator.hxx"
6+
#include "TMVA/RModel.hxx"
7+
8+
#include <sstream>
9+
10+
namespace TMVA{
11+
namespace Experimental{
12+
namespace SOFIE{
13+
14+
enum EBasicBinaryOperator { Add, Sub, Mul, Div };
15+
16+
template <typename T, EBasicBinaryOperator Op1>
17+
struct BinaryOperatorTrait {
18+
const char *Name() { return ""; }
19+
const char *Op() { return ""; }
20+
};
21+
template <typename T>
22+
struct BinaryOperatorTrait<T, Add> {
23+
static const char *Name() { return "Add"; }
24+
static const char *Op() { return "+"; }
25+
};
26+
27+
template <typename T>
28+
struct BinaryOperatorTrait<T, Sub> {
29+
static const char *Name() { return "Sub"; }
30+
static const char *Op() { return "-"; }
31+
};
32+
33+
template <typename T>
34+
struct BinaryOperatorTrait<T, Mul> {
35+
static const char *Name() { return "Mul"; }
36+
static const char *Op() { return "*"; }
37+
};
38+
39+
template <typename T>
40+
struct BinaryOperatorTrait<T, Div> {
41+
static const char *Name() { return "Div"; }
42+
static const char *Op() { return "/"; }
43+
};
44+
45+
template<typename T, EBasicBinaryOperator Op>
46+
class ROperator_BasicBinary final : public ROperator{
47+
private:
48+
49+
std::string fNX1;
50+
std::string fNX2;
51+
std::string fNY;
52+
std::vector<size_t> fShape;
53+
54+
// template <typename T, EBasicBinaryOperator Op1>
55+
// BinaryOperatorTrait<T,Op1> *s;
56+
57+
public:
58+
ROperator_BasicBinary(){}
59+
ROperator_BasicBinary(std::string nameX1, std::string nameX2, std::string nameY):
60+
fNX1(UTILITY::Clean_name(nameX1)), fNX2(UTILITY::Clean_name(nameX2)), fNY(UTILITY::Clean_name(nameY)){}
61+
62+
// type of output given input
63+
std::vector<ETensorType> TypeInference(std::vector<ETensorType> input){
64+
return input;
65+
}
66+
67+
// shape of output tensors given input tensors
68+
std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>> input){
69+
// assume now inputs have same shape (no broadcasting)
70+
auto ret = std::vector<std::vector<size_t>>(1, input[0]); // return vector size 1 with first input
71+
return ret;
72+
}
73+
74+
void Initialize(RModel& model){
75+
// input must be a graph input, or already initialized intermediate tensor
76+
if (model.CheckIfTensorAlreadyExist(fNX1) == false){
77+
throw std::runtime_error(std::string("TMVA SOFIE Binary Op Input Tensor ") + fNX1 + "is not found in model");
78+
}
79+
if (model.CheckIfTensorAlreadyExist(fNX2) == false) {
80+
throw std::runtime_error(std::string("TMVA SOFIE Binary Op Input Tensor ") + fNX2 + "is not found in model");
81+
}
82+
auto shapeX1 = model.GetTensorShape(fNX1);
83+
auto shapeX2 = model.GetTensorShape(fNX2);
84+
// assume same shape X1 and X2
85+
if (shapeX1 != shapeX2) {
86+
fShape = UTILITY::Multidirectional_broadcast(shapeX1,shapeX2);
87+
size_t length1 = ConvertShapeToLength(shapeX1);
88+
size_t length2 = ConvertShapeToLength(shapeX2);
89+
size_t output_length = ConvertShapeToLength(fShape);
90+
if(length1 != length2 || length1 != output_length){
91+
throw std::runtime_error(std::string("TMVA SOFIE Binary Op does not support input tensors with different lengths. The output tensor should also have the same length as the input tensors."));
92+
}
93+
}
94+
else if(shapeX1 == shapeX2){
95+
fShape = shapeX1;
96+
}
97+
model.AddIntermediateTensor(fNY, model.GetTensorType(fNX1), fShape);
98+
}
99+
100+
101+
std::string Generate(std::string OpName){
102+
OpName = "op_" + OpName;
103+
104+
if (fShape.empty()) {
105+
throw std::runtime_error("TMVA SOFIE Binary Op called to Generate without being initialized first");
106+
}
107+
std::stringstream out;
108+
// int length = 1;
109+
// for(auto& i: fShape){
110+
// length *= i;
111+
// }
112+
size_t length = ConvertShapeToLength(fShape);
113+
out << "\n//------ " + std::string(BinaryOperatorTrait<T,Op>::Name())+"\n";
114+
out << SP << "for (size_t id = 0; id < " << length << " ; id++){\n";
115+
out << SP << SP << "tensor_" << fNY << "[id] = tensor_" << fNX1 << "[id]" +
116+
std::string(BinaryOperatorTrait<T,Op>::Op()) + "tensor_" << fNX2 << "[id];\n";
117+
out << SP << "}\n";
118+
return out.str();
119+
}
120+
121+
};
122+
123+
}//SOFIE
124+
}//Experimental
125+
}//TMVA
126+
127+
128+
#endif //TMVA_SOFIE_ROperator_BasicBinary

tmva/sofie/inc/TMVA/SOFIE_common.hxx

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ ETensorType GetTemplatedType(T /*obj*/ ){
105105
namespace UTILITY{
106106
template<typename T>
107107
T* Unidirectional_broadcast(const T* original_data, const std::vector<size_t> original_shape, const std::vector<size_t> target_shape);
108+
std::vector<size_t> Multidirectional_broadcast(const std::vector<size_t> input1_shape, const std::vector<size_t> input2_shape);
108109
std::string Clean_name(std::string input_tensor_name);
109110

110111

tmva/sofie/src/SOFIE_common.cxx

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,49 @@ T* UTILITY::Unidirectional_broadcast(const T* original_data, const std::vector<s
135135
return new_datavector;
136136
}
137137

138+
139+
140+
std::vector<size_t> UTILITY::Multidirectional_broadcast(std::vector<size_t> input1_shape, std::vector<size_t> input2_shape)
141+
{
142+
std::vector<size_t> input_shape = (input1_shape.size() > input2_shape.size())?input1_shape:input2_shape;
143+
std::vector<size_t> output_shape(input_shape);
144+
145+
if(input1_shape.size() < input2_shape.size()){
146+
// Check if input1_shape.size() < input2_shape.size() we insert in the shape vector values of 1 at the beginning of the tensor until input1_shape.size() == input2_shape.size()
147+
auto it = input1_shape.begin();
148+
while (input1_shape.size() < input2_shape.size()) {
149+
it = input1_shape.insert(it, 1);
150+
}
151+
}
152+
else if(input2_shape.size() < input1_shape.size()){
153+
// Check if input2_shape.size() < input1_shape.size() we insert in the shape vector values of 1 at the beginning of the tensor until input1_shape.size() == input2_shape.size()
154+
auto it = input2_shape.begin();
155+
while (input2_shape.size() < input1_shape.size()) {
156+
it = input2_shape.insert(it, 1);
157+
}
158+
}
159+
//check if both the input have same shape, nothing to do directly return the output_shape as the same shape.
160+
if(input1_shape.size() == input2_shape.size()){
161+
if(input1_shape != input2_shape){
162+
//Check the shape values, if input1[i] not equal to input2[i] we have the result shape equal to input1[i] if input2[i] = 1 or viceversa
163+
for(size_t j = 0; j < input1_shape.size() ; j++){
164+
if(input1_shape[j] == input2_shape[j]){
165+
output_shape[j] = input1_shape[j];
166+
}
167+
else if(input1_shape[j] > input2_shape[j] && input2_shape[j] == 1){
168+
output_shape[j] = input1_shape[j];
169+
}
170+
else if(input2_shape[j] > input1_shape[j] && input1_shape[j] == 1){
171+
output_shape[j] = input2_shape[j];
172+
}
173+
}
174+
}
175+
176+
}
177+
return output_shape;
178+
179+
}
180+
138181
std::string UTILITY::Clean_name(std::string input_tensor_name){
139182
std::string s (input_tensor_name);
140183
s.erase(std::remove_if(s.begin(), s.end(), []( char const& c ) -> bool { return !std::isalnum(c); } ), s.end());

0 commit comments

Comments
 (0)