Skip to content

Commit 4954852

Browse files
committed
Add 4 Binary Operators
1 parent 48f164f commit 4954852

File tree

1 file changed

+11
-10
lines changed

1 file changed

+11
-10
lines changed

tmva/sofie/inc/TMVA/RModel_BasicBinaryOp.hxx

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,6 @@ namespace Experimental{
1212
namespace SOFIE{
1313

1414
enum EBasicBinaryOperator { Add, Sub, Mul, Div };
15-
template <typename T, EBasicBinaryOperator Op>
16-
struct BinaryOperatorTrait {
17-
const char * Name() { return"";}
18-
const char * Op() { return "";}
19-
};
2015

2116
class RModel_BasicBinaryOp final : public ROperator{
2217
private:
@@ -25,6 +20,11 @@ private:
2520
std::string fNX2;
2621
std::string fNY;
2722
std::vector<size_t> fShape;
23+
template <typename T, EBasicBinaryOperator Op1>
24+
struct BinaryOperatorTrait {
25+
const char * Name() { return "";}
26+
const char * Op() { return "";}
27+
};
2828
template <typename T>
2929
struct BinaryOperatorTrait<T,Add> {
3030
const char * Name() { return "Add";}
@@ -48,8 +48,9 @@ private:
4848
const char * Name() { return "Div";}
4949
const char * Op() { return "/";}
5050
};
51-
template <typename T, EBasicBinaryOperator Op>
52-
BinaryOperatorTrait<T,Op> *s;
51+
52+
template <typename T, EBasicBinaryOperator Op1>
53+
BinaryOperatorTrait<T,Op1> *s;
5354

5455
public:
5556
RModel_BasicBinaryOp(){}
@@ -71,16 +72,16 @@ public:
7172
void Initialize(RModel& model){
7273
// input must be a graph input, or already initialized intermediate tensor
7374
if (model.CheckIfTensorAlreadyExist(fNX1) == false){
74-
throw std::runtime_error(std::string("TMVA SOFIE Add Op Input Tensor ") + fNX1 + "is not found in model");
75+
throw std::runtime_error(std::string("TMVA SOFIE Binary Op Input Tensor ") + fNX1 + "is not found in model");
7576
}
7677
if (model.CheckIfTensorAlreadyExist(fNX2) == false) {
77-
throw std::runtime_error(std::string("TMVA SOFIE Add Op Input Tensor ") + fNX2 + "is not found in model");
78+
throw std::runtime_error(std::string("TMVA SOFIE Binary Op Input Tensor ") + fNX2 + "is not found in model");
7879
}
7980
auto shapeX1 = model.GetTensorShape(fNX1);
8081
auto shapeX2 = model.GetTensorShape(fNX2);
8182
// assume same shape X1 and X2
8283
if (shapeX1 != shapeX2) {
83-
std::string msg = "TMVA SOFIE Add Op: Support only inputs with same shape, shape 1 is " +
84+
std::string msg = "TMVA SOFIE Binary Op: Support only inputs with same shape, shape 1 is " +
8485
ConvertShapeToString(shapeX1) + "shape 2 is " + ConvertShapeToString(shapeX2);
8586
throw std::runtime_error(msg);
8687
}

0 commit comments

Comments
 (0)