Skip to content

Commit 573f839

Browse files
lmonetaNeel-Shah-29
authored andcommitted
Fix issue in parsing binary operators when one input is an initialized tensor
In Add,Sub, Mul or Div one of the input can be an initialized tensor therefore we don;t have its input type registered before parsing. We need to look if the tensor is in Initilizer tensor list
1 parent 9c9dd83 commit 573f839

File tree

1 file changed

+21
-12
lines changed

1 file changed

+21
-12
lines changed

tmva/sofie_parsers/src/RModelParser_ONNX.cxx

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ std::unique_ptr<ROperator> make_ROperator(size_t idx, const onnx::GraphProto& gr
2525
}
2626
// enum EBasicBinaryOperator { Add, Sub, Mul, Div };
2727
template<EBasicBinaryOperator Op1>
28-
std::unique_ptr<ROperator> make_ROperator_BasicBinary(const onnx::NodeProto& nodeproto, const onnx::GraphProto& /*graphproto */, std::unordered_map<std::string, ETensorType>& tensor_type){
28+
std::unique_ptr<ROperator> make_ROperator_BasicBinary(const onnx::NodeProto& nodeproto, const onnx::GraphProto& graphproto, std::unordered_map<std::string, ETensorType>& tensor_type){
2929

3030
ETensorType input_type = ETensorType::UNDEFINED;
3131

@@ -38,7 +38,16 @@ std::unique_ptr<ROperator> make_ROperator_BasicBinary(const onnx::NodeProto& nod
3838
else
3939
assert(it->second == input_type);
4040
} else {
41-
throw std::runtime_error("TMVA::SOFIE ONNX Parser Binary op has input tensor" + input_name + " but its type is not yet registered");
41+
// check if input tensor is an initialized tensor
42+
bool isInitializer = false;
43+
for (int i=0; i < graphproto.initializer_size(); i++){
44+
if (input_name == graphproto.initializer(i).name()) {
45+
isInitializer = true;
46+
break;
47+
}
48+
}
49+
if (!isInitializer)
50+
throw std::runtime_error("TMVA::SOFIE ONNX Parser Binary op has input tensor " + input_name + " but its type is not yet registered");
4251
}
4352
}
4453

@@ -151,9 +160,9 @@ std::unique_ptr<ROperator> make_ROperator_LeakyRelu(const onnx::NodeProto& nodep
151160

152161
for (int_t i = 0; i < nodeproto.attribute_size(); i++) {
153162
std::string attribute_name = nodeproto.attribute(i).name();
154-
if (attribute_name == "alpha")
163+
if (attribute_name == "alpha")
155164
attr_alpha = nodeproto.attribute(i).f();
156-
}
165+
}
157166
switch(input_type){
158167
case ETensorType::FLOAT:
159168
op.reset(new ROperator_LeakyRelu<float>(attr_alpha,nodeproto.input(0), nodeproto.output(0)));
@@ -476,7 +485,7 @@ std::unique_ptr<ROperator> make_ROperator_Pool(const onnx::NodeProto& nodeproto,
476485
RAttributes_Pool attr;
477486
// std::string attr_auto_pad = "NOTSET";
478487
// int attr_ceil_mode = 0;
479-
// int attr_count_include_pad = 0;
488+
// int attr_count_include_pad = 0;
480489
// int attr_storage_order = 0; // not for AveragePool
481490
// std::vector<size_t> attr_dilations; // not for AveragePool
482491
// std::vector<size_t> attr_kernel_shape;
@@ -532,22 +541,22 @@ std::unique_ptr<ROperator> make_ROperator_Reshape(const onnx::NodeProto &nodepro
532541
// make Reshape operator
533542
ETensorType input_type = ETensorType::UNDEFINED;
534543

535-
544+
536545
ReshapeOpMode opMode = Reshape;
537-
if (nodeproto.op_type() == "Flatten")
546+
if (nodeproto.op_type() == "Flatten")
538547
opMode = Flatten;
539-
else if (nodeproto.op_type() == "Squeeze")
548+
else if (nodeproto.op_type() == "Squeeze")
540549
opMode = Squeeze;
541550
else if (nodeproto.op_type() == "Unsqueeze")
542551
opMode = Unsqueeze;
543552

544-
553+
545554
//bool hasShapeInput = (opMode == Reshape) ? true : false;
546555

547-
// reshape has as extra input shape tensor (int64) but
556+
// reshape has as extra input shape tensor (int64) but
548557
// it is not present for Flatten, Squeeze and Unsquueze
549558
auto input_name = nodeproto.input(0);
550-
// for squeeze is optional ?
559+
// for squeeze is optional ?
551560
auto shape_name = (opMode == Reshape || opMode == Unsqueeze) ? nodeproto.input(1) : "";
552561
auto it = tensor_type.find(input_name);
553562
if (it != tensor_type.end()) {
@@ -561,7 +570,7 @@ std::unique_ptr<ROperator> make_ROperator_Reshape(const onnx::NodeProto &nodepro
561570
// Flatten is having one attribute: axis (int) (default=1)
562571
// old version of reshape and squeeze have axes as attributes
563572
std::unique_ptr<ROperator> op;
564-
int attr_value = (opMode == Reshape) ? 0 : 1;
573+
int attr_value = (opMode == Reshape) ? 0 : 1;
565574
if (opMode == Reshape && nodeproto.attribute_size() > 0 )
566575
attr_value = nodeproto.attribute(0).i();
567576

0 commit comments

Comments
 (0)