@@ -110,9 +110,6 @@ public:
110110 // Update the data and the shape of A
111111 model.AddConstantTensor (fNBroadcastedA , model.GetTensorType (fNA ), fShapeY , broadcastedData);
112112 fShapeA = fShapeY ;
113- } else {
114- // Add an intermediate tensor for broadcasting A
115- model.AddIntermediateTensor (fNBroadcastedA , model.GetTensorType (fNA ), fShapeY );
116113 }
117114 }
118115 // Broadcast B to Y
@@ -126,9 +123,6 @@ public:
126123 // do not update tensor B but add broadcasted one (since it can be input to some other operators)
127124 model.AddConstantTensor (fNBroadcastedB , model.GetTensorType (fNB ), fShapeY , broadcastedData);
128125 fShapeB = fShapeY ;
129- } else {
130- // Add an intermediate tensor for broadcasting B
131- model.AddIntermediateTensor (fNBroadcastedB , model.GetTensorType (fNB ), fShapeY );
132126 }
133127 }
134128 // Broadcast C to Y
@@ -142,9 +136,6 @@ public:
142136 // do not update tensor C but add broadcasted one (since it can be input to some other operators)
143137 model.AddConstantTensor (fNBroadcastedC , model.GetTensorType (fNC ), fShapeY , broadcastedData);
144138 fShapeC = fShapeY ;
145- } else {
146- // Add an intermediate tensor for broadcasting B
147- model.AddIntermediateTensor (fNBroadcastedC , model.GetTensorType (fNC ), fShapeY );
148139 }
149140 }
150141 } else {
@@ -252,39 +243,64 @@ public:
252243 out << SP << " \n //-------- Where " << opName << " --> " << ConvertShapeToString (fShapeY ) << " \n " ;
253244 size_t length = ConvertShapeToLength (fShapeY );
254245 std::string typeName = TensorType<T>::Name ();
255- // Broadcast A if it's uninitialized
256- if (fShapeA != fShapeY ) {
257- out << SP << " // Broadcasting uninitialized tensor " << fNA << " \n " ;
258- // out << SP << "{\n";
259- out << SP << " TMVA::Experimental::SOFIE::UTILITY::UnidirectionalBroadcast<" << typeName << " >(tensor_" << fNA << " , " << ConvertShapeToString (fShapeA ) << " , " << ConvertShapeToString (fShapeY )
260- << " , fTensor_" << fNBroadcastedA << " );\n " ;
246+
247+ auto stridesA = UTILITY::ComputeStrideFromShape (fShapeA );
248+ auto stridesB = UTILITY::ComputeStrideFromShape (fShapeB );
249+ auto stridesC = UTILITY::ComputeStrideFromShape (fShapeC );
250+ auto stridesY = UTILITY::ComputeStrideFromShape (fShapeY );
251+
252+ std::string compute_idx_A, compute_idx_B, compute_idx_C, compute_idx_Y;
253+
254+ if (std::all_of (fShapeA .begin (), fShapeA .end (), [](size_t x) { return x == 1 ; })){
255+ compute_idx_A = " 0" ;
256+ } else {
257+ for (size_t i = 0 ; i<fShapeA .size (); ++i){
258+ if (fShapeA [i]==1 ) continue ;
259+ compute_idx_A += " idx_" +fNY +std::to_string (i+(fShapeY .size ()-fShapeA .size ()))+" * " +stridesA[i]+" +" ;
260+ }
261+ compute_idx_A.pop_back ();
261262 }
262- // Broadcast B if it's uninitialized
263- if (fShapeB != fShapeY ) {
264- out << SP << " // Broadcasting uninitialized tensor " << fNB << " \n " ;
265- // out << SP << "{\n";
266- out << SP << " TMVA::Experimental::SOFIE::UTILITY::UnidirectionalBroadcast<" << typeName << " >(tensor_" << fNB << " , " << ConvertShapeToString (fShapeB ) << " , " << ConvertShapeToString (fShapeY )
267- << " , fTensor_" << fNBroadcastedB << " );\n " ;
263+ if (std::all_of (fShapeB .begin (), fShapeB .end (), [](size_t x) { return x == 1 ; })){
264+ compute_idx_B = " 0" ;
265+ } else {
266+ for (size_t i = 0 ; i<fShapeB .size (); ++i){
267+ if (fShapeB [i]==1 ) continue ;
268+ compute_idx_B += " idx_" +fNY +std::to_string (i+(fShapeY .size ()-fShapeB .size ()))+" * " +stridesB[i]+" +" ;
269+ }
270+ compute_idx_B.pop_back ();
268271 }
269- // Broadcast C if it's uninitialized
270- if ( fShapeC != fShapeY ) {
271- // special case if C is an input tensor
272- if ( fIsInputBoolTensor ) {
273- size_t inputLength = ConvertShapeToLength (fShapeC ) ;
274- out << SP << " std::vector<std::uint8_t> fTensor_ " << fNC << " (tensor_ " << fNC << " , tensor_ " << fNC << " + " << inputLength << " ); \n " ;
272+ if ( std::all_of ( fShapeC . begin (), fShapeC . end (), []( size_t x) { return x == 1 ; })){
273+ compute_idx_C = " 0 " ;
274+ } else {
275+ for ( size_t i = 0 ; i< fShapeC . size (); ++i) {
276+ if (fShapeC [i]== 1 ) continue ;
277+ compute_idx_C += " idx_ " + fNY + std::to_string (i+( fShapeY . size ()- fShapeC . size ()))+ " * " +stridesC[i]+ " + " ;
275278 }
276- out << SP << " // Broadcasting uninitialized tensor " << fNC << " \n " ;
277- // out << SP << "{\n";
278- out << SP << " TMVA::Experimental::SOFIE::UTILITY::UnidirectionalBroadcast<std::uint8_t>(fTensor_" << fNC << " .data(), " << ConvertShapeToString (fShapeC ) << " , " << ConvertShapeToString (fShapeY )
279- << " , fTensor_" << fNBroadcastedC << " );\n " ;
279+ compute_idx_C.pop_back ();
280+ }
281+
282+ if (fIsInputBoolTensor ) {
283+ size_t inputLength = ConvertShapeToLength (fShapeC );
284+ out << SP << " std::vector<bool> fTensor_" << fNC << " (tensor_" << fNC << " , tensor_" << fNC << " + " << inputLength << " );\n " ;
280285 }
281- std::string nameA = fNBroadcastedA .empty ()? fNA : fNBroadcastedA ;
282- std::string nameB = fNBroadcastedB .empty ()? fNB : fNBroadcastedB ;
283- std::string nameC = fNBroadcastedC .empty ()? fNC : fNBroadcastedC ;
284- out << SP << " for (size_t id = 0; id < " << length << " ; id++){\n " ;
286+
287+ for (size_t j = 0 ; j<fShapeY .size (); ++j){
288+ out << SP << " size_t " << fNY << " idx_" <<j<<" ;\n " ;
289+ }
290+ out << SP << " for(size_t idx = 0; idx < " << length << " ; ++idx){\n " ;
291+ out<< SP << SP << " idx_" <<fNY <<" 0 = idx / " << stridesY[0 ]<<" ;\n " ;
292+ compute_idx_Y += " idx_" +fNY +" 0 * " + std::to_string (stridesY[0 ]);
293+ std::string modulo_op = " idx % " + std::to_string (stridesY[0 ]);
294+ for (size_t j = 1 ; j<fShapeY .size (); ++j){
295+
296+ out << SP << SP << " idx_" <<fNY <<j<<" = (" <<modulo_op<<" ) / " <<stridesY[j]<<" ;\n " ;
297+ modulo_op += " % " + std::to_string (stridesY[j]);
298+ compute_idx_Y = " idx_" +fNY +std::to_string (j)+" * " +std::to_string (stridesY[j])+" + " +compute_idx_Y;
299+ }
300+
285301 // get output tensor applying condition
286- out << SP << SP << " tensor_" << fNY << " [id ] = " << " (fTensor_" << nameC << " [id ]) ? tensor_"
287- << nameA << " [id ] : tensor_" + nameB + " [id ];\n " ;
302+ out << SP << SP << " tensor_" << fNY << " [" <<compute_idx_Y<< " ] = " << " (fTensor_" << fNC << " [" <<compute_idx_C<< " ]) ? tensor_"
303+ << fNA << " [" <<compute_idx_A<< " ] : tensor_" + fNB + " [" <<compute_idx_B<< " ];\n " ;
288304 out << SP << " }\n " ;
289305 return out.str ();
290306 }
0 commit comments