11import logging
22import time
33from time import perf_counter
4- from typing import Dict
4+ from typing import Dict , Union
55
66import numpy as np
77from airobas .verif_pipeline import (
1414from keras .layers import Activation , Dense
1515from keras .models import Sequential , clone_model
1616from maraboupy import Marabou , MarabouCore
17+ from maraboupy .MarabouNetworkONNX import MarabouNetworkONNX
1718from maraboupy .MarabouNetwork import MarabouNetwork # (pip install maraboupy)
18-
19+ from termcolor import colored
1920logger = logging .getLogger (__name__ )
2021
2122output_name = "OUTPUT"
@@ -156,6 +157,7 @@ def buildEquations(self, index_layer, update_relu=True):
156157 if update_relu :
157158 self .add_relu (index_layer )
158159 else :
160+ print (type (layer ))
159161 raise NotImplemented (layer )
160162
161163 def get_output_layer (self , index_layer ):
@@ -229,8 +231,8 @@ def add_relu(self, index_layer):
229231 def get_output_dim (self ):
230232 global output_name
231233 return len (self .varMap [output_name ])
232-
233- def solve_query (self , options = None ):
234+
235+ def solve_query (self ,options = None ):
234236 if options is None :
235237 result = self .solve (verbose = False )
236238 else :
@@ -251,8 +253,33 @@ def solve_query(self, options=None):
251253 output_sat ,
252254 )
253255
256+ def solve_query (network ,options = None ):
257+ if options is None :
258+ result = network .solve (verbose = False )
259+ else :
260+ result = network .solve (verbose = False , options = options )
261+ input_sat = None
262+ output_sat = None
263+
264+ if result [0 ] == "sat" :
265+ n_in = len (network .inputVars [0 ][0 ])
266+ n_out = len (network .outputVars [0 ][0 ])
267+ input_sat = np .array ([result [1 ][network .inputVars [0 ][0 ][i ]] for i in range (n_in )])
268+ output_sat = np .array ([result [1 ][network .outputVars [0 ][0 ][i ]] for i in range (n_out )])
269+ if result [0 ] == "TIMEOUT" :
270+ logger .info (f"Time out !" )
271+ return (
272+ [result [0 ] == "sat" , result [0 ] == "unsat" , result [0 ] == "TIMEOUT" ],
273+ input_sat ,
274+ output_sat ,
275+ )
276+
277+ def solve_stability_property (network : Union [MarabouSequential , MarabouNetworkONNX ], x_min , x_max , y_min , y_max , options = None ):
278+ if isinstance (network ,MarabouSequential ):
279+ output_dim = network .get_output_dim ()
280+ elif isinstance (network ,MarabouNetworkONNX ):
281+ output_dim = len (network .outputVars [0 ][0 ])
254282
255- def solve_stability_property (network : MarabouSequential , x_min , x_max , y_min , y_max , timeout = 0 ):
256283 t_init = time .perf_counter ()
257284 # Set Lower and Upper bound for the input perturbation
258285 for i , x_min_i in enumerate (x_min ):
@@ -262,7 +289,8 @@ def solve_stability_property(network: MarabouSequential, x_min, x_max, y_min, y_
262289 # find a sample that is either greater than Y_max or lower than Y_min
263290 equ_list = []
264291
265- for i in range (network .get_output_dim ()):
292+ for i in range (output_dim ):
293+ #print(f"Old-Encoding\nmax diff inputs bounds: {np.max(x_max-x_min)}\n output lowe {y_min[i]}, output upper {y_max[i]}")
266294 if np .isinf (y_min [i ]) or np .isinf (y_max [i ]):
267295 continue
268296 equ_l = MarabouCore .Equation (MarabouCore .Equation .LE ) # greater or equal >= scalar
@@ -277,19 +305,18 @@ def solve_stability_property(network: MarabouSequential, x_min, x_max, y_min, y_
277305 equ_list .append ([equ_u ]) # one disjunction
278306
279307 network .addDisjunctionConstraint (equ_list )
280- t_end_init = time .perf_counter ()
281- options = None
282- if timeout :
283- options = Marabou .createOptions (timeoutInSeconds = int (timeout ), verbosity = 0 )
284- else :
285- options = Marabou .createOptions (verbosity = 0 )
286- result = network .solve_query (options = options )
308+ if isinstance (network ,MarabouSequential ):
309+ t_end_init = time .perf_counter ()
310+ result = network .solve_query (options )
311+ elif isinstance (network ,MarabouNetworkONNX ):
312+ t_end_init = time .perf_counter ()
313+ result = solve_query (network ,options )
287314 t_end_solve = time .perf_counter ()
288315 network .clearProperty ()
289316 network .disjunctionList = []
317+ #print(f'marabou solve: {result[0]}')
290318 return result , (t_init , t_end_init , t_end_solve )
291319
292-
293320class MarabouBlock (BlockVerif ):
294321 def __init__ (
295322 self ,
@@ -298,7 +325,22 @@ def __init__(
298325 ** kwargs ,
299326 ):
300327 super ().__init__ (problem_container = problem_container , data_container = data_container )
301- self .options = kwargs
328+ # Initialize self.marabou_NetworkONNX
329+ if 'marabou_ONNX' in kwargs :
330+ self .marabou_NetworkONNX = kwargs ['marabou_ONNX' ]
331+ kwargs .pop ('marabou_ONNX' )
332+ else :
333+ self .marabou_NetworkONNX = None
334+ # Initialize self.options by passing the collected kwargs to Marabou.createOption
335+ self .options = Marabou .createOptions (** kwargs )
336+
337+
338+ def display_options (self ):
339+ """Helper method to display current Marabou options."""
340+ print ("\n --- Current Marabou Options ---" )
341+ for key , value in self .options .items ():
342+ print (f" { key } : { value } " )
343+ print ("-------------------------------" )
302344
303345 def verif (self , indexes : np .ndarray ) -> BlockVerifOutput :
304346 nb_points = len (indexes )
@@ -310,26 +352,32 @@ def verif(self, indexes: np.ndarray) -> BlockVerifOutput:
310352 init_time_per_sample = np .empty (nb_points , dtype = float ),
311353 verif_time_per_sample = np .empty (nb_points , dtype = float ),
312354 )
313- t1 = perf_counter ()
314- network = MarabouSequential (model = self .problem_container .model )
315- t2 = perf_counter ()
316- output .build_time = t2 - t1
355+ if self .marabou_NetworkONNX is None :
356+ t1 = perf_counter ()
357+ network = MarabouSequential (model = self .problem_container .model )
358+ t2 = perf_counter ()
359+ output .build_time = t2 - t1
360+ else :
361+ network = self .marabou_NetworkONNX
362+ output .build_time = 0
317363 x_min = self .data_container .lbound_input_points [indexes , :]
318364 x_max = self .data_container .ubound_input_points [indexes , :]
319365 y_min = self .data_container .lbound_output_points [indexes , :]
320366 y_max = self .data_container .ubound_output_points [indexes , :]
321367 for index in range (nb_points ):
322- ((score , input_sat , output_sat ), times ) = solve_stability_property (
368+ ((score , input_sat , output_sat ), times ) = solve_stability_property (
323369 network ,
324370 x_min = x_min [index ],
325371 x_max = x_max [index ],
326372 y_min = y_min [index ],
327373 y_max = y_max [index ],
374+ options = self .options ,
328375 timeout = self .options .get ("time_out" , 200 ),
329376 )
330377 output .init_time_per_sample [index ] = times [1 ] - times [0 ]
331378 output .verif_time_per_sample [index ] = times [2 ] - times [1 ]
332379 status = StatusVerif .UNKNOWN
380+
333381 if score [0 ]:
334382 # Found counter example
335383 status = StatusVerif .VIOLATED
@@ -343,6 +391,11 @@ def verif(self, indexes: np.ndarray) -> BlockVerifOutput:
343391 logger .info (f"Current Verified (%) { np .sum (output .status == StatusVerif .VERIFIED ) / nb_points * 100 } " )
344392 logger .info (f"Current Violated (%) { np .sum (output .status == StatusVerif .VIOLATED ) / nb_points * 100 } " )
345393 logger .info (f"Current Timeout (%) { np .sum (output .status == StatusVerif .TIMEOUT ) / nb_points * 100 } " )
394+ # times returned by solve_query_property = (t_init, t_end_init, t_end_solve)
395+ print (colored (f"\n \n \
396+ Time to build marabou model: { output .build_time } \n \
397+ Time to init marabou model: { output .init_time_per_sample [index ]} \n \
398+ Time to verify property (marabou): { output .verif_time_per_sample [index ]} , " ,'red' ))
346399 return output
347400
348401 @staticmethod
0 commit comments