@@ -274,7 +274,7 @@ def solve_query(network,options=None):
274274 output_sat ,
275275 )
276276
277- def solve_stability_property (network : Union [MarabouSequential , MarabouNetworkONNX ], x_min , x_max , y_min , y_max , options = None ):
277+ def solve_stability_property_deprected (network : Union [MarabouSequential , MarabouNetworkONNX ], x_min , x_max , y_min , y_max , options = None ):
278278 if isinstance (network ,MarabouSequential ):
279279 output_dim = network .get_output_dim ()
280280 elif isinstance (network ,MarabouNetworkONNX ):
@@ -317,6 +317,40 @@ def solve_stability_property(network: Union[MarabouSequential, MarabouNetworkONN
317317 #print(f'marabou solve: {result[0]}')
318318 return result , (t_init , t_end_init , t_end_solve )
319319
320+ def solve_stability_property (network : Union [MarabouSequential , MarabouNetworkONNX ], x_min , x_max , y_min , y_max , options = None ,logits_rank = None ):
321+
322+ t_init = time .perf_counter ()
323+ # Set Lower and Upper bound for the input perturbation
324+ for i , x_min_i in enumerate (x_min ):
325+ network .setLowerBound (network .inputVars [0 ][0 ][i ], x_min_i )
326+ for i , x_max_i in enumerate (x_max ):
327+ network .setUpperBound (network .inputVars [0 ][0 ][i ], x_max_i )
328+
329+ # find a sample that is either greater than Y_max or lower than Y_min
330+ for (coeff , bound ) in zip ([1 ,- 1 ],[y_min ,y_max ]):
331+ order_bounds = np .argsort (logits_rank )[::- 1 ]
332+ for i in order_bounds :
333+ if np .abs (bound [i ])>= 1e6 :
334+ continue
335+ # equ_l : f(x)[i]< Y_min[i] or f(x)[i]> Y_max[i]
336+ network .addInequality ([network .outputVars [0 ][0 ][i ]],\
337+ [coeff ],
338+ coeff * bound [i ],
339+ isProperty = True )
340+ if isinstance (network ,MarabouSequential ):
341+ t_end_init = time .perf_counter () # to verify
342+ result = network .solve_query (options )
343+ elif isinstance (network ,MarabouNetworkONNX ):
344+ t_end_init = time .perf_counter ()
345+ result = solve_query (network ,options )
346+ t_end_solve = time .perf_counter ()
347+ network .additionalEquList .clear ()
348+
349+ exit_code = result [0 ] # solve_query return: [result[0] == "sat", result[0] == "unsat", result[0] == "TIMEOUT"],
350+ if exit_code [0 ] or exit_code [- 1 ]:
351+ break
352+ network .clearProperty ()
353+ return result , (t_init , t_end_init , t_end_solve )
320354class MarabouBlock (BlockVerif ):
321355 def __init__ (
322356 self ,
@@ -365,14 +399,14 @@ def verif(self, indexes: np.ndarray) -> BlockVerifOutput:
365399 y_min = self .data_container .lbound_output_points [indexes , :]
366400 y_max = self .data_container .ubound_output_points [indexes , :]
367401 for index in range (nb_points ):
368- import pdb ; pdb .set_trace ()
369402 ((score , input_sat , output_sat ), times ) = solve_stability_property (
370403 network ,
371404 x_min = x_min [index ],
372405 x_max = x_max [index ],
373406 y_min = y_min [index ],
374407 y_max = y_max [index ],
375408 options = self .options ,
409+ logits_rank = self .data_container .output_points [index ],
376410 #timeout=self.options.get("time_out", 200),
377411 )
378412 output .init_time_per_sample [index ] = times [1 ] - times [0 ]
0 commit comments