Skip to content

Commit a2150e5

Browse files
committed
Implemented the Multi-directional Broadcasting for SOFIE
1 parent 573f839 commit a2150e5

File tree

2 files changed

+54
-1
lines changed

2 files changed

+54
-1
lines changed

tmva/sofie/inc/TMVA/SOFIE_common.hxx

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,8 @@ ETensorType GetTemplatedType(T /*obj*/ ){
105105
namespace UTILITY{
106106
template<typename T>
107107
T* Unidirectional_broadcast(const T* original_data, const std::vector<size_t> original_shape, const std::vector<size_t> target_shape);
108+
template <typename T>
109+
std::vector<size_t> Multidirectional_broadcast(const T* original_data, std::vector<size_t> input1_shape, std::vector<size_t> input2_shape);
108110
std::string Clean_name(std::string input_tensor_name);
109111

110112

tmva/sofie/src/SOFIE_common.cxx

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,14 +135,65 @@ T* UTILITY::Unidirectional_broadcast(const T* original_data, const std::vector<s
135135
return new_datavector;
136136
}
137137

138+
template <typename T>
139+
std::vector<size_t> UTILITY::Multidirectional_broadcast(const T* original_data, std::vector<size_t> input1_shape, std::vector<size_t> input2_shape)
140+
{
141+
std::vector<size_t> input_shape = (input1_shape > input2_shape)?input1_shape:input2_shape;
142+
std::vector<size_t> output_shape(input_shape);
143+
144+
//check if both the input have same shape, nothing to do directly return the output_shape as the same shape.
145+
if(input1_shape.size() == input2_shape.size()){
146+
if(input1_shape == input2_shape){
147+
return output_shape;
148+
}
149+
else{
150+
//Check the shape values, if input1[i] not equal to input2[i] we have the result shape equal to input1[i] if input2[i] = 1 or viceversa
151+
for(size_t j = 0; j < input1_shape.size() ; j++){
152+
if(input1_shape[j] == input2_shape[j]){
153+
output_shape[j] = input1_shape[j];
154+
}
155+
else if(input1_shape[j] > input2_shape[j] && input2_shape[j] == 1){
156+
output_shape[j] = input1_shape[j];
157+
}
158+
else if(input2_shape[j] > input1_shape[j] && input1_shape[j] == 1){
159+
output_shape[j] = input2_shape[j];
160+
}
161+
}
162+
return output_shape;
163+
}
164+
}
165+
else if(input1_shape.size() < input2_shape.size()){
166+
// Check if input1_shape.size() < input2_shape.size() we insert in the shape vector values of 1 at the beginning of the tensor until input1_shape.size() == input2_shape.size()
167+
auto it = input1_shape.begin();
168+
while (input1_shape.size() < input2_shape.size()) {
169+
it = input1_shape.insert(it, 1);
170+
}
171+
if(input1_shape.size()==input1_shape.size()){
172+
UTILITY::Multidirectional_broadcast(original_data,input1_shape,input2_shape);
173+
}
174+
}
175+
else if(input2_shape.size() < input1_shape.size()){
176+
// Check if input2_shape.size() < input1_shape.size() we insert in the shape vector values of 1 at the beginning of the tensor until input1_shape.size() == input2_shape.size()
177+
auto it = input2_shape.begin();
178+
while (input2_shape.size() < input1_shape.size()) {
179+
it = input2_shape.insert(it, 1);
180+
}
181+
if(input1_shape.size()==input1_shape.size()){
182+
UTILITY::Multidirectional_broadcast(original_data,input1_shape,input2_shape);
183+
}
184+
}
185+
186+
}
187+
138188
std::string UTILITY::Clean_name(std::string input_tensor_name){
139189
std::string s (input_tensor_name);
140190
s.erase(std::remove_if(s.begin(), s.end(), []( char const& c ) -> bool { return !std::isalnum(c); } ), s.end());
141191
return s;
142192
}
143193

144194
template float* UTILITY::Unidirectional_broadcast(const float* original_data, const std::vector<size_t> original_shape, const std::vector<size_t> target_shape);
195+
template std::vector<size_t> UTILITY::Multidirectional_broadcast(const float* original_data, std::vector<size_t> input1_shape, std::vector<size_t> input2_shape);
145196

146197
}//SOFIE
147198
}//Experimental
148-
}//TMVA
199+
}//TMVA

0 commit comments

Comments
 (0)