@@ -13,6 +13,7 @@ namespace SOFIE{
1313 fInputTensorInfos = std::move (other.fInputTensorInfos );
1414 fReadyInputTensorInfos = std::move (other.fReadyInputTensorInfos );
1515 fOutputTensorNames = other.fOutputTensorNames ;
16+ fInputTensorNames = other.fInputTensorNames ;
1617 fOperators = std::move (other.fOperators );
1718 fInitializedTensors = std::move (other.fInitializedTensors );
1819 fIntermediateTensorInfos = std::move (other.fIntermediateTensorInfos );
@@ -28,6 +29,7 @@ namespace SOFIE{
2829 fInputTensorInfos = std::move (other.fInputTensorInfos );
2930 fReadyInputTensorInfos = std::move (other.fReadyInputTensorInfos );
3031 fOutputTensorNames = other.fOutputTensorNames ;
32+ fInputTensorNames = other.fInputTensorNames ;
3133 fOperators = std::move (other.fOperators );
3234 fInitializedTensors = std::move (other.fInitializedTensors );
3335 fIntermediateTensorInfos = std::move (other.fIntermediateTensorInfos );
@@ -83,7 +85,7 @@ namespace SOFIE{
8385 return f4->second .type ;
8486 }
8587
86- throw std::runtime_error (" TMVA SOFIE tensor [" + name + " ] for which the shape is requested is not found" );
88+ throw std::runtime_error (" TMVA SOFIE tensor [" + name + " ] for which the type is requested is not found" );
8789 }
8890
8991 bool RModel::CheckIfTensorAlreadyExist (std::string tensor_name){
@@ -112,6 +114,10 @@ namespace SOFIE{
112114 fReadyInputTensorInfos [input_name] = inputInfo;
113115 }
114116
117+ void RModel::AddInputTensorName (std::string input_name) {
118+ fInputTensorNames .push_back (input_name);
119+ }
120+
115121 void RModel::AddOperator (std::unique_ptr<ROperator> op, int order_execution){
116122 if (order_execution >= 0 ) {
117123 fOperators .insert (fOperators .begin () + order_execution, std::move (op));
@@ -164,15 +170,36 @@ namespace SOFIE{
164170 }
165171 }
166172
167- void RModel::Initialize (){
173+ void RModel::Initialize (int batchSize){
174+ // check if there are only parametrized input tensor and convert in
175+ // ready input tensor according to batch size
176+ // convert parametric shape to a dimensional shape
177+ if (fReadyInputTensorInfos .size () != fInputTensorNames .size ()) {
178+ if ( fReadyInputTensorInfos .size () + fInputTensorInfos .size () != fInputTensorNames .size ())
179+ throw std::runtime_error (" TMVA-SOFIE: RModel::Initializes: invalid inputs" );
180+ for (auto & input : fInputTensorInfos ) {
181+ std::vector<size_t > shape;
182+ shape.reserve (input.second .shape .size ());
183+ for (auto & d : input.second .shape ){
184+ if (d.isParam )
185+ shape.push_back (batchSize);
186+ else
187+ shape.push_back (d.dim );
188+ }
189+ AddInputTensorInfo (input.first , input.second .type , shape);
190+ }
191+ }
192+
193+
168194 for (auto & i : fOperators ){
195+ // std::cout << "initialize operator " << typeid(*i).name() << std::endl;
169196 i->Initialize (*this );
170197 }
171198 }
172199
173- void RModel::Generate (bool useSession, bool useWeightFile){
200+ void RModel::Generate (bool useSession, bool useWeightFile, int batchSize ){
174201 fUseSession = useSession; // session flag is used in operator initialize
175- Initialize ();
202+ Initialize (batchSize );
176203 fGC += (" //Code generated automatically by TMVA for Inference of Model file [" + fFileName + " ] at [" + fParseTime .substr (0 , fParseTime .length ()-1 ) +" ] \n " );
177204 for (auto & i: fNeededStdLib ) {
178205 fGC += " #include<" + i + " >\n " ;
@@ -182,7 +209,7 @@ namespace SOFIE{
182209 fGC += " #include \" TMVA/SOFIE_common.hxx\"\n " ;
183210 if (useWeightFile)
184211 fGC += " #include <fstream>\n " ;
185-
212+
186213 fGC += " \n namespace TMVA_SOFIE_" + fName + " {\n " ;
187214 if (!fNeededBlasRoutines .empty ()) {
188215 fGC += (" namespace BLAS{\n " );
@@ -227,7 +254,7 @@ namespace SOFIE{
227254 fGC += " std::vector<float> fTensor_" + i.first + " = std::vector<float>(" + std::to_string (length) + " );\n " ;
228255 fGC += " float * tensor_" + i.first + " = fTensor_" + i.first + " .data();\n " ;
229256 }
230-
257+
231258 }
232259 }
233260 for (auto &i: fIntermediateTensorInfos ){
@@ -309,16 +336,16 @@ namespace SOFIE{
309336 }
310337 if (outputSize == 1 ) {
311338 size_t outputLength = ConvertShapeToLength (GetTensorShape (fOutputTensorNames [0 ]));
312-
313- fGC += " \t std::vector<float> ret (tensor_" + fOutputTensorNames [0 ] + " , tensor_" + fOutputTensorNames [0 ] + " + " +
339+
340+ fGC += " \t std::vector<float> ret (tensor_" + fOutputTensorNames [0 ] + " , tensor_" + fOutputTensorNames [0 ] + " + " +
314341 std::to_string (outputLength) + " );\n " ;
315342 } else {
316343 for (size_t i = 0 ; i < outputSize; i++) {
317344 if (!fOutputTensorNames [i].empty ()) {
318345 size_t outputLength = ConvertShapeToLength (GetTensorShape (fOutputTensorNames [i]));
319346 fGC += " \t std::vector<float> ret_" ;
320347 fGC += std::to_string (i);
321- fGC += " (tensor_" + fOutputTensorNames [i] + " , tensor_" + fOutputTensorNames [i] + " + " +
348+ fGC += " (tensor_" + fOutputTensorNames [i] + " , tensor_" + fOutputTensorNames [i] + " + " +
322349 std::to_string (outputLength) + " );\n " ;
323350 }
324351 }
@@ -353,7 +380,7 @@ namespace SOFIE{
353380 fGC += " }\n " ;
354381 fGC += " std::string tensor_name;\n " ;
355382 fGC += " int length;\n " ;
356-
383+
357384 // loop on tensors and parse the file
358385 for (auto & i: fInitializedTensors ){
359386 if (i.second .fType == ETensorType::FLOAT){
@@ -370,7 +397,7 @@ namespace SOFIE{
370397 fGC += " throw std::runtime_error(err_msg);\n " ;
371398 fGC += " }\n " ;
372399 fGC += " if (length != " + slength + " ) {\n " ;
373- fGC += " std::string err_msg = \" TMVA-SOFIE failed to read the correct tensor size; expected size is " +
400+ fGC += " std::string err_msg = \" TMVA-SOFIE failed to read the correct tensor size; expected size is " +
374401 slength + " , read \" + std::to_string(length) ;\n " ;
375402 fGC += " throw std::runtime_error(err_msg);\n " ;
376403 fGC += " }\n " ;
@@ -382,7 +409,7 @@ namespace SOFIE{
382409 }
383410
384411 void RModel::WriteInitializedTensorsToFile (std::string filename) {
385- // write the initialized tensors in a text file
412+ // write the initialized tensors in a text file
386413 if (filename == " " ){
387414 filename = fName + " .data" ;
388415 }
0 commit comments