diff --git a/Cargo.lock b/Cargo.lock index 150dcc59..1229f7f7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2676,6 +2676,7 @@ dependencies = [ "num", "once_cell", "ref-cast", + "regex", "tempfile", "thiserror 1.0.69", "tracing", diff --git a/src/driver.rs b/src/driver.rs index 146ef65d..d075bef8 100644 --- a/src/driver.rs +++ b/src/driver.rs @@ -53,7 +53,7 @@ use crate::{ vcgen::Vcgen, }, version::write_detailed_version_info, - DebugOptions, SliceOptions, SliceVerifyMethod, VerifyCommand, VerifyError, + DebugOptions, SMTSolverType, SliceOptions, SliceVerifyMethod, VerifyCommand, VerifyError, }; use ariadne::ReportKind; @@ -65,7 +65,7 @@ use z3::{ use z3rro::{ model::InstrumentedModel, probes::ProbeSummary, - prover::{IncrementalMode, ProveResult, Prover}, + prover::{IncrementalMode, ProveResult, Prover, SolverType}, smtlib::Smtlib, util::{PrefixWriter, ReasonUnknown}, }; @@ -696,7 +696,13 @@ impl<'ctx> SmtVcUnit<'ctx> { let span = info_span!("SAT check"); let _entered = span.enter(); - let prover = mk_valid_query_prover(limits_ref, ctx, translate, &self.vc); + let prover = mk_valid_query_prover( + limits_ref, + ctx, + translate, + &self.vc, + options.smt_solver_options.smt_solver.clone(), + ); if options.debug_options.probe { let goal = Goal::new(ctx, false, false, false); @@ -823,9 +829,18 @@ fn mk_valid_query_prover<'smt, 'ctx>( ctx: &'ctx Context, smt_translate: &TranslateExprs<'smt, 'ctx>, valid_query: &Bool<'ctx>, + smt_solver: SMTSolverType, ) -> Prover<'ctx> { + let solver_type = match smt_solver { + SMTSolverType::InternalZ3 => SolverType::InternalZ3, + SMTSolverType::ExternalZ3 => SolverType::ExternalZ3, + SMTSolverType::Swine => SolverType::SWINE, + SMTSolverType::CVC5 => SolverType::CVC5, + SMTSolverType::Yices => SolverType::YICES, + }; + // create the prover and set the params - let mut prover = Prover::new(ctx, IncrementalMode::Native); + let mut prover = Prover::new(ctx, IncrementalMode::Native, solver_type); if let Some(remaining) = limits_ref.time_left() { prover.set_timeout(remaining); } diff --git a/src/main.rs b/src/main.rs index 2c0b5932..da54383f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -39,7 +39,10 @@ use tokio::task::JoinError; use tracing::{error, info, warn}; use vc::explain::VcExplanation; -use z3rro::{prover::ProveResult, util::ReasonUnknown}; +use z3rro::{ + prover::{ProveResult, ProverCommandError}, + util::ReasonUnknown, +}; pub mod ast; mod driver; @@ -138,6 +141,9 @@ pub struct VerifyCommand { #[command(flatten)] pub debug_options: DebugOptions, + + #[command(flatten)] + pub smt_solver_options: SMTSolverOptions, } #[derive(Debug, Args)] @@ -380,6 +386,28 @@ pub struct DebugOptions { pub probe: bool, } +#[derive(Debug, Default, Args)] +#[command(next_help_heading = "SMT Solver Options")] +pub struct SMTSolverOptions { + #[arg(long, default_value = "default")] + pub smt_solver: SMTSolverType, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, ValueEnum)] +pub enum SMTSolverType { + #[default] + #[value(name = "default")] + InternalZ3, + #[value(name = "z3")] + ExternalZ3, + #[value(name = "swine")] + Swine, + #[value(name = "cvc5")] + CVC5, + #[value(name = "yices")] + Yices, +} + #[derive(Debug, Default, Args)] #[command(next_help_heading = "Slicing Options")] pub struct SliceOptions { @@ -533,6 +561,10 @@ fn finalize_verify_result( tracing::error!("Interrupted"); ExitCode::from(130) // 130 seems to be a standard exit code for CTRL+C } + Err(VerifyError::ProverError(err)) => { + eprintln!("{}", err.to_string()); + ExitCode::from(1) + } } } @@ -612,6 +644,8 @@ pub enum VerifyError { /// The verifier was interrupted. #[error("interrupted")] Interrupted, + #[error("{0}")] + ProverError(#[from] ProverCommandError), } /// Verify a list of `user_files`. The `options.files` value is ignored here. diff --git a/src/opt/fuzz_test.rs b/src/opt/fuzz_test.rs index cdcbd81e..e264a913 100644 --- a/src/opt/fuzz_test.rs +++ b/src/opt/fuzz_test.rs @@ -2,7 +2,8 @@ use proptest::{ prelude::*, test_runner::{TestCaseResult, TestRunner}, }; -use z3rro::prover::{IncrementalMode, ProveResult, Prover}; + +use z3rro::prover::{IncrementalMode, ProveResult, Prover, SolverType}; use crate::{ ast::{ @@ -201,23 +202,24 @@ fn prove_equiv(expr: Expr, optimized: Expr, tcx: &TyCtx) -> TestCaseResult { let smt_ctx = SmtCtx::new(&ctx, tcx); let mut translate = TranslateExprs::new(&smt_ctx); let eq_expr_z3 = translate.t_bool(&eq_expr); - let mut prover = Prover::new(&ctx, IncrementalMode::Native); + let mut prover = Prover::new(&ctx, IncrementalMode::Native, SolverType::InternalZ3); translate .local_scope() .add_assumptions_to_prover(&mut prover); prover.add_provable(&eq_expr_z3); let x = match prover.check_proof() { - ProveResult::Proof => Ok(()), - ProveResult::Counterexample => { + Ok(ProveResult::Proof) => Ok(()), + Ok(ProveResult::Counterexample) => { let model = prover.get_model().unwrap(); Err(TestCaseError::fail(format!( "rewrote {} ...into... {}, but those are not equivalent:\n{}", expr, optimized, model ))) } - ProveResult::Unknown(reason) => { + Ok(ProveResult::Unknown(reason)) => { Err(TestCaseError::fail(format!("unknown result ({})", reason))) } + Err(err) => Err(TestCaseError::fail(format!("{}", err))), }; x } diff --git a/src/opt/unfolder.rs b/src/opt/unfolder.rs index 7cc15b18..d97146bb 100644 --- a/src/opt/unfolder.rs +++ b/src/opt/unfolder.rs @@ -25,7 +25,7 @@ use std::ops::DerefMut; use z3::SatResult; -use z3rro::prover::{IncrementalMode, Prover}; +use z3rro::prover::{IncrementalMode, Prover, SolverType}; use crate::{ ast::{ @@ -60,7 +60,7 @@ impl<'smt, 'ctx> Unfolder<'smt, 'ctx> { // it's important that we use the native incremental mode here, because // the performance benefit from the unfolder relies on many very fast // SAT checks. - let prover = Prover::new(ctx.ctx(), IncrementalMode::Native); + let prover = Prover::new(ctx.ctx(), IncrementalMode::Native, SolverType::InternalZ3); Unfolder { subst: Subst::new(ctx.tcx(), &limits_ref), @@ -98,7 +98,7 @@ impl<'smt, 'ctx> Unfolder<'smt, 'ctx> { // here we want to do a SAT check and not a proof search. if the // expression is e.g. `false`, then we want to get `Unsat` from the // solver and not `Proof`! - if this.prover.check_sat() == SatResult::Unsat { + if this.prover.check_sat() == Ok(SatResult::Unsat) { tracing::trace!(solver=?this.prover, "eliminated zero expr"); None } else { diff --git a/src/slicing/solver.rs b/src/slicing/solver.rs index f0149f96..c4bcb7a7 100644 --- a/src/slicing/solver.rs +++ b/src/slicing/solver.rs @@ -9,16 +9,16 @@ use z3::{ }; use z3rro::{ model::{InstrumentedModel, ModelConsistency}, - prover::{ProveResult, Prover}, + prover::{ProveResult, Prover, ProverCommandError}, util::ReasonUnknown, }; use crate::{ ast::{ExprBuilder, Span}, - resource_limits::{LimitError, LimitsRef}, + resource_limits::LimitsRef, slicing::{ model::{SliceMode, SliceModel}, - util::{PartialMinimizeResult, SubsetExploration}, + util::{at_most_k, PartialMinimizeResult, SubsetExploration}, }, smt::translate_exprs::TranslateExprs, VerifyError, @@ -218,7 +218,10 @@ impl<'ctx> SliceSolver<'ctx> { options, limits_ref, )?; - if exists_forall_solver.check_sat() == SatResult::Sat { + let sat_res = exists_forall_solver + .check_sat() + .map_err(|err| VerifyError::ProverError(err))?; + if sat_res == SatResult::Sat { let model = exists_forall_solver.get_model().unwrap(); let slice_model = SliceModel::from_model(SliceMode::Verify, &self.slice_stmts, selection, &model); @@ -251,7 +254,7 @@ impl<'ctx> SliceSolver<'ctx> { let res = self.prover.check_proof_assuming(&active_toggle_values); let mut slice_searcher = SliceModelSearch::new(active_toggle_values.clone()); - if let ProveResult::Proof = res { + if let Ok(ProveResult::Proof) = res { slice_searcher.found_active(self.prover.get_unsat_core()); } @@ -351,7 +354,10 @@ impl<'ctx> SliceSolver<'ctx> { self.prover.push(); slice_sat_binary_search(&mut self.prover, &active_toggle_values, options, limits_ref)?; - let res = self.prover.check_proof(); + let res = self + .prover + .check_proof() + .map_err(|err| VerifyError::ProverError(err))?; let model = if let Some(model) = self.prover.get_model() { assert!(matches!( res, @@ -461,16 +467,16 @@ fn slice_sat_binary_search<'ctx>( ) -> Result<(), VerifyError> { assert_eq!(prover.level(), 2); - let slice_vars: Vec<(&Bool<'ctx>, i32)> = - active_slice_vars.iter().map(|value| (value, 1)).collect(); - let set_at_most_true = |prover: &mut Prover<'ctx>, at_most_n: usize| { prover.pop(); prover.push(); let ctx = prover.get_context(); - let at_most_n_true = Bool::pb_le(ctx, &slice_vars, at_most_n as i32); - prover.add_assumption(&at_most_n_true); + if !active_slice_vars.is_empty() { + let at_most_n_true = + at_most_k(ctx, at_most_n, active_slice_vars, prover.get_solver_type()); + prover.add_assumption(&at_most_n_true); + } }; // TODO: we could have min_least_bound set to 1 if we could conclude for @@ -485,7 +491,7 @@ fn slice_sat_binary_search<'ctx>( // the fix would be to track explicitly whether we can make that assumption // that min_least_bound is 1. let min_least_bound = 0; - let mut minimize = PartialMinimizer::new(min_least_bound..=slice_vars.len()); + let mut minimize = PartialMinimizer::new(min_least_bound..=active_slice_vars.len()); let mut cur_solver_n = None; let mut slice_searcher = SliceModelSearch::new(active_slice_vars.to_vec()); @@ -506,7 +512,10 @@ fn slice_sat_binary_search<'ctx>( if let Some(timeout) = limits_ref.time_left() { prover.set_timeout(timeout); } - let res = prover.check_sat(); + + let res = prover + .check_sat() + .map_err(|err| VerifyError::ProverError(err))?; entered.record("res", tracing::field::debug(res)); @@ -571,7 +580,9 @@ fn slice_sat_binary_search<'ctx>( if let Some(timeout) = limits_ref.time_left() { prover.set_timeout(timeout); } - let res = prover.check_sat(); + let res = prover + .check_sat() + .map_err(|err| VerifyError::ProverError(err))?; if minimize.min_accept().is_some() { assert!(res == SatResult::Sat || res == SatResult::Unknown); } else if minimize.max_reject().is_some() { @@ -593,7 +604,7 @@ pub fn slice_unsat_search<'ctx>( prover: &mut Prover<'ctx>, options: &SliceSolveOptions, limits_ref: &LimitsRef, -) -> Result>>, LimitError> { +) -> Result>>, VerifyError> { let mut slice_searcher = SliceModelSearch::new(exploration.variables().iter().cloned().collect_vec()); let all_variables = exploration.variables().clone(); @@ -602,12 +613,12 @@ pub fn slice_unsat_search<'ctx>( limits_ref.check_limits()?; match check_proof_seed(&all_variables, prover, limits_ref, &seed) { - ProveResult::Proof => { + Ok(ProveResult::Proof) => { // now start the shrinking, then block up let res = exploration.shrink_block_unsat(seed, |seed| { match check_proof_seed(&all_variables, prover, limits_ref, seed) { - ProveResult::Proof => Some(unsat_core_to_seed(prover, &all_variables)), - ProveResult::Counterexample | ProveResult::Unknown(_) => None, + Ok(ProveResult::Proof) => Some(unsat_core_to_seed(prover, &all_variables)), + _ => None, } }); @@ -620,16 +631,16 @@ pub fn slice_unsat_search<'ctx>( SliceMinimality::Size => exploration.block_at_least(res.len()), } } - ProveResult::Counterexample => { + Ok(ProveResult::Counterexample) => { // grow the counterexample and then block down exploration.grow_block_sat(seed, |seed| { match check_proof_seed(&all_variables, prover, limits_ref, seed) { - ProveResult::Counterexample => true, - ProveResult::Proof | ProveResult::Unknown(_) => false, + Ok(ProveResult::Counterexample) => true, + _ => false, } }); } - ProveResult::Unknown(_) => { + Ok(ProveResult::Unknown(_)) => { exploration.block_this(&seed); match options.unknown { @@ -643,6 +654,7 @@ pub fn slice_unsat_search<'ctx>( } } } + Err(err) => return Err(VerifyError::ProverError(err)), } } @@ -655,7 +667,7 @@ fn check_proof_seed<'ctx>( prover: &mut Prover<'ctx>, limits_ref: &LimitsRef, seed: &IndexSet>, -) -> ProveResult { +) -> Result { let mut timeout = Duration::from_millis(100); if let Some(time_left) = limits_ref.time_left() { timeout = timeout.min(time_left); diff --git a/src/slicing/transform_test.rs b/src/slicing/transform_test.rs index e0d04a2c..0730b4fe 100644 --- a/src/slicing/transform_test.rs +++ b/src/slicing/transform_test.rs @@ -3,7 +3,7 @@ use itertools::Itertools; use z3rro::{ model::SmtEval, - prover::{IncrementalMode, ProveResult, Prover}, + prover::{IncrementalMode, ProveResult, Prover, SolverType}, }; use crate::{ @@ -125,7 +125,7 @@ fn prove_equiv( let smt_ctx = SmtCtx::new(&ctx, tcx); let mut translate = TranslateExprs::new(&smt_ctx); let eq_expr_z3 = translate.t_bool(&eq_expr); - let mut prover = Prover::new(&ctx, IncrementalMode::Native); + let mut prover = Prover::new(&ctx, IncrementalMode::Native, SolverType::InternalZ3); translate .local_scope() .add_assumptions_to_prover(&mut prover); @@ -135,15 +135,16 @@ fn prove_equiv( } prover.add_provable(&eq_expr_z3); let x = match prover.check_proof() { - ProveResult::Proof => Ok(()), - ProveResult::Counterexample => { + Ok(ProveResult::Proof) => Ok(()), + Ok(ProveResult::Counterexample) => { let model = prover.get_model().unwrap(); Err(format!( "we want to rewrite {:?} ...into... {:?} under assumptions {:?}, but those are not equivalent:\n{}\n original evaluates to {}\n rewritten evaluates to {}", stmt1, stmt2, assumptions, &model, translate.t_eureal(&stmt1_vc).eval(&model).unwrap(), translate.t_eureal(&stmt2_vc).eval(&model).unwrap() )) } - ProveResult::Unknown(reason) => Err(format!("unknown result ({})", reason)), + Ok(ProveResult::Unknown(reason)) => Err(format!("unknown result ({})", reason)), + Err(err) => Err(format!("{}", err)), }; x } diff --git a/src/slicing/util.rs b/src/slicing/util.rs index ec9bce88..3d4ed9a3 100644 --- a/src/slicing/util.rs +++ b/src/slicing/util.rs @@ -6,7 +6,11 @@ use std::{ use indexmap::IndexSet; use itertools::Itertools; use tracing::{instrument, trace}; -use z3::{ast::Bool, Context, SatResult, Solver}; +use z3::{ + ast::{Bool, Int}, + Context, SatResult, Solver, +}; +use z3rro::prover::SolverType; /// A result of a test during the partial minimization. Either we accept all /// values from this value upwards, we reject all values from this value @@ -411,3 +415,30 @@ impl<'ctx> SubsetExploration<'ctx> { current } } + +/// Create an SMT expression that is true if at most k of the given boolean variables evaluate to true +pub fn at_most_k<'ctx>( + ctx: &'ctx Context, + k: usize, + values: &[Bool<'ctx>], + solver_type: SolverType, +) -> Bool<'ctx> { + match solver_type { + SolverType::CVC5 | SolverType::YICES => { + let int_values: Vec> = values + .iter() + .map(|b| b.ite(&Int::from_i64(ctx, 1), &Int::from_i64(ctx, 0))) + .collect(); + + let sum = Int::add(ctx, &int_values); + + let k_int = Int::from_i64(ctx, k as i64); + sum.le(&k_int) + } + _ => { + let slice_vars: Vec<(&Bool<'ctx>, i32)> = + values.iter().map(|value| (value, 1)).collect(); + Bool::pb_le(ctx, &slice_vars, k as i32) + } + } +} diff --git a/z3rro/Cargo.toml b/z3rro/Cargo.toml index ee14818c..1259ded1 100644 --- a/z3rro/Cargo.toml +++ b/z3rro/Cargo.toml @@ -15,6 +15,7 @@ thiserror = "1.0" im-rc = "15" enum-map = "2.7.3" itertools = "0.14.0" +regex = "1" [features] datatype-eureal = [] diff --git a/z3rro/src/model.rs b/z3rro/src/model.rs index bd92d9c0..a6a2bb51 100644 --- a/z3rro/src/model.rs +++ b/z3rro/src/model.rs @@ -7,7 +7,9 @@ use std::{ }; use num::{BigInt, BigRational}; + use thiserror::Error; + use z3::{ ast::{Ast, Bool, Dynamic, Int, Real}, FuncDecl, FuncInterp, Model, diff --git a/z3rro/src/prover.rs b/z3rro/src/prover.rs index a865f510..bfbd6b1a 100644 --- a/z3rro/src/prover.rs +++ b/z3rro/src/prover.rs @@ -1,6 +1,17 @@ //! Not a SAT solver, but a prover. There's a difference. +use itertools::Itertools; +use thiserror::Error; + +use std::{ + collections::VecDeque, + fmt::Display, + io::{Seek, SeekFrom, Write}, + path::Path, + process::{Command, Output}, + time::Duration, +}; -use std::{fmt::Display, time::Duration}; +use tempfile::NamedTempFile; use z3::{ ast::{forall_const, Ast, Bool, Dynamic}, @@ -13,6 +24,25 @@ use crate::{ util::{set_solver_timeout, ReasonUnknown}, }; +#[derive(Debug, Error, PartialEq)] +pub enum ProverCommandError { + #[error("Process execution failed: {0}")] + ProcessError(String), + #[error("Parse error")] + ParseError, + #[error("Unexpected result from prover: {0}")] + UnexpectedResultError(String), +} + +#[derive(Debug, PartialEq, Clone)] +pub enum SolverType { + InternalZ3, + ExternalZ3, + SWINE, + CVC5, + YICES, +} + /// The result of a prove query. #[derive(Debug)] pub enum ProveResult { @@ -21,6 +51,180 @@ pub enum ProveResult { Unknown(ReasonUnknown), } +/// If z3 is used as the SMT solver, it is not necessary to store +/// a counterexample (for Sat) or reason (for Unknown), since the +/// Z3 solver already retains this information internally. +/// In this case, it is only used to store the SAT result. +/// +/// For SwInE, this can be used either to +/// 1) transport the result from SwInE, or +/// 2) store SAT result along with a reason for Unknown. +#[derive(Debug, Clone)] +pub enum SolverResult { + Unsat, + Sat(Option), + Unknown(Option), +} + +impl SolverResult { + fn to_sat_result(&self) -> SatResult { + match self { + SolverResult::Unsat => SatResult::Unsat, + SolverResult::Sat(_) => SatResult::Sat, + SolverResult::Unknown(_) => SatResult::Unknown, + } + } +} + +fn call_solver( + file_path: &Path, + solver: SolverType, + timeout: Option, + sat_result: Option, +) -> Result { + let (solver, args) = match solver { + SolverType::InternalZ3 => { + unreachable!("The function 'call_solver' should never be called for z3"); + } + SolverType::ExternalZ3 => { + let mut args: Vec = match sat_result { + Some(SatResult::Unsat) => unreachable!( + "The function 'call_solver' should not be called again after an 'unsat' result" + ), + Some(SatResult::Sat) => vec!["-model".to_string()], + Some(SatResult::Unknown) | None => vec![], + }; + + if let Some(t) = timeout { + args.push(format!("-t:{}", t.as_millis())); + } + + ("z3", args) + } + SolverType::SWINE => { + let args: Vec = match sat_result { + Some(SatResult::Unsat) => unreachable!( + "The function 'call_solver' should not be called again after an 'unsat' result" + ), + _ => vec!["--no-version".to_string()], + }; + + ("swine", args) + } + SolverType::CVC5 => { + let mut args: Vec = match sat_result { + Some(SatResult::Unsat) => unreachable!( + "The function 'call_solver' should not be called again after an 'unsat' result" + ), + Some(SatResult::Sat) => vec!["--produce-models".to_string()], + _ => vec![], + }; + + if let Some(t) = timeout { + args.push(format!("--tlimit={}", t.as_millis())); + } + + ("cvc5", args) + } + SolverType::YICES => { + let mut args: Vec = match sat_result { + Some(SatResult::Unsat) => unreachable!( + "The function 'call_solver' should not be called again after an 'unsat' result" + ), + Some(SatResult::Sat) => vec!["--smt2-model-format".to_string()], + _ => vec![], + }; + + if let Some(t) = timeout { + let secs = t.as_secs(); + + if secs > 0 { + args.push(format!("--timeout={}", secs)); + } else { + panic!("Timeout must be at least one second. Yices does not support timeouts shorter than 1 second.") + } + } + + ("yices-smt2", args) + } + }; + + Command::new(solver).args(&args).arg(file_path).output() +} + +/// To execute the SMT solver correctly, specific modifications to the input are required: +/// 1) For SwInE, remove lines that contain a `forall` quantifier or the declaration of the exponential function (`exp``). +/// 2) For other solvers, add a line to set logic, and remove incorrect assertions such as `(assert add)`. +/// 3) For solvers that do not support at-most, convert those assertions into equivalent logic. +fn transform_input_lines(input: &str, solver: SolverType, timeout: Option) -> String { + let timeout_option = if let Some(t) = timeout { + match solver { + SolverType::InternalZ3 => { + unreachable!( + "The function 'transform_input_lines' should never be called for internal z3" + ); + } + SolverType::SWINE => format!("(set-option :timeout {})\n", t.as_millis()), + _ => "".to_string(), + } + } else { + "".to_string() + }; + + let mut output = match solver { + SolverType::CVC5 | SolverType::YICES => { + let mut output = String::new(); + let logic = if input.contains("*") || input.contains("/") { + "(set-logic QF_NIRA)\n" + } else { + "(set-logic QF_LIRA)\n" + }; + output.push_str(logic); + output + } + _ => String::new(), + }; + + output.push_str(&timeout_option); + + if solver == SolverType::ExternalZ3 { + output.push_str(input); + } else { + let mut tmp_buffer: VecDeque = VecDeque::new(); + let mut input_buffer: VecDeque = input.chars().collect(); + let mut cnt = 0; + + let condition = |tmp: &str| match solver { + SolverType::SWINE => !tmp.contains("declare-fun exp") && !tmp.contains("forall"), + _ => !tmp.contains("(assert and)"), + }; + + // Collect characters until all opened parentheses are closed, and + // keep this block if it does not contain 'declare-fun exp' or 'forall'. + while let Some(c) = input_buffer.pop_front() { + tmp_buffer.push_back(c); + match c { + '(' => { + cnt += 1; + } + ')' => { + cnt -= 1; + if cnt == 0 { + let tmp: String = tmp_buffer.iter().collect(); + if condition(&tmp) { + output.push_str(&tmp); + } + tmp_buffer.clear(); + } + } + _ => {} + } + } + } + + output +} + impl Display for ProveResult { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { @@ -50,7 +254,7 @@ enum StackSolver<'ctx> { } #[derive(Debug)] -struct LastSatResult { +struct LastSatSolverResult { /// Whether the current model is consistent with the assertions. If the SMT /// solver returned [`SatResult::Unknown`], it is /// [`ModelConsistency::Unknown`]. @@ -59,7 +263,7 @@ struct LastSatResult { /// It is reset any time the assertions on the solver are modified. /// Sometimes Z3 caches on its own, but it is not reliable. Therefore, we do /// it here as well to be sure. - last_result: SatResult, + last_result: SolverResult, } /// A prover wraps a SAT solver, but it's used to prove validity of formulas. @@ -83,13 +287,16 @@ pub struct Prover<'ctx> { level: usize, /// The minimum level where an assertion was added to the solver. min_level_with_provables: Option, + /// SMT solver type + smt_solver: SolverType, /// Cached information about the last SAT/proof check call. - last_result: Option, + last_result: Option, + result_solver: Option>, } impl<'ctx> Prover<'ctx> { /// Create a new prover with the given [`Context`] and [`IncrementalMode`]. - pub fn new(ctx: &'ctx Context, mode: IncrementalMode) -> Self { + pub fn new(ctx: &'ctx Context, mode: IncrementalMode, solver_type: SolverType) -> Self { Prover { ctx, timeout: None, @@ -101,7 +308,9 @@ impl<'ctx> Prover<'ctx> { }, level: 0, min_level_with_provables: None, + smt_solver: solver_type, last_result: None, + result_solver: None, } } @@ -154,34 +363,88 @@ impl<'ctx> Prover<'ctx> { } /// `self.check_proof_assuming(&[])`. - pub fn check_proof(&mut self) -> ProveResult { + pub fn check_proof(&mut self) -> Result { self.check_proof_assuming(&[]) } /// Do the SAT check, but consider a check with no provables to be a /// [`ProveResult::Proof`]. - pub fn check_proof_assuming(&mut self, assumptions: &[Bool<'ctx>]) -> ProveResult { + pub fn check_proof_assuming( + &mut self, + assumptions: &[Bool<'ctx>], + ) -> Result { if !self.has_provables() { - return ProveResult::Proof; + return Ok(ProveResult::Proof); } - let res = match &self.last_result { - Some(cached_result) if assumptions.is_empty() => cached_result.last_result, + match self.smt_solver { + SolverType::InternalZ3 => { + let res = match &self.last_result { + Some(cached_result) if assumptions.is_empty() => { + cached_result.last_result.clone() + } + _ => { + let solver = self.get_solver(); + let res = if assumptions.is_empty() { + solver.check() + } else { + solver.check_assumptions(assumptions) + }; + + let solver_result = match res { + SatResult::Unsat => SolverResult::Unsat, + SatResult::Unknown => SolverResult::Unknown(None), + SatResult::Sat => SolverResult::Sat(None), + }; + self.cache_result(solver_result.clone()); + solver_result + } + }; + + match res { + SolverResult::Unsat => Ok(ProveResult::Proof), + SolverResult::Unknown(_) => { + Ok(ProveResult::Unknown(self.get_reason_unknown().unwrap())) + } + SolverResult::Sat(_) => Ok(ProveResult::Counterexample), + } + } _ => { - let solver = self.get_solver(); - let res = if assumptions.is_empty() { - solver.check() - } else { - solver.check_assumptions(assumptions) + let res = match &self.last_result { + Some(cached_result) if assumptions.is_empty() => { + Ok(cached_result.last_result.clone()) + } + _ => { + let solver_result = self.run_solver(assumptions)?; + + if let SolverResult::Sat(Some(cex)) = solver_result.clone() { + if let Some(solver) = &self.result_solver { + solver.from_string(cex.clone()); + solver.check(); + } else { + let solver = Solver::new(self.ctx); + solver.from_string(cex.clone()); + solver.check(); + self.result_solver = Some(solver); + } + } + self.cache_result(solver_result.clone()); + + Ok(solver_result) + } }; - self.cache_result(res); - res + + let sat_result = res?; + + match sat_result { + SolverResult::Unsat => Ok(ProveResult::Proof), + SolverResult::Unknown(r) => { + let reason = r.unwrap_or(ReasonUnknown::Other("".to_string())); + Ok(ProveResult::Unknown(reason)) + } + SolverResult::Sat(_) => Ok(ProveResult::Counterexample), + } } - }; - match res { - SatResult::Unsat => ProveResult::Proof, - SatResult::Unknown => ProveResult::Unknown(self.get_reason_unknown().unwrap()), - SatResult::Sat => ProveResult::Counterexample, } } @@ -194,25 +457,55 @@ impl<'ctx> Prover<'ctx> { } /// Do the regular SAT check. - pub fn check_sat(&mut self) -> SatResult { + pub fn check_sat(&mut self) -> Result { if let Some(cached_result) = &self.last_result { - return cached_result.last_result; + return Ok(cached_result.last_result.to_sat_result()); } - let res = self.get_solver().check(); - self.cache_result(res); - res + + let sat_result = match self.smt_solver { + SolverType::InternalZ3 => { + let sat_result = self.get_solver().check(); + + let solver_result = match sat_result { + SatResult::Unsat => SolverResult::Unsat, + SatResult::Unknown => SolverResult::Unknown(None), + SatResult::Sat => SolverResult::Sat(None), + }; + self.cache_result(solver_result); + + sat_result + } + _ => { + let solver_result = self.run_solver(&[])?; + if let SolverResult::Sat(Some(cex)) = solver_result.clone() { + if let Some(solver) = &self.result_solver { + solver.from_string(cex.clone()); + solver.check(); + } else { + let solver = Solver::new(self.ctx); + solver.from_string(cex.clone()); + solver.check(); + self.result_solver = Some(solver); + } + } + self.cache_result(solver_result.clone()); + solver_result.to_sat_result() + } + }; + + Ok(sat_result) } /// Save the result of the last SAT/proof check. - fn cache_result(&mut self, sat_result: SatResult) { - let model_consistency = match sat_result { - SatResult::Sat => Some(ModelConsistency::Consistent), - SatResult::Unknown => Some(ModelConsistency::Unknown), - SatResult::Unsat => None, + fn cache_result(&mut self, solver_result: SolverResult) { + let model_consistency = match solver_result { + SolverResult::Sat(_) => Some(ModelConsistency::Consistent), + SolverResult::Unknown(_) => Some(ModelConsistency::Unknown), + SolverResult::Unsat => None, }; - self.last_result = Some(LastSatResult { + self.last_result = Some(LastSatSolverResult { model_consistency, - last_result: sat_result, + last_result: solver_result, }); } @@ -224,7 +517,17 @@ impl<'ctx> Prover<'ctx> { /// [`ModelConsistency::Inconsistent`]. pub fn get_model(&self) -> Option> { let consistency = self.last_result.as_ref()?.model_consistency?; - let model = self.get_solver().get_model()?; + let model = match self.smt_solver { + SolverType::InternalZ3 => self.get_solver().get_model()?, + _ => { + let solver = match &self.result_solver { + Some(solver) => solver, + None => &Solver::new(self.ctx), + }; + + solver.get_model()? + } + }; Some(InstrumentedModel::new(consistency, model)) } @@ -235,9 +538,23 @@ impl<'ctx> Prover<'ctx> { /// See [`Solver::get_reason_unknown`]. pub fn get_reason_unknown(&self) -> Option { - self.get_solver() - .get_reason_unknown() - .map(|reason| reason.parse().unwrap()) + match self.smt_solver { + SolverType::InternalZ3 => self + .get_solver() + .get_reason_unknown() + .map(|reason| reason.parse().unwrap()), + _ => { + if let Some(cached_result) = &self.last_result { + if let SolverResult::Unknown(reason_unknown) = &cached_result.last_result { + reason_unknown.clone() + } else { + Some(ReasonUnknown::Other("".to_string())) + } + } else { + Some(ReasonUnknown::Other("".to_string())) + } + } + } } /// See [`Solver::push`]. @@ -321,7 +638,7 @@ impl<'ctx> Prover<'ctx> { &[], &Bool::and(self.ctx, &self.get_assertions()).not(), ); - let mut res = Prover::new(self.ctx, IncrementalMode::Native); // TODO + let mut res = Prover::new(self.ctx, IncrementalMode::Native, SolverType::InternalZ3); // TODO res.add_assumption(&theorem); res } @@ -330,13 +647,111 @@ impl<'ctx> Prover<'ctx> { pub fn get_smtlib(&self) -> Smtlib { Smtlib::from_solver(self.get_solver()) } + + pub fn get_solver_type(&self) -> SolverType { + self.smt_solver.clone() + } + + /// Execute an SMT solver (other than z3) + fn run_solver(&mut self, assumptions: &[Bool<'_>]) -> Result { + let mut smt_file: NamedTempFile = NamedTempFile::new().unwrap(); + smt_file + .write_all(self.generate_smtlib(assumptions).as_bytes()) + .unwrap(); + + let mut output = call_solver(smt_file.path(), self.get_solver_type(), self.timeout, None) + .map_err(|e| ProverCommandError::ProcessError(e.to_string()))?; + + if !output.status.success() { + return Err(ProverCommandError::ProcessError( + String::from_utf8_lossy(&output.stderr).to_string(), + )); + } + + let stdout = String::from_utf8_lossy(&output.stdout); + let first_line = stdout.lines().next().unwrap_or("").trim().to_lowercase(); + + let sat_result = match first_line.as_str() { + "sat" => { + smt_file + .as_file_mut() + .seek(SeekFrom::End(0)) + .map_err(|e| ProverCommandError::ProcessError(e.to_string()))?; + smt_file + .write_all(b"(get-model)\n") + .map_err(|e| ProverCommandError::ProcessError(e.to_string()))?; + + SatResult::Sat + } + "unsat" => SatResult::Unsat, + "unknown" => { + if self.smt_solver != SolverType::YICES { + smt_file + .as_file_mut() + .seek(SeekFrom::End(0)) + .map_err(|e| ProverCommandError::ProcessError(e.to_string()))?; + smt_file + .write_all(b"(get-info :reason-unknown)\n") + .map_err(|e| ProverCommandError::ProcessError(e.to_string()))?; + } + SatResult::Unknown + } + _ => { + return Err(ProverCommandError::UnexpectedResultError( + stdout.into_owned(), + )) + } + }; + + if sat_result == SatResult::Sat || sat_result == SatResult::Unknown { + output = call_solver( + smt_file.path(), + self.get_solver_type(), + self.timeout, + Some(sat_result), + ) + .map_err(|e| ProverCommandError::ProcessError(e.to_string()))?; + } + + let stdout = String::from_utf8_lossy(&output.stdout); + let mut lines_buffer: VecDeque<&str> = stdout.lines().collect(); + lines_buffer + .pop_front() + .ok_or(ProverCommandError::ParseError)?; + let solver_result = match sat_result { + SatResult::Unsat => SolverResult::Unsat, + SatResult::Unknown => { + SolverResult::Unknown(Some(ReasonUnknown::Other(lines_buffer.iter().join("\n")))) + } + SatResult::Sat => { + let cex = lines_buffer.iter().join(""); + SolverResult::Sat(Some(cex)) + } + }; + + Ok(solver_result) + } + + fn generate_smtlib(&self, assumptions: &[Bool<'_>]) -> String { + let mut smtlib = self.get_smtlib(); + + if assumptions.is_empty() { + smtlib.add_check_sat(); + } else { + smtlib.add_check_sat_assuming(assumptions.iter().map(|a| a.to_string()).collect()); + }; + + let smtlib = smtlib.into_string(); + + transform_input_lines(&smtlib, self.get_solver_type(), self.timeout) + } } #[cfg(test)] mod test { use z3::{ast::Bool, Config, Context, SatResult}; - use crate::prover::IncrementalMode; + use crate::prover::{IncrementalMode, SolverType}; use super::{ProveResult, Prover}; @@ -344,18 +759,18 @@ mod test { fn test_prover() { for mode in [IncrementalMode::Native, IncrementalMode::Emulated] { let ctx = Context::new(&Config::default()); - let mut prover = Prover::new(&ctx, mode); - assert!(matches!(prover.check_proof(), ProveResult::Proof)); - assert_eq!(prover.check_sat(), SatResult::Sat); + let mut prover = Prover::new(&ctx, mode, SolverType::InternalZ3); + assert!(matches!(prover.check_proof(), Ok(ProveResult::Proof))); + assert_eq!(prover.check_sat(), Ok(SatResult::Sat)); prover.push(); prover.add_assumption(&Bool::from_bool(&ctx, true)); - assert!(matches!(prover.check_proof(), ProveResult::Proof)); - assert_eq!(prover.check_sat(), SatResult::Sat); + assert!(matches!(prover.check_proof(), Ok(ProveResult::Proof))); + assert_eq!(prover.check_sat(), Ok(SatResult::Sat)); prover.pop(); - assert!(matches!(prover.check_proof(), ProveResult::Proof)); - assert_eq!(prover.check_sat(), SatResult::Sat); + assert!(matches!(prover.check_proof(), Ok(ProveResult::Proof))); + assert_eq!(prover.check_sat(), Ok(SatResult::Sat)); } } } diff --git a/z3rro/src/smtlib.rs b/z3rro/src/smtlib.rs index 9c767000..9127b42e 100644 --- a/z3rro/src/smtlib.rs +++ b/z3rro/src/smtlib.rs @@ -30,7 +30,17 @@ impl Smtlib { self.0.push_str("\n(check-sat)"); } - /// Add a `(check-sat)` command at the end. + /// Add a `(check-sat-assuming)` command at the end + pub fn add_check_sat_assuming(&mut self, assumptions: Vec) { + let assumptions_str: Vec = assumptions.iter().map(|a| a.to_string()).collect(); + + self.0.push_str(&format!( + "\n(check-sat-assuming ({}))", + assumptions_str.join(" ").as_str() + )); + } + + /// Add a `(get-model)` command at the end for counterexamples and a `(get-info :reason-unknown)` for unknown results. pub fn add_details_query(&mut self, prove_result: &ProveResult) { match prove_result { ProveResult::Proof => {} diff --git a/z3rro/src/test.rs b/z3rro/src/test.rs index 09d017cd..8291786d 100644 --- a/z3rro/src/test.rs +++ b/z3rro/src/test.rs @@ -4,7 +4,7 @@ use z3::{ast::Bool, Config, Context, SatResult}; -use crate::prover::{IncrementalMode, ProveResult, Prover}; +use crate::prover::{IncrementalMode, ProveResult, Prover, SolverType}; use super::scope::SmtScope; @@ -18,23 +18,24 @@ pub fn test_prove(f: impl for<'ctx> FnOnce(&'ctx Context, &mut SmtScope<'ctx>) - let mut scope = SmtScope::new(); let theorem = f(&ctx, &mut scope); - let mut prover = Prover::new(&ctx, IncrementalMode::Native); + let mut prover = Prover::new(&ctx, IncrementalMode::Native, SolverType::InternalZ3); scope.add_assumptions_to_prover(&mut prover); assert_eq!( prover.check_sat(), - SatResult::Sat, + Ok(SatResult::Sat), "SmtScope is inconsistent" ); prover.add_provable(&theorem); match prover.check_proof() { - ProveResult::Counterexample => panic!( + Ok(ProveResult::Counterexample) => panic!( "counter-example: {:?}\nassertions:\n{:?}", prover.get_model(), prover.get_assertions() ), - ProveResult::Unknown(reason) => panic!("solver returned unknown ({})", reason), - ProveResult::Proof => {} + Ok(ProveResult::Unknown(reason)) => panic!("solver returned unknown ({})", reason), + Ok(ProveResult::Proof) => {} + Err(e) => panic!("{}", e), }; }