|
| 1 | +# Tests MarabouNetwork features not tested by it's subclasses |
| 2 | +import pytest |
| 3 | +from .. import Marabou |
| 4 | +import os |
| 5 | +import numpy as np |
| 6 | + |
| 7 | +# Global settings |
| 8 | +OPT = Marabou.createOptions(verbosity = 0) # Turn off printing |
| 9 | +TOL = 1e-6 # Set tolerance for checking Marabou evaluations |
| 10 | +NETWORK_FOLDER = "../../resources/nnet/" # Folder for test networks |
| 11 | + |
| 12 | +def test_abs_constraint(): |
| 13 | + """ |
| 14 | + Tests the absolute value constraint. |
| 15 | + Based on the acas_1_1 test, with abs constraint added to the outputs. |
| 16 | + """ |
| 17 | + filename = "acasxu/ACASXU_experimental_v2a_1_1.nnet" |
| 18 | + testInputs = [ |
| 19 | + [-0.31182839647533234, 0.0, -0.2387324146378273, -0.5, -0.4166666666666667], |
| 20 | + [-0.16247807039378703, -0.4774648292756546, -0.2387324146378273, -0.3181818181818182, -0.25], |
| 21 | + [-0.2454504737724233, -0.4774648292756546, 0.0, -0.3181818181818182, 0.0] |
| 22 | + ] |
| 23 | + testOutputs = [ |
| 24 | + [abs(0.45556007), 0.44454904, abs(0.49616356), 0.38924966, 0.50136678, abs(testInputs[0][0])], |
| 25 | + [abs(-0.02158248), -0.01885345, abs(-0.01892334), -0.01892597, -0.01893113, abs(testInputs[1][0])], |
| 26 | + [abs(0.05990158), 0.05273383, abs(0.10029709), 0.01883183, 0.10521622, abs(testInputs[2][0])] |
| 27 | + ] |
| 28 | + |
| 29 | + network = loadNetwork(filename) |
| 30 | + |
| 31 | + # Replace two output variables with their's absolute value |
| 32 | + for out in [0, 2]: |
| 33 | + abs_out = network.getNewVariable() |
| 34 | + network.addAbsConstraint(network.outputVars[0][out], abs_out) |
| 35 | + network.outputVars[0][out] = abs_out |
| 36 | + |
| 37 | + abs_inp = network.getNewVariable() |
| 38 | + network.outputVars = np.array([list(network.outputVars[0])+[abs_inp]]) |
| 39 | + network.addAbsConstraint(network.inputVars[0][0], abs_inp) |
| 40 | + |
| 41 | + evaluateNetwork(network, testInputs, testOutputs) |
| 42 | + |
| 43 | +def loadNetwork(filename): |
| 44 | + # Load network relative to this file's location |
| 45 | + filename = os.path.join(os.path.dirname(__file__), NETWORK_FOLDER, filename) |
| 46 | + return Marabou.read_nnet(filename) |
| 47 | + |
| 48 | +def evaluateNetwork(network, testInputs, testOutputs): |
| 49 | + """ |
| 50 | + Load network and evaluate testInputs with Marabou |
| 51 | + """ |
| 52 | + |
| 53 | + for testInput, testOutput in zip(testInputs, testOutputs): |
| 54 | + marabouEval = network.evaluateWithMarabou([testInput], options = OPT, filename = "").flatten() |
| 55 | + |
| 56 | + assert max(abs(marabouEval - testOutput)) < TOL |
| 57 | + return network |
| 58 | + |
0 commit comments