Skip to content

Commit 70ea47d

Browse files
authored
Support Sub of two variables and Slice in Python ONNX (NeuralNetworkVerification#872)
* expand python onnx parser with slice * update python onnx parser with slice and sub * update readme * Update ONNXParser.py * add constantOfShape * change log
1 parent 4b4c25f commit 70ea47d

File tree

6 files changed

+100
-12
lines changed

6 files changed

+100
-12
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
- Fixed bug in the parsing of `transpose` nodes in command line C++ parser.
77
- Implemented forward-backward abstract interpretation, symbolic bound tightening, interval arithmetic and simulations for all activation functions.
88
- Added the BaBSR heuristic as a new branching strategy for ReLU Splitting
9+
- Support Sub of two variables, "Mul" of two constants, Slice, and ConstantOfShape in the python onnx parser
910

1011
## Version 2.0.0
1112

maraboupy/parsers/ONNXParser.py

Lines changed: 70 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from onnx import TensorProto
2424
import itertools
2525
from copy import copy
26-
from onnx.reference.ops._op_list import Split_18, Unsqueeze_1
26+
from onnx.reference.ops._op_list import Split_18, Unsqueeze_1, Slice_10
2727

2828
class ONNXParser:
2929
"""
@@ -167,6 +167,8 @@ def makeMarabouEquations(self, nodeName, makeEquations):
167167
self.dropout(node)
168168
elif node.op_type == 'Cast':
169169
self.cast(node)
170+
elif node.op_type == 'ConstantOfShape':
171+
self.constantOfShape(node)
170172
elif node.op_type == 'Reshape':
171173
self.reshape(node)
172174
elif node.op_type == 'Flatten':
@@ -177,6 +179,8 @@ def makeMarabouEquations(self, nodeName, makeEquations):
177179
self.unsqueeze(node)
178180
elif node.op_type == 'Squeeze':
179181
self.squeeze(node)
182+
elif node.op_type == "Slice":
183+
self.slice(node)
180184
elif node.op_type == "BatchNormalization":
181185
self.batchNorm(node, makeEquations)
182186
elif node.op_type == 'Concat':
@@ -282,6 +286,24 @@ def constant(self, node):
282286
return
283287
raise RuntimeError("Could not find value of tensor constant")
284288

289+
def constantOfShape(self, node):
290+
"""Function representing a constant tensor of shape
291+
292+
Args:
293+
node (node): ONNX node representing constantOfShape operation
294+
295+
:meta private:
296+
"""
297+
nodeName = node.output[0]
298+
inputName = node.input[0]
299+
for attr in node.attribute:
300+
if attr.name == "value":
301+
value = numpy_helper.to_array(get_attribute_value(attr))
302+
assert inputName in self.constantMap
303+
shape = self.constantMap[inputName]
304+
self.constantMap[nodeName] = np.broadcast_to(value, shape)
305+
self.shapeMap[nodeName] = shape
306+
285307
def identity(self, node):
286308
"""Function representing identity
287309
@@ -440,6 +462,23 @@ def transpose(self, node):
440462
elif inputName in self.constantMap:
441463
self.constantMap[nodeName] = np.transpose(self.constantMap[inputName], perm)
442464

465+
def slice(self, node):
466+
nodeName = node.output[0]
467+
inputName = node.input[0]
468+
starts = self.constantMap[node.input[1]]
469+
ends = self.constantMap[node.input[2]]
470+
axes = self.constantMap[node.input[3]]
471+
steps = self.constantMap[node.input[4]]
472+
473+
if inputName in self.varMap:
474+
output_data = Slice_10.eval(self.varMap[inputName], starts=starts, ends=ends, axes=axes, steps=steps)
475+
self.shapeMap[nodeName] = output_data.shape
476+
self.varMap[nodeName] = output_data
477+
else:
478+
output_data = Slice_10.eval(self.constantMap[inputName], starts=starts, ends=ends, axes=axes, steps=steps)
479+
self.shapeMap[nodeName] = output_data.shape
480+
self.constantMap[nodeName] = output_data
481+
443482
def unsqueeze(self, node):
444483
"""Function representing unsqueeze
445484
@@ -1055,17 +1094,21 @@ def mulEquations(self, node, makeEquations):
10551094
return
10561095

10571096
multiple = self.constantMap[inputName2]
1058-
input1 = self.varMap[inputName1]
1059-
outputVariables = self.makeNewVariables(nodeName)
1060-
input1 = input1.reshape(-1)
1061-
outputVariables = outputVariables.reshape(-1)
1097+
if inputName1 in self.constantMap:
1098+
input1 = self.constantMap[inputName1]
1099+
self.constantMap[nodeName] = input1 * multiple
1100+
else:
1101+
input1 = self.varMap[inputName1]
1102+
outputVariables = self.makeNewVariables(nodeName)
1103+
input1 = input1.reshape(-1)
1104+
outputVariables = outputVariables.reshape(-1)
10621105

1063-
for i in range(len(input1)):
1064-
e = MarabouUtils.Equation()
1065-
e.addAddend(multiple, input1[i])
1066-
e.addAddend(-1, outputVariables[i])
1067-
e.setScalar(0.0)
1068-
self.query.addEquation(e)
1106+
for i in range(len(input1)):
1107+
e = MarabouUtils.Equation()
1108+
e.addAddend(multiple, input1[i])
1109+
e.addAddend(-1, outputVariables[i])
1110+
e.setScalar(0.0)
1111+
self.query.addEquation(e)
10691112
return
10701113

10711114
def addEquations(self, node, makeEquations):
@@ -1245,7 +1288,22 @@ def subEquations(self, node, makeEquations):
12451288
if not makeEquations:
12461289
return
12471290

1248-
assert inputName1 in self.varMap and inputName2 in self.constantMap
1291+
assert inputName1 in self.varMap and (inputName2 in self.constantMap or inputName2 in self.varMap)
1292+
1293+
# the difference between the two variables
1294+
if inputName1 in self.varMap and inputName2 in self.varMap:
1295+
outputVariables = self.makeNewVariables(nodeName)
1296+
input1 = self.varMap[inputName1].reshape(-1)
1297+
input2 = self.varMap[inputName2].reshape(-1)
1298+
outputVariables = outputVariables.reshape(-1)
1299+
for i in range(len(input1)):
1300+
e = MarabouUtils.Equation()
1301+
e.addAddend(1, input1[i])
1302+
e.addAddend(-1, input2[i])
1303+
e.addAddend(-1, outputVariables[i])
1304+
e.setScalar(0.0)
1305+
self.query.addEquation(e)
1306+
return
12491307

12501308
# Get variables
12511309
inputVars = self.varMap[inputName1].reshape(-1)

maraboupy/test/test_onnx.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,18 @@ def test_split_onnx():
3939
os.remove(presplit_filename)
4040
os.remove(postsplit_filename)
4141

42+
def test_slice_var():
43+
filename = "slice_test.onnx"
44+
evaluateFile(filename)
45+
46+
def test_slice_constant():
47+
filename = "constant_slice_test.onnx"
48+
evaluateFile(filename)
49+
50+
def test_sub_var():
51+
filename = "sub_var_test.onnx"
52+
evaluateFile(filename)
53+
4254
def test_sub():
4355
filename = "test_sub.onnx"
4456
evaluateFile(filename)
156 Bytes
Binary file not shown.

resources/onnx/slice_test.onnx

887 Bytes
Binary file not shown.

resources/onnx/sub_var_test.onnx

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
pytorch2.6.0:
2+
#
3+
input1
4+
input2output/Sub"Sub
5+
main_graphZ
6+
input1
7+

8+

9+
Z
10+
input2
11+

12+

13+
b
14+
output
15+

16+

17+
B

0 commit comments

Comments
 (0)