2323from onnx import TensorProto
2424import itertools
2525from copy import copy
26- from onnx .reference .ops ._op_list import Split_18 , Unsqueeze_1
26+ from onnx .reference .ops ._op_list import Split_18 , Unsqueeze_1 , Slice_10
2727
2828class ONNXParser :
2929 """
@@ -167,6 +167,8 @@ def makeMarabouEquations(self, nodeName, makeEquations):
167167 self .dropout (node )
168168 elif node .op_type == 'Cast' :
169169 self .cast (node )
170+ elif node .op_type == 'ConstantOfShape' :
171+ self .constantOfShape (node )
170172 elif node .op_type == 'Reshape' :
171173 self .reshape (node )
172174 elif node .op_type == 'Flatten' :
@@ -177,6 +179,8 @@ def makeMarabouEquations(self, nodeName, makeEquations):
177179 self .unsqueeze (node )
178180 elif node .op_type == 'Squeeze' :
179181 self .squeeze (node )
182+ elif node .op_type == "Slice" :
183+ self .slice (node )
180184 elif node .op_type == "BatchNormalization" :
181185 self .batchNorm (node , makeEquations )
182186 elif node .op_type == 'Concat' :
@@ -282,6 +286,24 @@ def constant(self, node):
282286 return
283287 raise RuntimeError ("Could not find value of tensor constant" )
284288
289+ def constantOfShape (self , node ):
290+ """Function representing a constant tensor of shape
291+
292+ Args:
293+ node (node): ONNX node representing constantOfShape operation
294+
295+ :meta private:
296+ """
297+ nodeName = node .output [0 ]
298+ inputName = node .input [0 ]
299+ for attr in node .attribute :
300+ if attr .name == "value" :
301+ value = numpy_helper .to_array (get_attribute_value (attr ))
302+ assert inputName in self .constantMap
303+ shape = self .constantMap [inputName ]
304+ self .constantMap [nodeName ] = np .broadcast_to (value , shape )
305+ self .shapeMap [nodeName ] = shape
306+
285307 def identity (self , node ):
286308 """Function representing identity
287309
@@ -440,6 +462,23 @@ def transpose(self, node):
440462 elif inputName in self .constantMap :
441463 self .constantMap [nodeName ] = np .transpose (self .constantMap [inputName ], perm )
442464
465+ def slice (self , node ):
466+ nodeName = node .output [0 ]
467+ inputName = node .input [0 ]
468+ starts = self .constantMap [node .input [1 ]]
469+ ends = self .constantMap [node .input [2 ]]
470+ axes = self .constantMap [node .input [3 ]]
471+ steps = self .constantMap [node .input [4 ]]
472+
473+ if inputName in self .varMap :
474+ output_data = Slice_10 .eval (self .varMap [inputName ], starts = starts , ends = ends , axes = axes , steps = steps )
475+ self .shapeMap [nodeName ] = output_data .shape
476+ self .varMap [nodeName ] = output_data
477+ else :
478+ output_data = Slice_10 .eval (self .constantMap [inputName ], starts = starts , ends = ends , axes = axes , steps = steps )
479+ self .shapeMap [nodeName ] = output_data .shape
480+ self .constantMap [nodeName ] = output_data
481+
443482 def unsqueeze (self , node ):
444483 """Function representing unsqueeze
445484
@@ -1055,17 +1094,21 @@ def mulEquations(self, node, makeEquations):
10551094 return
10561095
10571096 multiple = self .constantMap [inputName2 ]
1058- input1 = self .varMap [inputName1 ]
1059- outputVariables = self .makeNewVariables (nodeName )
1060- input1 = input1 .reshape (- 1 )
1061- outputVariables = outputVariables .reshape (- 1 )
1097+ if inputName1 in self .constantMap :
1098+ input1 = self .constantMap [inputName1 ]
1099+ self .constantMap [nodeName ] = input1 * multiple
1100+ else :
1101+ input1 = self .varMap [inputName1 ]
1102+ outputVariables = self .makeNewVariables (nodeName )
1103+ input1 = input1 .reshape (- 1 )
1104+ outputVariables = outputVariables .reshape (- 1 )
10621105
1063- for i in range (len (input1 )):
1064- e = MarabouUtils .Equation ()
1065- e .addAddend (multiple , input1 [i ])
1066- e .addAddend (- 1 , outputVariables [i ])
1067- e .setScalar (0.0 )
1068- self .query .addEquation (e )
1106+ for i in range (len (input1 )):
1107+ e = MarabouUtils .Equation ()
1108+ e .addAddend (multiple , input1 [i ])
1109+ e .addAddend (- 1 , outputVariables [i ])
1110+ e .setScalar (0.0 )
1111+ self .query .addEquation (e )
10691112 return
10701113
10711114 def addEquations (self , node , makeEquations ):
@@ -1245,7 +1288,22 @@ def subEquations(self, node, makeEquations):
12451288 if not makeEquations :
12461289 return
12471290
1248- assert inputName1 in self .varMap and inputName2 in self .constantMap
1291+ assert inputName1 in self .varMap and (inputName2 in self .constantMap or inputName2 in self .varMap )
1292+
1293+ # the difference between the two variables
1294+ if inputName1 in self .varMap and inputName2 in self .varMap :
1295+ outputVariables = self .makeNewVariables (nodeName )
1296+ input1 = self .varMap [inputName1 ].reshape (- 1 )
1297+ input2 = self .varMap [inputName2 ].reshape (- 1 )
1298+ outputVariables = outputVariables .reshape (- 1 )
1299+ for i in range (len (input1 )):
1300+ e = MarabouUtils .Equation ()
1301+ e .addAddend (1 , input1 [i ])
1302+ e .addAddend (- 1 , input2 [i ])
1303+ e .addAddend (- 1 , outputVariables [i ])
1304+ e .setScalar (0.0 )
1305+ self .query .addEquation (e )
1306+ return
12491307
12501308 # Get variables
12511309 inputVars = self .varMap [inputName1 ].reshape (- 1 )
0 commit comments