Skip to content

Commit fe046cc

Browse files
committed
feat: Derive Arbitrary instances for inductive datatypes.
1 parent 9f49266 commit fe046cc

17 files changed

+1542
-4
lines changed

Plausible.lean

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,5 @@ import Plausible.Testable
1010
import Plausible.Functions
1111
import Plausible.Attr
1212
import Plausible.Tactic
13+
import Plausible.Arbitrary
14+
import Plausible.DeriveArbitrary

Plausible/ArbitraryFueled.lean

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
/-
2+
Copyright (c) 2025 AWS. All rights reserved.
3+
Released under Apache 2.0 license as described in the file LICENSE.
4+
Authors: AWS
5+
-/
6+
import Plausible.Arbitrary
7+
import Plausible.Gen
8+
9+
namespace Plausible
10+
11+
open Gen
12+
13+
/-- A typeclass for *fueled* random generation, i.e. a variant of
14+
the `Arbitrary` typeclass where the fuel for the generator is made explicit.
15+
- This typeclass is equivalent to Rocq QuickChick's `arbitrarySized` typeclass
16+
(QuickChick uses the `Nat` parameter as both fuel and the generator size,
17+
here we use it just for fuel, as Plausible's `Gen` type constructor
18+
already internalizes the size parameter.) -/
19+
class ArbitraryFueled (α : Type) where
20+
/-- Takes a `Nat` and produces a random generator dependent on the `Nat` parameter
21+
(which indicates the amount of fuel to be used before failing). -/
22+
arbitraryFueled : Nat → Gen α
23+
24+
/-- Every `ArbitraryFueled` instance gives rise to an `Arbitrary` instance -/
25+
instance [ArbitraryFueled α] : Arbitrary α where
26+
arbitrary := Gen.sized ArbitraryFueled.arbitraryFueled
27+
28+
/-- Raised when a fueled generator fails due to insufficient fuel. -/
29+
def Gen.outOfFuel : GenError :=
30+
.genError "out of fuel"
31+
32+
end Plausible

Plausible/Attr.lean

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,4 @@ initialize registerTraceClass `plausible.discarded
1313
initialize registerTraceClass `plausible.success
1414
initialize registerTraceClass `plausible.shrink.steps
1515
initialize registerTraceClass `plausible.shrink.candidates
16+
initialize registerTraceClass `plausible.deriving.arbitrary

Plausible/DeriveArbitrary.lean

Lines changed: 321 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,321 @@
1+
/-
2+
Copyright (c) 2025 Ernest Ng. All rights reserved.
3+
Released under Apache 2.0 license as described in the file LICENSE.
4+
Authors: Ernest Ng
5+
-/
6+
import Lean.Elab
7+
import Lean.Elab.Deriving.Basic
8+
import Lean.Elab.Deriving.Util
9+
10+
import Plausible.Arbitrary
11+
import Plausible.ArbitraryFueled
12+
13+
open Lean Elab Meta Parser Term
14+
open Elab.Deriving
15+
open Elab.Command
16+
17+
/-!
18+
19+
# Deriving Handler for `Arbitrary`
20+
21+
This file defines a handler which automatically derives `Arbitrary` instances
22+
for inductive types.
23+
24+
(Note that the deriving handler technically derives `ArbitraryFueled` instancces,
25+
but every `ArbitraryFueled` instance automatically results in an `Arbitrary` instance,
26+
as detailed in `Arbitrary.lean`.)
27+
28+
Note that the resulting `Arbitrary` and `ArbitraryFueled` instance should be considered
29+
to be opaque, following the convention for the deriving handler for Mathlib's `Encodable` typeclass.
30+
31+
Example usage:
32+
33+
```lean
34+
-- Datatype for binary trees
35+
inductive Tree
36+
| Leaf : Tree
37+
| Node : Nat → Tree → Tree → Tree
38+
deriving Arbitrary
39+
```
40+
41+
To sample from a derived generator, users can simply call `Arbitrary.runArbitrary`, specify the type
42+
for the desired generated values and provide some Nat to act as the generator's fuel parameter (10 in the example below):
43+
44+
```lean
45+
#eval Arbitrary.runArbitrary (α := Tree) 10
46+
```
47+
48+
To view the code for the derived generator, users can enable trace messages using the `plausible.deriving.arbitrary` trace class as follows:
49+
50+
```lean
51+
set_option trace.plausible.deriving.arbitrary true
52+
```
53+
54+
## Main definitions
55+
* Deriving handler for `ArbitraryFueled` typeclass
56+
57+
-/
58+
59+
namespace Plausible
60+
61+
open Arbitrary
62+
63+
/-- Takes the name of a constructor for an algebraic data type and returns an array
64+
containing `(argument_name, argument_type)` pairs.
65+
66+
If the algebraic data type is defined using anonymous constructor argument syntax, i.e.
67+
```
68+
inductive T where
69+
C1 : τ1 → … → τn
70+
71+
```
72+
Lean produces macro scopes when we try to access the names for the constructor args.
73+
In this case, we remove the macro scopes so that the name is user-accessible.
74+
(This will result in constructor argument names being non-unique in the array
75+
that is returned -- it is the caller's responsibility to produce fresh names.)
76+
-/
77+
def getCtorArgsNamesAndTypes (header : Header) (indVal : InductiveVal) (ctorName : Name) : MetaM (Array (Name × Expr)) := do
78+
let ctorInfo ← getConstInfoCtor ctorName
79+
80+
forallTelescopeReducing ctorInfo.type fun args _ => do
81+
let mut argNamesAndTypes := #[]
82+
83+
for i in *...args.size do
84+
let arg := args[i]!
85+
let localDecl ← arg.fvarId!.getDecl
86+
let argType := localDecl.type
87+
88+
let argName ← if i < indVal.numParams then pure header.argNames[i]! else Core.mkFreshUserName `a
89+
if i < indVal.numParams then
90+
continue
91+
else
92+
argNamesAndTypes := argNamesAndTypes.push (argName, argType)
93+
94+
return argNamesAndTypes
95+
96+
-- Note: the following functions closely follow the implementation of the deriving handler for `Repr` / `BEq`
97+
-- (see https://github.com/leanprover/lean4/blob/master/src/Lean/Elab/Deriving/Repr.lean).
98+
99+
open TSyntax.Compat in
100+
/-- Variant of `Deriving.Util.mkHeader` where we don't add an explicit binder
101+
of the form `($targetName : $targetType)` to the field `binders`
102+
(i.e. `binders` contains only implicit binders) -/
103+
def mkHeaderWithOnlyImplicitBinders (className : Name) (arity : Nat) (indVal : InductiveVal) : TermElabM Header := do
104+
let argNames ← mkInductArgNames indVal
105+
let binders ← mkImplicitBinders argNames
106+
let targetType ← mkInductiveApp indVal argNames
107+
let mut targetNames := #[]
108+
for _ in [:arity] do
109+
targetNames := targetNames.push (← mkFreshUserName `x)
110+
let binders := binders ++ (← mkInstImplicitBinders className indVal argNames)
111+
return {
112+
binders := binders
113+
argNames := argNames
114+
targetNames := targetNames
115+
targetType := targetType
116+
}
117+
118+
open TSyntax.Compat in
119+
/-- Variant of `Deriving.Util.mkInstanceCmds` which is specialized to creating `ArbitraryFueled` instances
120+
that have `Arbitrary` inst-implicit binders.
121+
122+
Note that we can't use `mkInstanceCmds` out of the box,
123+
since it expects the inst-implicit binders and the instance we're creating to both belong to the same typeclass. -/
124+
def mkArbitraryFueledInstanceCmds (ctx : Deriving.Context) (typeNames : Array Name) (useAnonCtor := true) : TermElabM (Array Command) := do
125+
let mut instances := #[]
126+
for i in [:ctx.typeInfos.size] do
127+
let indVal := ctx.typeInfos[i]!
128+
if typeNames.contains indVal.name then
129+
let auxFunName := ctx.auxFunNames[i]!
130+
let argNames ← mkInductArgNames indVal
131+
let binders ← mkImplicitBinders argNames
132+
let binders := binders ++ (← mkInstImplicitBinders ``Arbitrary indVal argNames) -- this line is changed from
133+
let indType ← mkInductiveApp indVal argNames
134+
let type ← `($(mkCIdent ``ArbitraryFueled) $indType)
135+
let mut val := mkIdent auxFunName
136+
if useAnonCtor then
137+
val ← `(⟨$val⟩)
138+
let instCmd ← `(instance $binders:implicitBinder* : $type := $val)
139+
instances := instances.push instCmd
140+
return instances
141+
142+
/-- Creates a `Header` for the `Arbitrary` typeclass -/
143+
def mkArbitraryHeader (indVal : InductiveVal) : TermElabM Header :=
144+
mkHeaderWithOnlyImplicitBinders ``Arbitrary 1 indVal
145+
146+
/-- Creates the *body* of the generator that appears in the instance of the `ArbitraryFueled` typeclass -/
147+
def mkBody (header : Header) (inductiveVal : InductiveVal) (generatorType : TSyntax `term) : TermElabM Term := do
148+
-- Fetch the name of the target type (the type for which we are deriving a generator)
149+
let targetTypeName := inductiveVal.name
150+
151+
-- Produce `Ident`s for the `fuel` argument for the lambda
152+
-- at the end of the generator function, as well as the `aux_arb` inner helper function
153+
let freshFuel := Lean.mkIdent (← Core.mkFreshUserName `fuel)
154+
let freshFuel' := Lean.mkIdent (← Core.mkFreshUserName `fuel')
155+
let auxArb := mkIdent `aux_arb
156+
157+
-- Maintain two arrays which will be populated with pairs
158+
-- where the first component is a sub-generator (non-recursive / recursive)
159+
-- and the 2nd component is the generator's associated weight
160+
let mut weightedNonRecursiveGenerators := #[]
161+
let mut weightedRecursiveGenerators := #[]
162+
163+
-- We also need to keep track of non-recursive generators without their weights,
164+
-- since some of Plausible's `Gen` combinators operate on generator functions
165+
let mut nonRecursiveGeneratorsNoWeights := #[]
166+
167+
for ctorName in inductiveVal.ctors do
168+
let ctorIdent := mkIdent ctorName
169+
170+
let ctorArgNamesTypes ← getCtorArgsNamesAndTypes header inductiveVal ctorName
171+
let (ctorArgNames, ctorArgTypes) := Array.unzip ctorArgNamesTypes
172+
173+
/- Produce fresh names for each of the constructor's arguments.
174+
Producing fresh names is necessary in order to handle
175+
constructors expressed using the following syntax:
176+
```
177+
inductive Foo
178+
| C : T1 → ... → Tn
179+
```
180+
in which all the arguments to the constructor `C` don't have explicit names.
181+
-/
182+
let ctorArgIdents := Lean.mkIdent <$> ctorArgNames
183+
let ctorArgIdentsTypes := Array.zip ctorArgIdents ctorArgTypes
184+
185+
if ctorArgNamesTypes.isEmpty then
186+
-- Constructor is nullary, we can just use an generator of the form `pure ...` with weight 1,
187+
-- following the QuickChick convention.
188+
-- (For clarity, this generator is parenthesized in the code produced.)
189+
let pureGen ← `(($(Lean.mkIdent `pure) $ctorIdent))
190+
weightedNonRecursiveGenerators := weightedNonRecursiveGenerators.push (← `((1, $pureGen)))
191+
nonRecursiveGeneratorsNoWeights := nonRecursiveGeneratorsNoWeights.push pureGen
192+
else
193+
-- Add all the constructor's argument names + types to the local context,
194+
-- then produce the body of the sub-generator (& a flag indicating if the constructor is recursive)
195+
let (generatorBody, ctorIsRecursive) ←
196+
withLocalDeclsDND ctorArgNamesTypes (fun _ => do
197+
let mut doElems := #[]
198+
199+
-- Flag to indicate whether the constructor is recursive (initialized to `false`)
200+
let mut ctorIsRecursive := false
201+
202+
-- Examine each argument to see which of them require recursive calls to the generator
203+
for (freshIdent, argType) in ctorArgIdentsTypes do
204+
-- If the argument's type is the same as the target type,
205+
-- produce a recursive call to the generator using `aux_arb`,
206+
-- otherwise generate a value using `arbitrary`
207+
let bindExpr ←
208+
if argType.getAppFn.constName == targetTypeName then
209+
-- We've detected that the constructor has a recursive argument, so we update the flag
210+
ctorIsRecursive := true
211+
`(doElem| let $freshIdent ← $(mkIdent `aux_arb):term $(freshFuel'):term)
212+
else
213+
`(doElem| let $freshIdent ← $(mkIdent ``Arbitrary.arbitrary):term)
214+
doElems := doElems.push bindExpr
215+
216+
-- Create an expression `return C x1 ... xn` at the end of the generator, where
217+
-- `C` is the constructor name and the `xi` are the generated values for the args
218+
let pureExpr ← `(doElem| return $ctorIdent $ctorArgIdents*)
219+
doElems := doElems.push pureExpr
220+
221+
-- Put the body of the generator together in an explicitly-parenthesized `do`-block
222+
let generatorBody ← `((do $[$doElems:doElem]*))
223+
pure (generatorBody, ctorIsRecursive))
224+
225+
if !ctorIsRecursive then
226+
-- Non-recursive generators have weight 1, following the QuickChick convention
227+
weightedNonRecursiveGenerators := weightedNonRecursiveGenerators.push (← `((1, $generatorBody)))
228+
nonRecursiveGeneratorsNoWeights := nonRecursiveGeneratorsNoWeights.push generatorBody
229+
else
230+
-- Recursive generaotrs have an associated weight of `fuel' + 1`, following the QuickChick convention
231+
weightedRecursiveGenerators := weightedRecursiveGenerators.push (← ``(($freshFuel' + 1, $generatorBody)))
232+
233+
-- Use the first non-recursive generator (without its weight) as the default generator
234+
-- If the target type has no non-recursive constructors, we emit an error message
235+
-- saying that we cannot derive a generator for that type
236+
let defaultGenerator ← Option.getDM (nonRecursiveGeneratorsNoWeights[0]?)
237+
(throwError m!"derive Arbitrary failed, {targetTypeName} has no non-recursive constructors")
238+
239+
-- Create the cases for the pattern-match on the fuel argument
240+
-- If `fuel = 0`, pick one of the non-recursive generators
241+
let mut caseExprs := #[]
242+
let zeroCase ← `(Term.matchAltExpr| | $(mkIdent ``Nat.zero) => $(mkIdent ``Gen.oneOfWithDefault) $defaultGenerator [$nonRecursiveGeneratorsNoWeights,*])
243+
caseExprs := caseExprs.push zeroCase
244+
245+
-- If `fuel = fuel' + 1`, pick a generator (it can be non-recursive or recursive)
246+
let mut allWeightedGenerators ← `([$weightedNonRecursiveGenerators,*, $weightedRecursiveGenerators,*])
247+
let succCase ← `(Term.matchAltExpr| | $freshFuel' + 1 => $(mkIdent ``Gen.frequency) $defaultGenerator $allWeightedGenerators)
248+
caseExprs := caseExprs.push succCase
249+
250+
-- Create function argument for the generator fuel
251+
let fuelParam ← `(Term.letIdBinder| ($freshFuel : $(mkIdent `Nat)))
252+
let matchExpr ← `(match $freshFuel:ident with $caseExprs:matchAlt*)
253+
254+
-- Create an instance of the `ArbitraryFueled` typeclass
255+
`(let rec $auxArb:ident $fuelParam : $generatorType :=
256+
$matchExpr
257+
fun $freshFuel => $auxArb $freshFuel)
258+
259+
260+
/-- Creates the function definition for the derived generator -/
261+
def mkAuxFunction (ctx : Deriving.Context) (i : Nat) : TermElabM Command := do
262+
let auxFunName := ctx.auxFunNames[i]!
263+
let indVal := ctx.typeInfos[i]!
264+
let header ← mkArbitraryHeader indVal
265+
let mut binders := header.binders
266+
267+
-- Determine the type of the generator
268+
-- (the `Plausible.Gen` type constructor applied to the name of the `inductive` type, plus any type parameters)
269+
let targetType ← mkInductiveApp ctx.typeInfos[i]! header.argNames
270+
let generatorType ← `($(mkIdent ``Plausible.Gen) $targetType)
271+
272+
-- Create the body of the generator function
273+
let mut body ← mkBody header indVal generatorType
274+
275+
-- For mutually-recursive types, we need to create
276+
-- local `let`-definitions containing the relevant `ArbitraryFueled` instances so that
277+
-- the derived generator typechecks
278+
if ctx.usePartial then
279+
let letDecls ← mkLocalInstanceLetDecls ctx ``ArbitraryFueled header.argNames
280+
body ← mkLet letDecls body
281+
282+
-- If we are deriving a generator for a bunch of mutually-recursive types,
283+
-- the derived generator needs to be marked `partial` (following the implementation
284+
-- of the `deriving Repr` handler)
285+
if ctx.usePartial then
286+
`(partial def $(mkIdent auxFunName):ident $binders:bracketedBinder* : $(mkIdent ``Nat) → $generatorType := $body:term)
287+
else
288+
`(def $(mkIdent auxFunName):ident $binders:bracketedBinder* : $(mkIdent ``Nat) → $generatorType := $body:term)
289+
290+
/-- Creates a `mutual ... end` block containing the definitions of the derived generators -/
291+
def mkMutualBlock (ctx : Deriving.Context) : TermElabM Syntax := do
292+
let mut auxDefs := #[]
293+
for i in *...ctx.typeInfos.size do
294+
auxDefs := auxDefs.push (← mkAuxFunction ctx i)
295+
`(mutual
296+
$auxDefs:command*
297+
end)
298+
299+
/-- Creates an instance of the `ArbitraryFueled` typeclass -/
300+
private def mkArbitraryFueledInstanceCmd (declName : Name) : TermElabM (Array Syntax) := do
301+
let ctx ← mkContext ``Arbitrary "arbitrary" declName
302+
let cmds := #[← mkMutualBlock ctx] ++ (← mkArbitraryFueledInstanceCmds ctx #[declName])
303+
trace[plausible.deriving.arbitrary] "\n{cmds}"
304+
return cmds
305+
306+
/-- Deriving handler which produces an instance of the `ArbitraryFueled` typeclass for
307+
each type specified in `declNames` -/
308+
def mkArbitraryInstanceHandler (declNames : Array Name) : CommandElabM Bool := do
309+
if (← declNames.allM isInductive) then
310+
for declName in declNames do
311+
let cmds ← liftTermElabM $ mkArbitraryFueledInstanceCmd declName
312+
cmds.forM elabCommand
313+
return true
314+
else
315+
throwError "Cannot derive instance of Arbitrary typeclass for non-inductive types"
316+
return false
317+
318+
initialize
319+
registerDerivingHandler ``Arbitrary mkArbitraryInstanceHandler
320+
321+
end Plausible

0 commit comments

Comments
 (0)