Skip to content

Commit 7871f77

Browse files
committed
Fixing some test cases regarding Multi directional Broadcasting
1 parent 3f05469 commit 7871f77

File tree

2 files changed

+4
-11
lines changed

2 files changed

+4
-11
lines changed

tmva/sofie/inc/TMVA/SOFIE_common.hxx

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,7 @@ ETensorType GetTemplatedType(T /*obj*/ ){
105105
namespace UTILITY{
106106
template<typename T>
107107
T* Unidirectional_broadcast(const T* original_data, const std::vector<size_t> original_shape, const std::vector<size_t> target_shape);
108-
109-
std::vector<size_t> Multidirectional_broadcast(std::vector<size_t> input1_shape, std::vector<size_t> input2_shape);
108+
std::vector<size_t> Multidirectional_broadcast(const std::vector<size_t> input1_shape, const std::vector<size_t> input2_shape);
110109
std::string Clean_name(std::string input_tensor_name);
111110

112111

tmva/sofie/src/SOFIE_common.cxx

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -137,9 +137,9 @@ T* UTILITY::Unidirectional_broadcast(const T* original_data, const std::vector<s
137137

138138

139139

140-
std::vector<size_t> UTILITY::Multidirectional_broadcast(std::vector<size_t> input1_shape, std::vector<size_t> input2_shape)
140+
std::vector<size_t> UTILITY::Multidirectional_broadcast(std::vector<size_t> input1_shape, std::vector<size_t> input2_shape)
141141
{
142-
std::vector<size_t> input_shape = (input1_shape > input2_shape)?input1_shape:input2_shape;
142+
std::vector<size_t> input_shape = (input1_shape.size() > input2_shape.size())?input1_shape:input2_shape;
143143
std::vector<size_t> output_shape(input_shape);
144144

145145
if(input1_shape.size() < input2_shape.size()){
@@ -148,22 +148,16 @@ std::vector<size_t> UTILITY::Multidirectional_broadcast(std::vector<size_t> inp
148148
while (input1_shape.size() < input2_shape.size()) {
149149
it = input1_shape.insert(it, 1);
150150
}
151-
if(input1_shape.size()==input1_shape.size()){
152-
UTILITY::Multidirectional_broadcast(input1_shape,input2_shape);
153-
}
154151
}
155152
else if(input2_shape.size() < input1_shape.size()){
156153
// Check if input2_shape.size() < input1_shape.size() we insert in the shape vector values of 1 at the beginning of the tensor until input1_shape.size() == input2_shape.size()
157154
auto it = input2_shape.begin();
158155
while (input2_shape.size() < input1_shape.size()) {
159156
it = input2_shape.insert(it, 1);
160157
}
161-
if(input1_shape.size()==input1_shape.size()){
162-
UTILITY::Multidirectional_broadcast(input1_shape,input2_shape);
163-
}
164158
}
165159
//check if both the input have same shape, nothing to do directly return the output_shape as the same shape.
166-
else if(input1_shape.size() == input2_shape.size()){
160+
if(input1_shape.size() == input2_shape.size()){
167161
if(input1_shape != input2_shape){
168162
//Check the shape values, if input1[i] not equal to input2[i] we have the result shape equal to input1[i] if input2[i] = 1 or viceversa
169163
for(size_t j = 0; j < input1_shape.size() ; j++){

0 commit comments

Comments
 (0)