@@ -25,7 +25,7 @@ std::unique_ptr<ROperator> make_ROperator(size_t idx, const onnx::GraphProto& gr
2525}
2626// enum EBasicBinaryOperator { Add, Sub, Mul, Div };
2727template <EBasicBinaryOperator Op1>
28- std::unique_ptr<ROperator> make_ROperator_BasicBinary (const onnx::NodeProto& nodeproto, const onnx::GraphProto& /* graphproto */ , std::unordered_map<std::string, ETensorType>& tensor_type){
28+ std::unique_ptr<ROperator> make_ROperator_BasicBinary (const onnx::NodeProto& nodeproto, const onnx::GraphProto& graphproto, std::unordered_map<std::string, ETensorType>& tensor_type){
2929
3030 ETensorType input_type = ETensorType::UNDEFINED;
3131
@@ -38,7 +38,16 @@ std::unique_ptr<ROperator> make_ROperator_BasicBinary(const onnx::NodeProto& nod
3838 else
3939 assert (it->second == input_type);
4040 } else {
41- throw std::runtime_error (" TMVA::SOFIE ONNX Parser Binary op has input tensor" + input_name + " but its type is not yet registered" );
41+ // check if input tensor is an initialized tensor
42+ bool isInitializer = false ;
43+ for (int i=0 ; i < graphproto.initializer_size (); i++){
44+ if (input_name == graphproto.initializer (i).name ()) {
45+ isInitializer = true ;
46+ break ;
47+ }
48+ }
49+ if (!isInitializer)
50+ throw std::runtime_error (" TMVA::SOFIE ONNX Parser Binary op has input tensor " + input_name + " but its type is not yet registered" );
4251 }
4352 }
4453
@@ -151,9 +160,9 @@ std::unique_ptr<ROperator> make_ROperator_LeakyRelu(const onnx::NodeProto& nodep
151160
152161 for (int_t i = 0 ; i < nodeproto.attribute_size (); i++) {
153162 std::string attribute_name = nodeproto.attribute (i).name ();
154- if (attribute_name == " alpha" )
163+ if (attribute_name == " alpha" )
155164 attr_alpha = nodeproto.attribute (i).f ();
156- }
165+ }
157166 switch (input_type){
158167 case ETensorType::FLOAT:
159168 op.reset (new ROperator_LeakyRelu<float >(attr_alpha,nodeproto.input (0 ), nodeproto.output (0 )));
@@ -476,7 +485,7 @@ std::unique_ptr<ROperator> make_ROperator_Pool(const onnx::NodeProto& nodeproto,
476485 RAttributes_Pool attr;
477486 // std::string attr_auto_pad = "NOTSET";
478487 // int attr_ceil_mode = 0;
479- // int attr_count_include_pad = 0;
488+ // int attr_count_include_pad = 0;
480489 // int attr_storage_order = 0; // not for AveragePool
481490 // std::vector<size_t> attr_dilations; // not for AveragePool
482491 // std::vector<size_t> attr_kernel_shape;
@@ -532,22 +541,22 @@ std::unique_ptr<ROperator> make_ROperator_Reshape(const onnx::NodeProto &nodepro
532541 // make Reshape operator
533542 ETensorType input_type = ETensorType::UNDEFINED;
534543
535-
544+
536545 ReshapeOpMode opMode = Reshape;
537- if (nodeproto.op_type () == " Flatten" )
546+ if (nodeproto.op_type () == " Flatten" )
538547 opMode = Flatten;
539- else if (nodeproto.op_type () == " Squeeze" )
548+ else if (nodeproto.op_type () == " Squeeze" )
540549 opMode = Squeeze;
541550 else if (nodeproto.op_type () == " Unsqueeze" )
542551 opMode = Unsqueeze;
543552
544-
553+
545554 // bool hasShapeInput = (opMode == Reshape) ? true : false;
546555
547- // reshape has as extra input shape tensor (int64) but
556+ // reshape has as extra input shape tensor (int64) but
548557 // it is not present for Flatten, Squeeze and Unsquueze
549558 auto input_name = nodeproto.input (0 );
550- // for squeeze is optional ?
559+ // for squeeze is optional ?
551560 auto shape_name = (opMode == Reshape || opMode == Unsqueeze) ? nodeproto.input (1 ) : " " ;
552561 auto it = tensor_type.find (input_name);
553562 if (it != tensor_type.end ()) {
@@ -561,7 +570,7 @@ std::unique_ptr<ROperator> make_ROperator_Reshape(const onnx::NodeProto &nodepro
561570 // Flatten is having one attribute: axis (int) (default=1)
562571 // old version of reshape and squeeze have axes as attributes
563572 std::unique_ptr<ROperator> op;
564- int attr_value = (opMode == Reshape) ? 0 : 1 ;
573+ int attr_value = (opMode == Reshape) ? 0 : 1 ;
565574 if (opMode == Reshape && nodeproto.attribute_size () > 0 )
566575 attr_value = nodeproto.attribute (0 ).i ();
567576
0 commit comments