Skip to content

Commit 39880cc

Browse files
committed
cleanup the PR
1 parent 9bbdced commit 39880cc

File tree

3 files changed

+34
-22
lines changed

3 files changed

+34
-22
lines changed

src/Lean/Meta/Basic.lean

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,18 @@ We should also investigate the impact on memory consumption.
362362
-/
363363
abbrev DefEqCache := PersistentHashMap DefEqCacheKey Bool
364364

365+
/--
366+
A `DefEqTransCache` is a `DefEqCache` that is only valid in the original `MetavarContext`.
367+
It stores of the `numAssignments` from that original `MetavarContext`.
368+
If the `numAssignments` in the `MetavarContext` has increased, we invalidate this cache.
369+
And when we revert the metavariable context in `checkpointDefEq`, if the `numAssignments`
370+
in the original `MetavarContext` is smaller than in the cache, we revert the cache to its original.
371+
-/
372+
structure DefEqTransCache where
373+
cache : DefEqCache := {}
374+
numAssignments : Nat := 0
375+
deriving Inhabited
376+
365377
/--
366378
Cache datastructures for type inference, type class resolution, whnf, and definitional equality.
367379
-/
@@ -370,7 +382,7 @@ structure Cache where
370382
funInfo : FunInfoCache := {}
371383
synthInstance : SynthInstanceCache := {}
372384
whnf : WhnfCache := {}
373-
defEqTrans : DefEqCache × Nat := ({}, 0) -- transient cache for terms containing mvars or using nonstandard configuration options, it is valid as long as the count matches `MetavarContext.numAssignments`.
385+
defEqTrans : DefEqTransCache := {} -- transient cache for terms containing mvars or using nonstandard configuration options, it is valid as long as the count matches `MetavarContext.numAssignments`.
374386
defEqPerm : DefEqCache := {} -- permanent cache for terms not containing mvars and using standard configuration options
375387
deriving Inhabited
376388

@@ -629,9 +641,9 @@ def resetCache : MetaM Unit :=
629641

630642
@[inline] def modifyDefEqTransientCache (numAssignments : Nat) (f : DefEqCache → DefEqCache) : MetaM Unit :=
631643
modifyCache fun c =>
632-
let (transCache, numAssignmentsOld) := c.defEqTrans
644+
let transCache, numAssignmentsOld := c.defEqTrans
633645
let transCache := if numAssignments == numAssignmentsOld then transCache else {}
634-
{ c with defEqTrans := (f transCache, numAssignments) }
646+
{ c with defEqTrans := f transCache, numAssignments }
635647

636648
@[inline] def modifyDefEqPermCache (f : DefEqCache → DefEqCache) : MetaM Unit :=
637649
modifyCache fun ⟨c1, c2, c3, c4, c5, defeqPerm⟩ => ⟨c1, c2, c3, c4, c5, f defeqPerm⟩
@@ -650,7 +662,7 @@ def mkInfoCacheKey (expr : Expr) (nargs? : Option Nat) : MetaM InfoCacheKey :=
650662
return { expr, nargs?, configKey := (← read).configKey }
651663

652664
@[inline] def resetDefEqTransientCache : MetaM Unit :=
653-
modify fun s => { s with cache.defEqTrans := ({}, s.mctx.numAssignments) }
665+
modify fun s => { s with cache.defEqTrans := {}, s.mctx.numAssignments }
654666

655667
@[inline] def resetDefEqPermCaches : MetaM Unit :=
656668
modifyDefEqPermCache fun _ => {}
@@ -2246,13 +2258,13 @@ partial def processPostponed (mayPostpone : Bool := true) (exceptionOnFailure :=
22462258
return true
22472259
else
22482260
-- The transient cache needs to be reverted if it assumes an assignments that is being reverted.
2249-
let invalidCache := s.meta.mctx.numAssignments < (← get).cache.defEqTrans.2
2250-
s.restore (transCache := invalidCache)
2261+
let isInvalidCache := s.meta.mctx.numAssignments != (← get).cache.defEqTrans.numAssignments
2262+
s.restore (transCache := isInvalidCache)
22512263
return false
22522264
else
22532265
-- The transient cache needs to be reverted if it assumes an assignments that is being reverted.
2254-
let invalidCache := s.meta.mctx.numAssignments < (← get).cache.defEqTrans.2
2255-
s.restore (transCache := invalidCache)
2266+
let isInvalidCache := s.meta.mctx.numAssignments != (← get).cache.defEqTrans.numAssignments
2267+
s.restore (transCache := isInvalidCache)
22562268
return false
22572269
catch ex =>
22582270
s.restore
@@ -2293,8 +2305,8 @@ def isExprDefEq (t s : Expr) : MetaM Bool :=
22932305
We have tried in the past to track when the result was independent of the `MetavarContext` state
22942306
but it was not effective. It is more important to cache aggressively inside of a single `isDefEq`
22952307
call because some of the heuristics create many similar subproblems.
2296-
See issue #1102 for an example that triggers an exponential blowup if we don't use this more
2297-
aggressive form of caching.
2308+
See issue #1102 and `tests/lean/run/defEqTransCache.lean` for examples that trigger an exponential blowup
2309+
if we don't use this more aggressive form of caching.
22982310
-/
22992311
resetDefEqTransientCache
23002312
checkpointDefEq (mayPostpone := true) <| Meta.isExprDefEqAux t s

src/Lean/Meta/ExprDefEq.lean

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,7 @@ private partial def isDefEqArgs (f : Expr) (args₁ args₂ : Array Expr) : Meta
312312
let info := finfo.paramInfo[i]!
313313
if info.isInstImplicit then
314314
unless (← withInferTypeConfig <| Meta.isExprDefEqAux a₁ a₂) do
315-
return false
315+
return false
316316
else
317317
unless (← Meta.isExprDefEqAux a₁ a₂) do
318318
return false
@@ -2070,7 +2070,7 @@ private def mkCacheKey (t s : Expr) : MetaM DefEqCacheKeyInfo := do
20702070
private def getCachedResult (keyInfo : DefEqCacheKeyInfo) : MetaM LBool := do
20712071
let cache ← match keyInfo.kind with
20722072
| .transient numAssignments =>
2073-
let (cache, numAssignmentsCache) := (← get).cache.defEqTrans
2073+
let cache, numAssignmentsCache := (← get).cache.defEqTrans
20742074
if numAssignments == numAssignmentsCache then
20752075
pure cache
20762076
else
@@ -2087,14 +2087,14 @@ private def cacheResult (keyInfo : DefEqCacheKeyInfo) (result : Bool) : MetaM Un
20872087
| .transient numAssignmentsOld =>
20882088
/-
20892089
If the result is `false`, we cache it at `numAssignmentsOld`.
2090-
If the result is `true`, we check that the number of assignments hasn't increased.
2090+
If the result is `true`, we only cache it if the number of assignments hasn't increase.
20912091
-/
20922092
if !result then
20932093
modifyDefEqTransientCache numAssignmentsOld fun c => c.insert key result
20942094
else
20952095
let numAssignmentsNew := (← getMCtx).numAssignments
20962096
if numAssignmentsOld == numAssignmentsNew then
2097-
modifyDefEqTransientCache numAssignmentsNew fun c => c.insert key result
2097+
modifyDefEqTransientCache numAssignmentsOld fun c => c.insert key result
20982098

20992099
private def whnfCoreAtDefEq (e : Expr) : MetaM Expr := do
21002100
if backward.isDefEq.lazyWhnfCore.get (← getOptions) then

tests/lean/run/defEqTransCache.lean

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@ import Lean
22
/-!
33
Previously, unification wouldn't be very careful with the `isDefEq` cache for terms containing metavariables.
44
- This is mostly problematic because erasing the cache leads to exponential slowdowns (`test1` & `test2`)
5-
- but in some cases it leads to metavariable assignments leaking into places where they shouldn't be,
6-
which either causes unification to fail where it should succeed (`test3`)
7-
or to succeed where it is expected to fail.
8-
5+
- but in some cases it lead to metavariable assignments leaking into places where they shouldn't be,
6+
which either caused unification to fail where it should succeed (`test3`)
7+
or to succeed where it is expected to fail (which happened in one mathlib proof).
98
-/
9+
1010
set_option maxHeartbeats 1000
1111

1212
namespace test1
@@ -16,10 +16,10 @@ class A (n : Nat) where
1616
instance [A n] : A (n+1) where
1717
x := A.x n
1818

19-
theorem test [A 0] : A.x 100 = sorry := sorry
19+
theorem test [A 0] : A.x 100 = 0 := sorry
2020

21-
-- Previously, this example was exponentially slow
22-
example [A 1] : A.x 100 = sorry := by
21+
-- This rewrite should fail. Previously, it failed exponentially slowly
22+
example [A 1] : A.x 100 = 0 := by
2323
fail_if_success rw [@test]
2424
sorry
2525
end test1
@@ -66,7 +66,7 @@ elab "unfold_head" e:term : term => do
6666
let e ← Elab.Term.elabTerm e none
6767
unfoldDefinition e
6868

69-
-- we use `unfold_head` in order to get the raw kernel projection `·.1` instead of the projection funtcion `A.x`.
69+
-- use `unfold_head` to get the raw kernel projection `·.1` instead of the projection funtcion `A.x`
7070
def test {α} (i : B α) : unfold_head i.toA.x := sorry
7171

7272
-- Previously, in this example the unification failed,

0 commit comments

Comments
 (0)