Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions crates/flux-infer/src/lean_encoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,7 @@ use rustc_span::ErrorGuaranteed;

use crate::{
fixpoint_encoding::{ConstDeps, InterpretedConst, KVarSolutions, SortDeps, fixpoint},
lean_format::{
self, BoolMode, LeanCtxt, WithLeanCtxt, def_id_to_pascal_case, snake_case_to_pascal_case,
},
lean_format::{self, LeanCtxt, WithLeanCtxt, def_id_to_pascal_case, snake_case_to_pascal_case},
};

/// Helper macro to create Vec<String> from string-like values
Expand Down Expand Up @@ -290,7 +288,6 @@ impl<'genv, 'tcx> LeanEncoder<'genv, 'tcx> {
adt_map: &self.sort_deps.adt_map,
opaque_adt_map: &self.sort_deps.opaque_sorts,
kvar_solutions: &self.kvar_solutions,
bool_mode: BoolMode::Bool,
}
}

Expand All @@ -314,6 +311,10 @@ impl<'genv, 'tcx> LeanEncoder<'genv, 'tcx> {
snake_case_to_pascal_case(&name)
}

fn open_classical(&self) -> &str {
"open Classical"
}

fn new(
genv: GlobalEnv<'genv, 'tcx>,
def_id: MaybeExternId,
Expand Down Expand Up @@ -416,6 +417,7 @@ impl<'genv, 'tcx> LeanEncoder<'genv, 'tcx> {
let path = file.path(self.genv);
if let Some(mut file) = create_file_with_dirs(path)? {
writeln!(file, "{}", &LeanFile::Fluxlib.import(self.genv))?;
writeln!(file, "{}", self.open_classical())?;
namespaced(&mut file, |f| {
writeln!(f, "def {} := sorry", WithLeanCtxt { item: sort, cx: &self.lean_cx() })
})?;
Expand Down Expand Up @@ -456,6 +458,7 @@ impl<'genv, 'tcx> LeanEncoder<'genv, 'tcx> {
for dep in self.data_decl_dependencies(data_decl) {
writeln!(file, "{}", dep.import(self.genv))?;
}
writeln!(file, "{}", self.open_classical())?;

// write data decl
namespaced(&mut file, |f| {
Expand Down Expand Up @@ -527,6 +530,7 @@ impl<'genv, 'tcx> LeanEncoder<'genv, 'tcx> {
for dep in self.fun_def_dependencies(did, fun_def) {
writeln!(file, "{}", dep.import(self.genv))?;
}
writeln!(file, "{}", self.open_classical())?;

// write fun def
namespaced(&mut file, |f| {
Expand Down Expand Up @@ -555,6 +559,8 @@ impl<'genv, 'tcx> LeanEncoder<'genv, 'tcx> {
writeln!(file, "{}", self.sort_file(&dep).import(self.genv))?;
}

writeln!(file, "{}", self.open_classical())?;

namespaced(&mut file, |f| {
if let Some(comment) = &const_decl.comment {
writeln!(f, "--{comment}")?;
Expand Down Expand Up @@ -683,6 +689,7 @@ impl<'genv, 'tcx> LeanEncoder<'genv, 'tcx> {
let path = LeanFile::Vc(def_id).path(self.genv);
if let Some(mut file) = create_file_with_dirs(path)? {
self.generate_vc_imports(&mut file)?;
writeln!(file, "{}", self.open_classical())?;

let vc_name = vc_name(self.genv, def_id);
// 3. Write the VC
Expand Down Expand Up @@ -715,6 +722,7 @@ impl<'genv, 'tcx> LeanEncoder<'genv, 'tcx> {
if let Some(mut file) = create_file_with_dirs(path)? {
writeln!(file, "{}", LeanFile::Fluxlib.import(self.genv))?;
writeln!(file, "{}", LeanFile::Vc(def_id).import(self.genv))?;
writeln!(file, "{}", self.open_classical())?;
namespaced(&mut file, |f| {
writeln!(f, "def {proof_name} : {vc_name} := by")?;
writeln!(f, " unfold {vc_name}")?;
Expand Down
49 changes: 7 additions & 42 deletions crates/flux-infer/src/lean_format.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,25 +23,12 @@ use crate::fixpoint_encoding::{
},
};

#[derive(Debug, Clone, Copy)]
pub enum BoolMode {
Bool,
Prop,
}

pub struct LeanCtxt<'a, 'genv, 'tcx> {
pub genv: GlobalEnv<'genv, 'tcx>,
pub pretty_var_map: &'a PrettyMap<LocalVar>,
pub adt_map: &'a FxIndexSet<DefId>,
pub opaque_adt_map: &'a [(FluxDefId, SortDecl)],
pub kvar_solutions: &'a KVarSolutions,
pub bool_mode: BoolMode,
}

impl<'a, 'genv, 'tcx> LeanCtxt<'a, 'genv, 'tcx> {
pub(crate) fn with_bool_mode(&self, bool_mode: BoolMode) -> Self {
LeanCtxt { bool_mode, ..*self }
}
}

pub struct WithLeanCtxt<'a, 'b, 'genv, 'tcx, T> {
Expand Down Expand Up @@ -320,12 +307,7 @@ impl LeanFmt for Sort {
fn lean_fmt(&self, f: &mut std::fmt::Formatter, cx: &LeanCtxt) -> std::fmt::Result {
match self {
Sort::Int => write!(f, "Int"),
Sort::Bool => {
match cx.bool_mode {
BoolMode::Bool => write!(f, "Bool"),
BoolMode::Prop => write!(f, "Prop"),
}
}
Sort::Bool => write!(f, "Prop"),
Sort::Real => write!(f, "Real"),
Sort::Str => write!(f, "String"),
Sort::Func(f_sort) => {
Expand Down Expand Up @@ -394,12 +376,7 @@ impl LeanFmt for Expr {
Expr::Constant(c) => {
match c {
Constant::Numeral(n) => write!(f, "{n}",),
Constant::Boolean(b) => {
match cx.bool_mode {
BoolMode::Bool => write!(f, "{}", if *b { "true" } else { "false" }),
BoolMode::Prop => write!(f, "{}", if *b { "True" } else { "False" }),
}
}
Constant::Boolean(b) => write!(f, "{}", if *b { "True" } else { "False" }),
Constant::String(s) => write!(f, "{}", s.display()),
Constant::Real(n) => write!(f, "{n}.0"),
Constant::BitVec(bv, size) => write!(f, "{}#{}", bv, size),
Expand Down Expand Up @@ -452,8 +429,7 @@ impl LeanFmt for Expr {
write!(f, ")")?;
if let Some(out_sort) = out_sort {
write!(f, " : (")?;
let sort_cx = cx.with_bool_mode(BoolMode::Bool);
out_sort.lean_fmt(f, &sort_cx)?;
out_sort.lean_fmt(f, &cx)?;
write!(f, "))")?;
}
Ok(())
Expand All @@ -462,10 +438,7 @@ impl LeanFmt for Expr {
write!(f, "(")?;
for (i, expr) in exprs.iter().enumerate() {
if i > 0 {
match cx.bool_mode {
BoolMode::Bool => write!(f, " && ")?,
BoolMode::Prop => write!(f, " ∧ ")?,
};
write!(f, " ∧ ")?;
}
expr.lean_fmt(f, cx)?;
}
Expand All @@ -475,10 +448,7 @@ impl LeanFmt for Expr {
write!(f, "(")?;
for (i, expr) in exprs.iter().enumerate() {
if i > 0 {
match cx.bool_mode {
BoolMode::Bool => write!(f, " || ")?,
BoolMode::Prop => write!(f, " ∨ ")?,
};
write!(f, " ∨ ")?;
}
expr.lean_fmt(f, cx)?;
}
Expand All @@ -501,10 +471,7 @@ impl LeanFmt for Expr {
}
Expr::Not(inner) => {
write!(f, "(")?;
match cx.bool_mode {
BoolMode::Bool => write!(f, "!")?,
BoolMode::Prop => write!(f, "¬")?,
};
write!(f, "¬")?;
inner.as_ref().lean_fmt(f, cx)?;
write!(f, ")")
}
Expand Down Expand Up @@ -559,7 +526,7 @@ impl LeanFmt for Expr {
impl LeanFmt for FunDef {
fn lean_fmt(&self, f: &mut fmt::Formatter, cx: &LeanCtxt) -> fmt::Result {
let FunDef { name, sort, comment: _, body } = self;
write!(f, "def ")?;
write!(f, "noncomputable def ")?;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this necessary as soon as you open Classical, or do you need it for some specific functions?

name.lean_fmt(f, cx)?;
if let Some(body) = body {
for (arg, arg_sort) in iter::zip(&body.args, &sort.inputs) {
Expand Down Expand Up @@ -675,8 +642,6 @@ impl<'a> LeanFmt for LeanKConstraint<'a> {
if !cx.kvar_solutions.is_empty() {
writeln!(f, "namespace {namespace}\n")?;

let cx = cx.with_bool_mode(BoolMode::Prop);

if !cx.kvar_solutions.cut_solutions.is_empty() {
writeln!(f, "-- cyclic (cut) kvars")?;
for kvar_solution in &cx.kvar_solutions.cut_solutions {
Expand Down
Loading