diff --git a/Aesop/Forward/State.lean b/Aesop/Forward/State.lean index 1961fa41a..65f3f1d20 100644 --- a/Aesop/Forward/State.lean +++ b/Aesop/Forward/State.lean @@ -193,6 +193,17 @@ def erasePatSubst (imap : InstMap) (subst : Substitution) (slot : SlotIndex) : let hs := hs.erase { fvarId? := none, subst } (ms, hs) +private def pHashSetSize [BEq α] [Hashable α] (s : PHashSet α) : Nat := + s.fold (init := 0) fun n _ => n + 1 + +/-- Extract stats from an `InstMap`. -/ +def stats (imap : InstMap) : Array ForwardInstantiationStats := Id.run do + let mut stats := #[] + for (_, m) in imap.map do + for (_, (ms, hs)) in m do + stats := stats.push { «matches» := pHashSetSize ms, hyps := pHashSetSize hs } + return stats + end InstMap set_option linter.missingDocs false in @@ -310,6 +321,10 @@ where else panic! s!"substitution contains no instantiation for variable {var}" +/-- Extract stats from a `VariableMap`. -/ +def stats (vmap : VariableMap) : Array ForwardInstantiationStats := + vmap.map.foldl (init := #[]) fun acc _ imap => acc ++ imap.stats + end VariableMap /-- Structure representing the state of a slot cluster. -/ @@ -599,6 +614,11 @@ def erasePatSubst (subst : Substitution) (pi : PremiseIndex) (cs : ClusterState) completeMatches := filterPHashSet (! ·.containsPatSubst subst) cs.completeMatches } +/-- Extract stats from a `ClusterState`. -/ +def stats (cs : ClusterState) : ForwardClusterStateStats where + slots := cs.slots.size + instantiationStats := cs.variableMap.stats + end ClusterState /-- The source of a pattern substitution. The same substitution can have @@ -739,6 +759,11 @@ def eraseHyp (h : FVarId) (pi : PremiseIndex) (rs : RuleState) : RuleState := let clusterStates := rs.clusterStates.map (·.eraseHyp h pi) { rs with clusterStates } +/-- Extract stats from a `RuleState`. -/ +def stats (rs : RuleState) : ForwardRuleStateStats where + ruleName := rs.rule.name + clusterStateStats := rs.clusterStates.map (·.stats) + end RuleState /-- State representing the non-complete matches of a given set of forward rules @@ -914,4 +939,9 @@ def updateTargetPatSubsts (goal : MVarId) (fs : ForwardState) : BaseM (ForwardState × Array ForwardRuleMatch) := fs.updateTargetPatSubstsCore #[] goal newPatSubsts +/-- Extract stats from a `ForwardState`. -/ +def stats (fs : ForwardState) : ForwardStateStats where + ruleStateStats := fs.ruleStates.foldl (init := #[]) fun acc _ rs => + acc.push rs.stats + end Aesop.ForwardState diff --git a/Aesop/RPINF.lean b/Aesop/RPINF.lean index 5bd4bbc66..0af93698e 100644 --- a/Aesop/RPINF.lean +++ b/Aesop/RPINF.lean @@ -1,3 +1,8 @@ +/- +Copyright (c) 2025 Jannis Limperg. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Jannis Limperg +-/ module public import Aesop.Util.Basic @@ -64,6 +69,7 @@ where def rpinf (e : Expr) : BaseM RPINF := withConstAesopTraceNode .rpinf (return m!"rpinf") do + profiling (fun stats _ elapsed => { stats with rpinf := stats.rpinf + elapsed } ) do aesop_trace[rpinf] "input:{indentExpr e}" let e ← rpinfRaw e let hash := pinfHash e.toExpr diff --git a/Aesop/Search/Main.lean b/Aesop/Search/Main.lean index f07cf12bd..685171bc7 100644 --- a/Aesop/Search/Main.lean +++ b/Aesop/Search/Main.lean @@ -14,6 +14,7 @@ public import Aesop.Search.Expansion public import Aesop.Search.ExpandSafePrefix public import Aesop.Search.Queue public import Aesop.Tree +public import Aesop.Tree.Stats public import Aesop.Frontend.Extension public section @@ -277,10 +278,12 @@ def search (goal : MVarId) (ruleSet? : Option LocalRuleSet := none) mkLocalRuleSet rss options | some ruleSet => pure ruleSet let ⟨Q, _⟩ := options.queue - let go : SearchM _ _ := do - show SearchM Q _ from - try searchLoop - finally freeTree + let go : SearchM Q _ := do + try + searchLoop + finally + collectGoalStatsIfEnabled + freeTree let ((goals, _, _), stats) ← go.run ruleSet options simpConfig simpConfigSyntax? goal |>.run stats return (goals, stats) diff --git a/Aesop/Stats/Basic.lean b/Aesop/Stats/Basic.lean index 49d1ba59a..e4e38dab3 100644 --- a/Aesop/Stats/Basic.lean +++ b/Aesop/Stats/Basic.lean @@ -16,7 +16,41 @@ open Lean namespace Aesop --- All times are in nanoseconds. +structure ForwardInstantiationStats where + «matches» : Nat + hyps : Nat + deriving Inhabited, ToJson + +structure ForwardClusterStateStats where + slots : Nat + instantiationStats : Array ForwardInstantiationStats + deriving Inhabited, ToJson + +structure ForwardRuleStateStats where + ruleName : RuleName + clusterStateStats : Array ForwardClusterStateStats + deriving Inhabited, ToJson + +structure ForwardStateStats where + ruleStateStats : Array ForwardRuleStateStats + deriving Inhabited, ToJson + +inductive GoalKind + | preNorm + | postNorm + deriving Inhabited, ToJson + +structure GoalStats where + goalId : Nat -- We don't use GoalId to avoid an import cycle + goalKind : GoalKind + /-- Number of fvars in the local context, excluding implementation detail + fvars. -/ + lctxSize : Nat + /-- This goal's depth in the search tree. -/ + depth : Nat + forwardStateStats : ForwardStateStats + deriving Inhabited, ToJson + structure RuleStats where rule : DisplayRuleName elapsed : Nanos @@ -63,14 +97,16 @@ structure Stats where ruleSelection : Nanos script : Nanos forwardState : Nanos + rpinf : Nanos scriptGenerated : Option ScriptGenerated ruleStats : Array RuleStats + goalStats : Array GoalStats deriving Inhabited namespace Stats protected def empty : Stats := by - refine' { scriptGenerated := none, ruleStats := #[], .. } <;> exact 0 + refine' { scriptGenerated := none, ruleStats := #[], goalStats := #[], .. } <;> exact 0 instance : EmptyCollection Stats := ⟨Stats.empty⟩ @@ -144,7 +180,7 @@ def _root_.Aesop.sortRuleStatsTotals def trace (p : Stats) (opt : TraceOption) : CoreM Unit := do if ! (← opt.isEnabled) then return - let { total, configParsing, ruleSetConstruction, search, ruleSelection, script, forwardState, scriptGenerated, ruleStats } := p + let { total, configParsing, ruleSetConstruction, search, ruleSelection, script, forwardState, rpinf, scriptGenerated, ruleStats, goalStats := _goalStats } := p -- TODO print goal stats? let totalRuleApplications := ruleStats.foldl (init := 0) λ total rp => total + rp.elapsed aesop_trace![opt] "Total: {total.printAsMillis}" @@ -156,6 +192,7 @@ def trace (p : Stats) (opt : TraceOption) : CoreM Unit := do (return m!"Search: {search.printAsMillis}") do aesop_trace![opt] "Rule selection: {ruleSelection.printAsMillis}" aesop_trace![opt] "Forward state updates: {forwardState.printAsMillis}" + aesop_trace![opt] "RPINF: {rpinf.printAsMillis}" withConstAesopTraceNode opt (collapsed := false) (return m!"Rule applications: {totalRuleApplications.printAsMillis} [total / successful / failed]") do let timings := sortRuleStatsTotals p.ruleStatsTotals.toArray diff --git a/Aesop/Tree/Stats.lean b/Aesop/Tree/Stats.lean new file mode 100644 index 000000000..974c9fdd3 --- /dev/null +++ b/Aesop/Tree/Stats.lean @@ -0,0 +1,41 @@ +/- +Copyright (c) 2025 Jannis Limperg. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Jannis Limperg +-/ +module + +public import Aesop.Tree.Traversal +public import Aesop.Tree.TreeM + +public section + +namespace Aesop + +def Goal.stats (g : Goal) : TreeM GoalStats := do + let (goal, metaState) ← g.currentGoalAndMetaState (← getRootMetaState) + let decl := metaState.meta.mctx.getDecl goal + let lctxSize := decl.lctx.foldl (init := 0) fun acc ldecl => + if ldecl.isImplementationDetail then acc else acc + 1 + return { + goalId := g.id.toNat + goalKind := if g.isNormal then .postNorm else .preNorm + forwardStateStats := g.forwardState.stats + depth := g.depth + lctxSize + } + +def collectGoalStatsIfEnabled : TreeM Unit := do + if ← enableStats then + let go : StateRefT (Array GoalStats) TreeM Unit := + postTraverseDown + (fun gref => do + let stats ← (← gref.get).stats + modify (·.push stats)) + (fun _rref => return) + (fun _cref => return) + (.mvarCluster (← getThe Tree).root) + let goalStats ← (·.2) <$> go.run #[] + modifyStats ({ · with goalStats }) + +end Aesop