Skip to content

Commit adbe073

Browse files
vbCrLfGuyKatzHujiguykatzz
authored
Add abs constraint support to maraboupy (#327)
* Add abs constraint support to maraboupy * Add test for abs constraint * Fix test * Fix test * remove a too-strong assertion * attempt to fix the windows bug * oops * assertion Co-authored-by: Guy Katz <guykatz@cs.huji.ac.il> Co-authored-by: Guy <katz911@gmail.com>
1 parent 8b42f74 commit adbe073

File tree

6 files changed

+110
-16
lines changed

6 files changed

+110
-16
lines changed

maraboupy/MarabouCore.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,10 @@ void addMaxConstraint(InputQuery& ipq, std::set<unsigned> elements, unsigned v){
9696
ipq.addPiecewiseLinearConstraint(m);
9797
}
9898

99+
void addAbsConstraint(InputQuery& ipq, unsigned b, unsigned f){
100+
ipq.addPiecewiseLinearConstraint(new AbsoluteValueConstraint(b, f));
101+
}
102+
99103
void createInputQuery(InputQuery &inputQuery, std::string networkFilePath, std::string propertyFilePath){
100104
AcasParser* acasParser = new AcasParser( String(networkFilePath) );
101105
acasParser->generateQuery( inputQuery );
@@ -265,6 +269,15 @@ PYBIND11_MODULE(MarabouCore, m) {
265269
v (int): Output variable from max constraint
266270
)pbdoc",
267271
py::arg("inputQuery"), py::arg("elements"), py::arg("v"));
272+
m.def("addAbsConstraint", &addAbsConstraint, R"pbdoc(
273+
Add an Abs constraint to the InputQuery
274+
275+
Args:
276+
inputQuery (:class:`~maraboupy.MarabouCore.InputQuery`): Marabou input query to be solved
277+
b (int): Input variable
278+
f (int): Output variable
279+
)pbdoc",
280+
py::arg("inputQuery"), py::arg("b"), py::arg("f"));
268281
py::class_<InputQuery>(m, "InputQuery")
269282
.def(py::init())
270283
.def("setUpperBound", &InputQuery::setUpperBound)

maraboupy/MarabouNetwork.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ class MarabouNetwork:
2626
equList (list of :class:`~maraboupy.MarabouUtils.Equation`): Network equations
2727
reluList (list of tuples): List of relu constraint tuples, where each tuple contains the backward and forward variables
2828
maxList (list of tuples): List of max constraint tuples, where each tuple conatins the set of input variables and output variable
29+
absList (list of tuples): List of abs constraint tuples, where each tuple conatins the input variable and the output variable
2930
varsParticipatingInConstraints (set of int): Variables involved in some constraint
3031
lowerBounds (Dict[int, float]): Lower bounds of variables
3132
upperBounds (Dict[int, float]): Upper bounds of variables
@@ -44,6 +45,7 @@ def clear(self):
4445
self.equList = []
4546
self.reluList = []
4647
self.maxList = []
48+
self.absList = []
4749
self.varsParticipatingInConstraints = set()
4850
self.lowerBounds = dict()
4951
self.upperBounds = dict()
@@ -110,6 +112,17 @@ def addMaxConstraint(self, elements, v):
110112
for i in elements:
111113
self.varsParticipatingInConstraints.add(i)
112114

115+
def addAbsConstraint(self, b, f):
116+
"""Function to add a new Abs constraint
117+
118+
Args:
119+
b (int): Variable representing input of the Abs constraint
120+
f (int): Variable representing output of the Abs constraint
121+
"""
122+
self.absList += [(b, f)]
123+
self.varsParticipatingInConstraints.add(b)
124+
self.varsParticipatingInConstraints.add(f)
125+
113126
def lowerBoundExists(self, x):
114127
"""Function to check whether lower bound for a variable is known
115128
@@ -209,6 +222,9 @@ def getMarabouQuery(self):
209222
assert e < self.numVars
210223
MarabouCore.addMaxConstraint(ipq, m[0], m[1])
211224

225+
for b, f in self.absList:
226+
MarabouCore.addAbsConstraint(ipq, b, f)
227+
212228
for l in self.lowerBounds:
213229
assert l < self.numVars
214230
ipq.setLowerBound(l, self.lowerBounds[l])

maraboupy/test/test_network.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# Tests MarabouNetwork features not tested by it's subclasses
2+
import pytest
3+
from .. import Marabou
4+
import os
5+
import numpy as np
6+
7+
# Global settings
8+
OPT = Marabou.createOptions(verbosity = 0) # Turn off printing
9+
TOL = 1e-6 # Set tolerance for checking Marabou evaluations
10+
NETWORK_FOLDER = "../../resources/nnet/" # Folder for test networks
11+
12+
def test_abs_constraint():
13+
"""
14+
Tests the absolute value constraint.
15+
Based on the acas_1_1 test, with abs constraint added to the outputs.
16+
"""
17+
filename = "acasxu/ACASXU_experimental_v2a_1_1.nnet"
18+
testInputs = [
19+
[-0.31182839647533234, 0.0, -0.2387324146378273, -0.5, -0.4166666666666667],
20+
[-0.16247807039378703, -0.4774648292756546, -0.2387324146378273, -0.3181818181818182, -0.25],
21+
[-0.2454504737724233, -0.4774648292756546, 0.0, -0.3181818181818182, 0.0]
22+
]
23+
testOutputs = [
24+
[abs(0.45556007), 0.44454904, abs(0.49616356), 0.38924966, 0.50136678, abs(testInputs[0][0])],
25+
[abs(-0.02158248), -0.01885345, abs(-0.01892334), -0.01892597, -0.01893113, abs(testInputs[1][0])],
26+
[abs(0.05990158), 0.05273383, abs(0.10029709), 0.01883183, 0.10521622, abs(testInputs[2][0])]
27+
]
28+
29+
network = loadNetwork(filename)
30+
31+
# Replace two output variables with their's absolute value
32+
for out in [0, 2]:
33+
abs_out = network.getNewVariable()
34+
network.addAbsConstraint(network.outputVars[0][out], abs_out)
35+
network.outputVars[0][out] = abs_out
36+
37+
abs_inp = network.getNewVariable()
38+
network.outputVars = np.array([list(network.outputVars[0])+[abs_inp]])
39+
network.addAbsConstraint(network.inputVars[0][0], abs_inp)
40+
41+
evaluateNetwork(network, testInputs, testOutputs)
42+
43+
def loadNetwork(filename):
44+
# Load network relative to this file's location
45+
filename = os.path.join(os.path.dirname(__file__), NETWORK_FOLDER, filename)
46+
return Marabou.read_nnet(filename)
47+
48+
def evaluateNetwork(network, testInputs, testOutputs):
49+
"""
50+
Load network and evaluate testInputs with Marabou
51+
"""
52+
53+
for testInput, testOutput in zip(testInputs, testOutputs):
54+
marabouEval = network.evaluateWithMarabou([testInput], options = OPT, filename = "").flatten()
55+
56+
assert max(abs(marabouEval - testOutput)) < TOL
57+
return network
58+

src/basis_factorization/CSRMatrix.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,8 +104,6 @@ void CSRMatrix::initializeToEmpty( unsigned m, unsigned n )
104104

105105
void CSRMatrix::increaseCapacity()
106106
{
107-
ASSERT( _m > 0 && _n > 0 );
108-
109107
unsigned estimatedNumRowEntries = std::max( 2U, _n / ROW_DENSITY_ESTIMATE );
110108
unsigned newEstimatedNnz = _estimatedNnz + ( estimatedNumRowEntries * _m );
111109

src/engine/AbsoluteValueConstraint.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,7 @@ PiecewiseLinearCaseSplit AbsoluteValueConstraint::getValidCaseSplit() const
267267
void AbsoluteValueConstraint::eliminateVariable( unsigned variable, double /* fixedValue */ )
268268
{
269269
(void)variable;
270-
ASSERT( variable = _b );
270+
ASSERT( ( variable == _f ) || ( variable == _b ) );
271271

272272
// In an absolute value constraint, if a variable is removed the
273273
// entire constraint can be discarded

src/engine/Preprocessor.cpp

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -529,12 +529,9 @@ void Preprocessor::collectFixedValues()
529529
usedVariables.insert( merged.first );
530530

531531
// Collect any variables with identical lower and upper bounds, or
532-
// which are unused, unless they are input/output variables
532+
// which are unused
533533
for ( unsigned i = 0; i < _preprocessed.getNumberOfVariables(); ++i )
534534
{
535-
if ( _inputOutputVariables.exists( i ) )
536-
continue;
537-
538535
if ( FloatUtils::areEqual( _preprocessed.getLowerBound( i ), _preprocessed.getUpperBound( i ) ) )
539536
{
540537
_fixedVariables[i] = _preprocessed.getLowerBound( i );
@@ -613,19 +610,30 @@ void Preprocessor::eliminateVariables()
613610
}
614611
}
615612

616-
// Inform the NLR about eliminated varibales
617-
if ( _preprocessed._networkLevelReasoner )
618-
{
619-
for ( const auto &fixed : _fixedVariables )
620-
_preprocessed._networkLevelReasoner->eliminateVariable( fixed.first, fixed.second );
621-
}
613+
// Inform the NLR about eliminated varibales, unless they are
614+
// input/output variables
615+
if ( _preprocessed._networkLevelReasoner )
616+
{
617+
for ( const auto &fixed : _fixedVariables )
618+
{
619+
if ( _inputOutputVariables.exists( fixed.first ) )
620+
continue;
621+
622+
_preprocessed._networkLevelReasoner->eliminateVariable( fixed.first, fixed.second );
623+
}
624+
}
622625

623626
// Compute the new variable indices, after the elimination of fixed variables
624627
int offset = 0;
628+
unsigned numEliminated = 0;
625629
for ( unsigned i = 0; i < _preprocessed.getNumberOfVariables(); ++i )
626630
{
627-
if ( _fixedVariables.exists( i ) || _mergedVariables.exists( i ) )
631+
if ( ( _fixedVariables.exists( i ) || _mergedVariables.exists( i ) ) &&
632+
!_inputOutputVariables.exists( i ) )
633+
{
634+
++numEliminated;
628635
++offset;
636+
}
629637
else
630638
_oldIndexToNewIndex[i] = i - offset;
631639
}
@@ -720,7 +728,8 @@ void Preprocessor::eliminateVariables()
720728
// Update the lower/upper bound maps
721729
for ( unsigned i = 0; i < _preprocessed.getNumberOfVariables(); ++i )
722730
{
723-
if ( _fixedVariables.exists( i ) || _mergedVariables.exists( i ) )
731+
if ( ( _fixedVariables.exists( i ) || _mergedVariables.exists( i ) ) &&
732+
!_inputOutputVariables.exists( i ) )
724733
continue;
725734

726735
ASSERT( _oldIndexToNewIndex.at( i ) <= i );
@@ -748,7 +757,7 @@ void Preprocessor::eliminateVariables()
748757
}
749758

750759
// Adjust the number of variables in the query
751-
_preprocessed.setNumberOfVariables( _preprocessed.getNumberOfVariables() - _fixedVariables.size() - _mergedVariables.size() );
760+
_preprocessed.setNumberOfVariables( _preprocessed.getNumberOfVariables() - numEliminated );
752761

753762
// Adjust the input/output mappings in the query
754763
_preprocessed.adjustInputOutputMapping( _oldIndexToNewIndex, _mergedVariables );

0 commit comments

Comments
 (0)