Skip to content

Commit 44c20d2

Browse files
committed
Required Support and Test for Multibroadcasting added.
1 parent 7871f77 commit 44c20d2

File tree

4 files changed

+64
-4
lines changed

4 files changed

+64
-4
lines changed

tmva/sofie/inc/TMVA/ROperator_BasicBinary.hxx

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,11 +83,17 @@ public:
8383
auto shapeX2 = model.GetTensorShape(fNX2);
8484
// assume same shape X1 and X2
8585
if (shapeX1 != shapeX2) {
86-
std::string msg = "TMVA SOFIE Binary Op: Support only inputs with same shape, shape 1 is " +
87-
ConvertShapeToString(shapeX1) + "shape 2 is " + ConvertShapeToString(shapeX2);
88-
throw std::runtime_error(msg);
86+
fShape = UTILITY::Multidirectional_broadcast(shapeX1,shapeX2);
87+
size_t length1 = ConvertShapeToLength(shapeX1);
88+
size_t length2 = ConvertShapeToLength(shapeX2);
89+
size_t output_length = ConvertShapeToLength(fShape);
90+
if(length1 != length2 || length1 != output_length){
91+
throw std::runtime_error(std::string("TMVA SOFIE Binary Op does not support input tensors with different lengths. The output tensor should also have the same length as the input tensors."));
92+
}
8993
}
90-
fShape = shapeX1;
94+
else if(shapeX1 == shapeX2){
95+
fShape = shapeX1;
96+
}
9197
model.AddIntermediateTensor(fNY, model.GetTensorType(fNX1), fShape);
9298
}
9399

tmva/sofie/test/TestCustomModelsFromONNX.cxx

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@
1818
#include "Add_FromONNX.hxx"
1919
#include "input_models/references/Add.ref.hxx"
2020

21+
#include "Add_broadcast_FromONNX.hxx"
22+
#include "input_models/references/Add_broadcast.ref.hxx"
23+
2124
#include "Mul_FromONNX.hxx"
2225
#include "input_models/references/Mul.ref.hxx"
2326

@@ -172,6 +175,33 @@ TEST(ONNX, Sub)
172175
}
173176
}
174177

178+
TEST(ONNX, Add_broadcast)
179+
{
180+
constexpr float TOLERANCE = DEFAULT_TOLERANCE;
181+
182+
// Preparing the standard input
183+
std::vector<float> input1({
184+
1, 2, 3,
185+
3, 4, 5
186+
});
187+
std::vector<float> input2({
188+
5, 6, 7,
189+
8, 9, 10
190+
});
191+
TMVA_SOFIE_Add_broadcast::Session s("Add_broadcast_FromONNX.dat");
192+
193+
std::vector<float> output = s.infer(input2.data(),input1.data());
194+
195+
// Checking output size
196+
EXPECT_EQ(output.size(), sizeof(Add_broadcast_ExpectedOutput::outputs) / sizeof(float));
197+
198+
float *correct = Add_broadcast_ExpectedOutput::outputs;
199+
200+
// Checking every output value, one by one
201+
for (size_t i = 0; i < output.size(); ++i) {
202+
EXPECT_LE(std::abs(output[i] - correct[i]), TOLERANCE);
203+
}
204+
}
175205

176206
TEST(ONNX, Add)
177207
{
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
pytorch1.11.0:�
2+
)
3+
onnx::Add_0
4+
onnx::Add_12Add_0"Addtorch-jit-exportZ!
5+
onnx::Add_0
6+

7+

8+

9+
Z
10+
onnx::Add_1
11+

12+

13+
b
14+
2
15+

16+

17+

18+
B
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
namespace Add_broadcast_ExpectedOutput{
2+
float outputs[] = {
3+
6, 8, 10,
4+
11, 13, 15
5+
};
6+
} // namespace Add_ExpectedOutput

0 commit comments

Comments
 (0)