Skip to content

Commit 6124755

Browse files
committed
array index fold/unfold
1 parent 2947727 commit 6124755

File tree

10 files changed

+261
-50
lines changed

10 files changed

+261
-50
lines changed

prusti-encoder/src/encoders/mir_impure.rs

Lines changed: 84 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -508,8 +508,8 @@ impl<'vir, 'enc, E: TaskEncoder> ImpureEncVisitor<'vir, 'enc, E> {
508508

509509
fn pcg_repack(&mut self, repack_op: &RepackOp<'vir>) {
510510
match repack_op {
511-
RepackOp::Expand(place, _target, capability_kind)
512-
| RepackOp::Collapse(place, _target, capability_kind) => {
511+
RepackOp::Expand(place, target, capability_kind)
512+
| RepackOp::Collapse(place, target, capability_kind) => {
513513
if matches!(capability_kind, CapabilityKind::Write) {
514514
// Collapsing an already exhaled place is a no-op
515515
// TODO: unless it's through a Ref I imagine?
@@ -557,15 +557,61 @@ impl<'vir, 'enc, E: TaskEncoder> ImpureEncVisitor<'vir, 'enc, E> {
557557
.unwrap()
558558
.apply_cast_if_necessary(self.vcx, proj_app);
559559
}*/
560-
self.stmt(self.vcx.mk_unfold_stmt(predicate));
560+
match target.last_projection() {
561+
Some((_, mir::PlaceElem::Index(index_local))) => {
562+
let usize_ty_out = self
563+
.deps
564+
.require_local::<RustTySnapshotsEnc>(self.vcx.tcx().mk_ty_from_kind(ty::TyKind::Uint(ty::UintTy::Usize)))
565+
.unwrap();
566+
let index_args = self.vcx.alloc_slice(&[usize_ty_out.generic_snapshot.specifics.expect_primitive().snap_to_prim.apply(self.vcx, [self.encode_operand_snap(&mir::Operand::Copy(index_local.into()))])]
567+
.into_iter()
568+
.chain(args.into_iter().copied())
569+
.collect::<Vec<_>>());
570+
let unfold_index = place_ty_out
571+
.generic_predicate
572+
.expect_array()
573+
.unfold_index;
574+
self.stmt(
575+
self.vcx.alloc(vir::StmtGenData::new(
576+
self.vcx.alloc(unfold_index.apply(self.vcx, index_args)),
577+
)),
578+
);
579+
},
580+
_ => {
581+
self.stmt(self.vcx.mk_unfold_stmt(predicate));
582+
}
583+
}
561584
for (apply, _) in &casts {
562585
self.stmt(apply);
563586
}
564587
} else {
565588
for (_, undo) in &casts {
566589
self.stmt(undo);
567590
}
568-
self.stmt(self.vcx.mk_fold_stmt(predicate));
591+
match target.last_projection() {
592+
Some((_, mir::PlaceElem::Index(index_local))) => {
593+
let usize_ty_out = self
594+
.deps
595+
.require_local::<RustTySnapshotsEnc>(self.vcx.tcx().mk_ty_from_kind(ty::TyKind::Uint(ty::UintTy::Usize)))
596+
.unwrap();
597+
let index_args = self.vcx.alloc_slice(&[usize_ty_out.generic_snapshot.specifics.expect_primitive().snap_to_prim.apply(self.vcx, [self.encode_operand_snap(&mir::Operand::Copy(index_local.into()))])]
598+
.into_iter()
599+
.chain(args.into_iter().copied())
600+
.collect::<Vec<_>>());
601+
let fold_index = place_ty_out
602+
.generic_predicate
603+
.expect_array()
604+
.fold_index;
605+
self.stmt(
606+
self.vcx.alloc(vir::StmtGenData::new(
607+
self.vcx.alloc(fold_index.apply(self.vcx, index_args)),
608+
)),
609+
);
610+
},
611+
_ => {
612+
self.stmt(self.vcx.mk_fold_stmt(predicate));
613+
}
614+
}
569615
}
570616
}
571617
RepackOp::Weaken(place, CapabilityKind::Exclusive, CapabilityKind::Write) => {
@@ -741,6 +787,21 @@ impl<'vir, 'enc, E: TaskEncoder> ImpureEncVisitor<'vir, 'enc, E> {
741787
place: &EncodePlaceResult<'vir>,
742788
) -> Vec<(vir::Stmt<'vir>, vir::Stmt<'vir>)> {
743789
match place.ty.ty.kind() {
790+
TyKind::Array(elem_ty, _) => {
791+
// TODO: make place_casts take the actual place as argument, so
792+
// that we don't have to fake this:
793+
let elem = mir::ProjectionElem::Index(0usize.into());
794+
let proj_app = self.encode_place_element(place.ty, elem, place.expr);
795+
self.deps
796+
.require_local::<RustTyCastersEnc<CastTypeImpure>>(
797+
*elem_ty,
798+
)
799+
.unwrap()
800+
.cast_to_concrete_if_possible(self.vcx, proj_app)
801+
.into_iter()
802+
.map(|cs| (cs.apply_cast_stmt, cs.unapply_cast_stmt))
803+
.collect()
804+
}
744805
TyKind::Adt(def, _) if def.is_box() => {
745806
let proj_app =
746807
self.encode_place_element(place.ty, mir::ProjectionElem::Deref, place.expr);
@@ -853,6 +914,24 @@ impl<'vir, 'enc, E: TaskEncoder> ImpureEncVisitor<'vir, 'enc, E> {
853914
.ref_to_args(self.vcx, instantiated_ty, expr);
854915
projection_p.apply(self.vcx, proj_args)
855916
}
917+
mir::ProjectionElem::Index(v) => {
918+
let e_ty = self
919+
.deps
920+
.require_ref::<RustTyPredicatesEnc>(place_ty.ty)
921+
.unwrap();
922+
let index_access = e_ty
923+
.generic_predicate
924+
.expect_array()
925+
.index_access;
926+
let instantiated_ty = self
927+
.deps
928+
.require_local::<LiftedTyEnc<EncodeGenericsAsLifted>>(place_ty.ty)
929+
.unwrap();
930+
let proj_args = e_ty
931+
.generic_predicate
932+
.ref_to_args(self.vcx, instantiated_ty, expr);
933+
index_access.apply(self.vcx, proj_args)
934+
}
856935
// TODO: should all variants start at the same `Ref`?
857936
mir::ProjectionElem::Downcast(..) => expr,
858937
mir::ProjectionElem::Deref => {
@@ -1157,7 +1236,6 @@ impl<'vir, 'enc, E: TaskEncoder> mir::visit::Visitor<'vir> for ImpureEncVisitor<
11571236
box kind @ mir::AggregateKind::Array(elem_ty),
11581237
values,
11591238
) => {
1160-
//let e_elem_ty = self.deps.require_ref::<RustTySnapshotsEnc>(*elem_ty).unwrap();
11611239
let generic_enc = self.deps.require_ref::<GenericEnc>(()).unwrap();
11621240
let e_rvalue_ty = self.deps.require_ref::<RustTyPredicatesEnc>(rvalue_ty).unwrap();
11631241
let prim = e_rvalue_ty.generic_predicate.expect_array();
@@ -1169,7 +1247,7 @@ impl<'vir, 'enc, E: TaskEncoder> mir::visit::Visitor<'vir> for ImpureEncVisitor<
11691247
).unwrap();
11701248
let value_snaps = values.iter().map(|value| self.encode_operand_snap(value)).collect::<Vec<_>>();
11711249
let casted_values = ty_caster.apply_casts(self.vcx, value_snaps.into_iter());
1172-
prim.prim_to_snap.apply(self.vcx, [
1250+
prim.snap_data.prim_to_snap.apply(self.vcx, [
11731251
self.vcx.mk_seq_lit(self.vcx.alloc_slice(&casted_values), &generic_enc.param_snapshot),
11741252
])
11751253
}

prusti-encoder/src/encoders/type/kinds/array.rs

Lines changed: 122 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
use crate::encoders::{
2-
domain::{DomainBuilder, DomainEnc, DomainEncOutputRef, DomainEncSpecifics}, predicate::{PredicateBuilder, PredicateEncData}, rust_ty_snapshots::RustTySnapshotsEnc, snapshot::SnapshotEncOutput, PredicateEnc, PredicateEncOutputRef
2+
domain::{DomainBuilder, DomainEnc, DomainEncOutputRef, DomainEncSpecifics}, predicate::{PredicateBuilder, PredicateEncData}, rust_ty_snapshots::RustTySnapshotsEnc, snapshot::SnapshotEncOutput, GenericEnc, PredicateEnc, PredicateEncOutputRef
33
};
44
use prusti_rustc_interface::middle::ty;
55
use task_encoder::{EncodeFullError, TaskEncoder, TaskEncoderDependencies};
6-
use vir::{FunctionIdent, ToKnownArity, UnaryArity};
6+
use vir::{FunctionIdent, MethodIdent, ToKnownArity, UnaryArity, UnknownArity};
77

88
#[derive(Clone, Copy, Debug)]
99
pub struct DomainDataArray<'vir> {
@@ -24,21 +24,27 @@ impl<'vir> DomainEncSpecifics<'vir> {
2424
}
2525
}
2626

27-
// TODO: PredicateEncDataArray
27+
#[derive(Clone, Copy, Debug)]
28+
pub struct PredicateEncDataArray<'vir> {
29+
pub snap_data: DomainDataArray<'vir>,
30+
pub index_access: FunctionIdent<'vir, UnknownArity<'vir>>,
31+
pub unfold_index: MethodIdent<'vir, UnknownArity<'vir>>,
32+
pub fold_index: MethodIdent<'vir, UnknownArity<'vir>>,
33+
}
2834

2935
impl<'vir> PredicateEncOutputRef<'vir> {
3036
#[track_caller]
31-
pub fn expect_array(&self) -> DomainDataArray<'vir> {
32-
match self.specifics {
33-
PredicateEncData::Array(prim) => prim,
37+
pub fn expect_array(&self) -> &PredicateEncDataArray<'vir> {
38+
match &self.specifics {
39+
PredicateEncData::Array(data) => data,
3440
s => panic!("expected array predicate data (got {s:?})"),
3541
}
3642
}
3743
}
3844

3945
pub(crate) fn domain<'vir>(
4046
task_key: <DomainEnc as TaskEncoder>::TaskKey<'vir>,
41-
output_ref: &DomainEncOutputRef<'vir>,
47+
_output_ref: &DomainEncOutputRef<'vir>,
4248
deps: &mut TaskEncoderDependencies<'vir, DomainEnc>,
4349
builder: &mut DomainBuilder<'vir>,
4450
) -> Result<DomainEncSpecifics<'vir>, EncodeFullError<'vir, DomainEnc>> {
@@ -68,7 +74,7 @@ pub(crate) fn domain<'vir>(
6874
pub(crate) fn predicate<'vir>(
6975
_task_key: <PredicateEnc as TaskEncoder>::TaskKey<'vir>,
7076
snap: SnapshotEncOutput<'vir>,
71-
_deps: &mut TaskEncoderDependencies<'vir, PredicateEnc>,
77+
deps: &mut TaskEncoderDependencies<'vir, PredicateEnc>,
7278
generic_decls: &[vir::LocalDecl<'vir>],
7379
generic_exprs: &[vir::Expr<'vir>],
7480
builder: &mut PredicateBuilder<'vir>,
@@ -83,45 +89,131 @@ pub(crate) fn predicate<'vir>(
8389
// let ty_kind = ty.kind();
8490

8591
let snap_type = snap.snapshot;
92+
let snap_data = snap.specifics.expect_array();
8693

8794
let ref_self = builder.vcx.mk_local("self", &vir::TypeData::Ref);
8895
let ref_self_decl = builder.vcx.mk_local_decl_local(ref_self);
8996

90-
// fields
91-
// let prim_field = builder.field("val", snap_type);
92-
9397
// main predicate
9498
let self_pred = builder.predicate(
9599
"",
96100
&[ref_self_decl]
97101
.into_iter()
98102
.chain(generic_decls.iter().cloned())
99103
.collect::<Vec<_>>(),
100-
None, // Some(vir::expr! { acc_field([prim_field](ref_self)) }),
104+
None,
101105
);
102106

103107
// Ref-to-snap
104-
builder.function_snap = Some(
105-
builder
106-
.mk_function(
107-
"snap",
108-
&[ref_self_decl]
109-
.into_iter()
110-
.chain(generic_decls.iter().cloned())
111-
.collect::<Vec<_>>(),
112-
snap_type,
113-
&[vir::expr! { acc_wildcard([self_pred](ref_self, ..[generic_exprs])) }],
114-
&[],
115-
None,
116-
// Some(vir::expr! {
117-
// unfolding_wildcard ([self_pred](ref_self)) in ([prim_field](ref_self))
118-
// }),
119-
)
120-
.1,
108+
let (snap_ident, snap_func) = builder.mk_function(
109+
"snap",
110+
&[ref_self_decl]
111+
.into_iter()
112+
.chain(generic_decls.iter().cloned())
113+
.collect::<Vec<_>>(),
114+
snap_type,
115+
&[vir::expr! { acc_wildcard([self_pred](ref_self, ..[generic_exprs])) }],
116+
&[],
117+
None,
118+
);
119+
builder.function_snap = Some(snap_func);
120+
121+
// "borrowed" predicate, to frame across index accesses
122+
let borrowed_pred = builder.predicate(
123+
"borrowed",
124+
&[ref_self_decl]
125+
.into_iter()
126+
.chain(generic_decls.iter().cloned())
127+
.collect::<Vec<_>>(),
128+
None,
129+
);
130+
let borrowed_snap = builder.function(
131+
"borrowed_snap",
132+
&[ref_self_decl]
133+
.into_iter()
134+
.chain(generic_decls.iter().cloned())
135+
.collect::<Vec<_>>(),
136+
snap_type,
137+
&[vir::expr! { acc_wildcard([borrowed_pred](ref_self, ..[generic_exprs])) }],
138+
&[],
139+
None,
140+
);
141+
142+
let index_access = builder.function(
143+
"index",
144+
&[ref_self_decl].into_iter()
145+
.chain(generic_decls.iter().cloned())
146+
.collect::<Vec<_>>(),
147+
&vir::TypeData::Ref,
148+
&[], // TODO: should have a read permission here!
149+
&[],
150+
None,
151+
);
152+
153+
// unfold/fold index
154+
let self_snap = vir::expr! { [snap_ident](ref_self, ..[generic_exprs]) };
155+
let self_val = vir::expr! { [snap_data.snap_to_prim](self_snap) };
156+
let index = builder.vcx.mk_local("index", &vir::TypeData::Int);
157+
let index_decl = builder.vcx.mk_local_decl_local(index);
158+
let generic_enc = deps.require_ref::<GenericEnc>(())?;
159+
let index_val = builder.vcx.mk_bin_op_expr(vir::BinOpKind::SeqIndex, self_val, vir::expr! { index });
160+
161+
let unfold_index = builder.method(
162+
"unfold_index",
163+
&[index_decl, ref_self_decl]
164+
.into_iter()
165+
.chain(generic_decls.iter().cloned())
166+
.collect::<Vec<_>>(),
167+
&[],
168+
&[
169+
vir::expr! { acc([self_pred](ref_self, ..[generic_exprs])) },
170+
vir::expr! { ((0) <= (index)) && ((index) < (vpr_seq_len(self_val))) },
171+
],
172+
&[
173+
vir::expr! { acc([borrowed_pred](ref_self, ..[generic_exprs])) },
174+
vir::expr! { acc([generic_enc.ref_to_pred]([index_access](ref_self, ..[generic_exprs]), ..[generic_exprs])) },
175+
vir::expr! { ([borrowed_snap](ref_self, ..[generic_exprs])) == (old(self_snap)) },
176+
vir::expr! { ([generic_enc.ref_to_snap]([index_access](ref_self, ..[generic_exprs]), ..[generic_exprs])) == (old(index_val)) },
177+
],
178+
);
179+
180+
let fold_index = builder.method(
181+
"fold_index",
182+
&[index_decl, ref_self_decl]
183+
.into_iter()
184+
.chain(generic_decls.iter().cloned())
185+
.collect::<Vec<_>>(),
186+
&[],
187+
&[
188+
vir::expr! { acc([borrowed_pred](ref_self, ..[generic_exprs])) },
189+
vir::expr! { acc([generic_enc.ref_to_pred]([index_access](ref_self, ..[generic_exprs]), ..[generic_exprs])) },
190+
vir::expr! { ((0) <= (index)) && ((index) < (vpr_seq_len([snap_data.snap_to_prim]([borrowed_snap](ref_self, ..[generic_exprs]))))) },
191+
],
192+
&[
193+
vir::expr! { acc([self_pred](ref_self, ..[generic_exprs])) },
194+
vir::expr! { (vpr_seq_len(self_val)) == (old(vpr_seq_len([snap_data.snap_to_prim]([borrowed_snap](ref_self, ..[generic_exprs]))))) },
195+
vir::expr! {
196+
forall i: [&vir::TypeData::Int] :: {[builder.vcx.mk_bin_op_expr(vir::BinOpKind::SeqIndex, self_val, vir::expr! { i })]} (((0) <= (i)) && ((i) < (vpr_seq_len(self_val))))
197+
==> (([builder.vcx.mk_bin_op_expr(vir::BinOpKind::SeqIndex, self_val, vir::expr! { i })]) == ([builder.vcx.mk_ternary_expr(
198+
vir::expr! { (i) == (index) },
199+
vir::expr! { old([generic_enc.ref_to_snap]([index_access](ref_self, ..[generic_exprs]), ..[generic_exprs])) },
200+
vir::expr! { old([builder.vcx.mk_bin_op_expr(
201+
vir::BinOpKind::SeqIndex,
202+
vir::expr! { [snap_data.snap_to_prim]([borrowed_snap](ref_self, ..[generic_exprs])) },
203+
vir::expr! { i },
204+
)]) },
205+
)]))
206+
},
207+
],
121208
);
122209

123210
Ok((
124-
PredicateEncData::Array(snap.specifics.expect_array()),
211+
PredicateEncData::Array(PredicateEncDataArray {
212+
snap_data: snap.specifics.expect_array(),
213+
index_access,
214+
unfold_index,
215+
fold_index,
216+
}),
125217
None,
126218
))
127219
}

prusti-encoder/src/encoders/type/predicate.rs

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,12 @@ use vir::{
88
use crate::encoders::GenericEnc;
99

1010
use super::{
11-
domain::DomainDataArray, kinds::primitive::DomainDataPrim, lifted::{generic::LiftedGeneric, ty::LiftedTy}, most_generic_ty::{get_vir_base_name_kind, MostGenericTy}, snapshot::SnapshotEnc
11+
kinds::primitive::DomainDataPrim, lifted::{generic::LiftedGeneric, ty::LiftedTy}, most_generic_ty::{get_vir_base_name_kind, MostGenericTy}, snapshot::SnapshotEnc
1212
};
1313

1414
pub use super::kinds::{
1515
adt::PredicateEncDataEnum,
16+
array::PredicateEncDataArray,
1617
immref::PredicateEncDataImmRef,
1718
mutref::PredicateEncDataMutRef,
1819
structlike::PredicateEncDataStruct,
@@ -30,7 +31,7 @@ pub enum PredicateEncError {
3031
#[derive(Clone, Copy, Debug)]
3132
pub enum PredicateEncData<'vir> {
3233
Never,
33-
Array(DomainDataArray<'vir>),
34+
Array(PredicateEncDataArray<'vir>),
3435
Primitive(DomainDataPrim<'vir>),
3536
// structs, tuples
3637
StructLike(PredicateEncDataStruct<'vir>),
@@ -266,7 +267,7 @@ impl<'vir> PredicateBuilder<'vir> {
266267
unreachable_to_snap: self.unreachable_to_snap.unwrap().1,
267268
function_snap: self.function_snap.unwrap(),
268269
ref_to_field_refs: self.functions,
269-
method_assign: self.methods[0],
270+
methods: self.methods,
270271
}
271272
}
272273
}
@@ -279,7 +280,7 @@ pub struct PredicateEncOutput<'vir> {
279280
pub unreachable_to_snap: vir::Function<'vir>,
280281
pub function_snap: vir::Function<'vir>,
281282
pub ref_to_field_refs: Vec<vir::Function<'vir>>,
282-
pub method_assign: vir::Method<'vir>,
283+
pub methods: Vec<vir::Method<'vir>>,
283284
}
284285

285286
impl TaskEncoder for PredicateEnc {
@@ -345,7 +346,7 @@ impl TaskEncoder for PredicateEnc {
345346
unreachable_to_snap: dep.unreachable_to_snap,
346347
function_snap: dep.ref_to_snap,
347348
ref_to_field_refs: vec![],
348-
method_assign,
349+
methods: vec![method_assign],
349350
},
350351
(),
351352
))

0 commit comments

Comments
 (0)