Skip to content

Commit 5dcf39e

Browse files
committed
feat: Derive Arbitrary instances for inductive datatypes.
1 parent 9f49266 commit 5dcf39e

17 files changed

+1548
-9
lines changed

Plausible.lean

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,5 @@ import Plausible.Testable
1010
import Plausible.Functions
1111
import Plausible.Attr
1212
import Plausible.Tactic
13+
import Plausible.Arbitrary
14+
import Plausible.DeriveArbitrary

Plausible/ArbitraryFueled.lean

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
/-
2+
Copyright (c) 2025 AWS. All rights reserved.
3+
Released under Apache 2.0 license as described in the file LICENSE.
4+
Authors: AWS
5+
-/
6+
import Plausible.Arbitrary
7+
import Plausible.Gen
8+
9+
namespace Plausible
10+
11+
open Gen
12+
13+
/-- A typeclass for *fueled* random generation, i.e. a variant of
14+
the `Arbitrary` typeclass where the fuel for the generator is made explicit.
15+
- This typeclass is equivalent to Rocq QuickChick's `arbitrarySized` typeclass
16+
(QuickChick uses the `Nat` parameter as both fuel and the generator size,
17+
here we use it just for fuel, as Plausible's `Gen` type constructor
18+
already internalizes the size parameter.) -/
19+
class ArbitraryFueled (α : Type) where
20+
/-- Takes a `Nat` and produces a random generator dependent on the `Nat` parameter
21+
(which indicates the amount of fuel to be used before failing). -/
22+
arbitraryFueled : Nat → Gen α
23+
24+
/-- Every `ArbitraryFueled` instance gives rise to an `Arbitrary` instance -/
25+
instance [ArbitraryFueled α] : Arbitrary α where
26+
arbitrary := Gen.sized ArbitraryFueled.arbitraryFueled
27+
28+
/-- Raised when a fueled generator fails due to insufficient fuel. -/
29+
def Gen.outOfFuel : GenError :=
30+
.genError "out of fuel"
31+
32+
end Plausible

Plausible/Attr.lean

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,4 @@ initialize registerTraceClass `plausible.discarded
1313
initialize registerTraceClass `plausible.success
1414
initialize registerTraceClass `plausible.shrink.steps
1515
initialize registerTraceClass `plausible.shrink.candidates
16+
initialize registerTraceClass `plausible.deriving.arbitrary

Plausible/DeriveArbitrary.lean

Lines changed: 319 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,319 @@
1+
/-
2+
Copyright (c) 2025 Ernest Ng. All rights reserved.
3+
Released under Apache 2.0 license as described in the file LICENSE.
4+
Authors: Ernest Ng
5+
-/
6+
import Lean.Elab
7+
import Lean.Elab.Deriving.Basic
8+
import Lean.Elab.Deriving.Util
9+
10+
import Plausible.Arbitrary
11+
import Plausible.ArbitraryFueled
12+
13+
open Lean Elab Meta Parser Term
14+
open Elab.Deriving
15+
open Elab.Command
16+
17+
/-!
18+
19+
# Deriving Handler for `Arbitrary`
20+
21+
This file defines a handler which automatically derives `Arbitrary` instances
22+
for inductive types.
23+
24+
(Note that the deriving handler technically derives `ArbitraryFueled` instancces,
25+
but every `ArbitraryFueled` instance automatically results in an `Arbitrary` instance,
26+
as detailed in `Arbitrary.lean`.)
27+
28+
Note that the resulting `Arbitrary` and `ArbitraryFueled` instance should be considered
29+
to be opaque, following the convention for the deriving handler for Mathlib's `Encodable` typeclass.
30+
31+
Example usage:
32+
33+
```lean
34+
-- Datatype for binary trees
35+
inductive Tree
36+
| Leaf : Tree
37+
| Node : Nat → Tree → Tree → Tree
38+
deriving Arbitrary
39+
```
40+
41+
To sample from a derived generator, users can simply call `Arbitrary.runArbitrary`, specify the type
42+
for the desired generated values and provide some Nat to act as the generator's fuel parameter (10 in the example below):
43+
44+
```lean
45+
#eval Arbitrary.runArbitrary (α := Tree) 10
46+
```
47+
48+
To view the code for the derived generator, users can enable trace messages using the `plausible.deriving.arbitrary` trace class as follows:
49+
50+
```lean
51+
set_option trace.plausible.deriving.arbitrary true
52+
```
53+
54+
## Main definitions
55+
* Deriving handler for `ArbitraryFueled` typeclass
56+
57+
-/
58+
59+
namespace Plausible
60+
61+
open Arbitrary
62+
63+
/-- Takes the name of a constructor for an algebraic data type and returns an array
64+
containing `(argument_name, argument_type)` pairs.
65+
66+
If the algebraic data type is defined using anonymous constructor argument syntax, i.e.
67+
```
68+
inductive T where
69+
C1 : τ1 → … → τn
70+
71+
```
72+
Lean produces macro scopes when we try to access the names for the constructor args.
73+
In this case, we remove the macro scopes so that the name is user-accessible.
74+
(This will result in constructor argument names being non-unique in the array
75+
that is returned -- it is the caller's responsibility to produce fresh names.)
76+
-/
77+
def getCtorArgsNamesAndTypes (_header : Header) (indVal : InductiveVal) (ctorName : Name) : MetaM (Array (Name × Expr)) := do
78+
let ctorInfo ← getConstInfoCtor ctorName
79+
80+
forallTelescopeReducing ctorInfo.type fun args _ => do
81+
let mut argNamesAndTypes := #[]
82+
83+
for i in *...args.size do
84+
let arg := args[i]!
85+
let argType ← arg.fvarId!.getType
86+
87+
if i < indVal.numParams then
88+
continue
89+
else
90+
let argName ← Core.mkFreshUserName `a
91+
argNamesAndTypes := argNamesAndTypes.push (argName, argType)
92+
return argNamesAndTypes
93+
94+
-- Note: the following functions closely follow the implementation of the deriving handler for `Repr` / `BEq`
95+
-- (see https://github.com/leanprover/lean4/blob/master/src/Lean/Elab/Deriving/Repr.lean).
96+
97+
open TSyntax.Compat in
98+
/-- Variant of `Deriving.Util.mkHeader` where we don't add an explicit binder
99+
of the form `($targetName : $targetType)` to the field `binders`
100+
(i.e. `binders` contains only implicit binders) -/
101+
def mkHeaderWithOnlyImplicitBinders (className : Name) (arity : Nat) (indVal : InductiveVal) : TermElabM Header := do
102+
let argNames ← mkInductArgNames indVal
103+
let binders ← mkImplicitBinders argNames
104+
let targetType ← mkInductiveApp indVal argNames
105+
let mut targetNames := #[]
106+
for _ in [:arity] do
107+
targetNames := targetNames.push (← mkFreshUserName `x)
108+
let binders := binders ++ (← mkInstImplicitBinders className indVal argNames)
109+
return {
110+
binders := binders
111+
argNames := argNames
112+
targetNames := targetNames
113+
targetType := targetType
114+
}
115+
116+
open TSyntax.Compat in
117+
/-- Variant of `Deriving.Util.mkInstanceCmds` which is specialized to creating `ArbitraryFueled` instances
118+
that have `Arbitrary` inst-implicit binders.
119+
120+
Note that we can't use `mkInstanceCmds` out of the box,
121+
since it expects the inst-implicit binders and the instance we're creating to both belong to the same typeclass. -/
122+
def mkArbitraryFueledInstanceCmds (ctx : Deriving.Context) (typeNames : Array Name) (useAnonCtor := true) : TermElabM (Array Command) := do
123+
let mut instances := #[]
124+
for i in [:ctx.typeInfos.size] do
125+
let indVal := ctx.typeInfos[i]!
126+
if typeNames.contains indVal.name then
127+
let auxFunName := ctx.auxFunNames[i]!
128+
let argNames ← mkInductArgNames indVal
129+
let binders ← mkImplicitBinders argNames
130+
let binders := binders ++ (← mkInstImplicitBinders ``Arbitrary indVal argNames) -- this line is changed from
131+
let indType ← mkInductiveApp indVal argNames
132+
let type ← `($(mkCIdent ``ArbitraryFueled) $indType)
133+
let mut val := mkIdent auxFunName
134+
if useAnonCtor then
135+
val ← `(⟨$val⟩)
136+
let instCmd ← `(instance $binders:implicitBinder* : $type := $val)
137+
instances := instances.push instCmd
138+
return instances
139+
140+
/-- Creates a `Header` for the `Arbitrary` typeclass -/
141+
def mkArbitraryHeader (indVal : InductiveVal) : TermElabM Header :=
142+
mkHeaderWithOnlyImplicitBinders ``Arbitrary 1 indVal
143+
144+
/-- Creates the *body* of the generator that appears in the instance of the `ArbitraryFueled` typeclass -/
145+
def mkBody (header : Header) (inductiveVal : InductiveVal) (generatorType : TSyntax `term) : TermElabM Term := do
146+
-- Fetch the name of the target type (the type for which we are deriving a generator)
147+
let targetTypeName := inductiveVal.name
148+
149+
-- Produce `Ident`s for the `fuel` argument for the lambda
150+
-- at the end of the generator function, as well as the `aux_arb` inner helper function
151+
let freshFuel := Lean.mkIdent (← Core.mkFreshUserName `fuel)
152+
let freshFuel' := Lean.mkIdent (← Core.mkFreshUserName `fuel')
153+
let auxArb := mkIdent `aux_arb
154+
155+
-- Maintain two arrays which will be populated with pairs
156+
-- where the first component is a sub-generator (non-recursive / recursive)
157+
-- and the 2nd component is the generator's associated weight
158+
let mut weightedNonRecursiveGenerators := #[]
159+
let mut weightedRecursiveGenerators := #[]
160+
161+
-- We also need to keep track of non-recursive generators without their weights,
162+
-- since some of Plausible's `Gen` combinators operate on generator functions
163+
let mut nonRecursiveGeneratorsNoWeights := #[]
164+
165+
for ctorName in inductiveVal.ctors do
166+
let ctorIdent := mkIdent ctorName
167+
168+
let ctorArgNamesTypes ← getCtorArgsNamesAndTypes header inductiveVal ctorName
169+
let (ctorArgNames, ctorArgTypes) := Array.unzip ctorArgNamesTypes
170+
171+
/- Produce fresh names for each of the constructor's arguments.
172+
Producing fresh names is necessary in order to handle
173+
constructors expressed using the following syntax:
174+
```
175+
inductive Foo
176+
| C : T1 → ... → Tn
177+
```
178+
in which all the arguments to the constructor `C` don't have explicit names.
179+
-/
180+
let ctorArgIdents := Lean.mkIdent <$> ctorArgNames
181+
let ctorArgIdentsTypes := Array.zip ctorArgIdents ctorArgTypes
182+
183+
if ctorArgNamesTypes.isEmpty then
184+
-- Constructor is nullary, we can just use an generator of the form `pure ...` with weight 1,
185+
-- following the QuickChick convention.
186+
-- (For clarity, this generator is parenthesized in the code produced.)
187+
let pureGen ← `(($(Lean.mkIdent `pure) $ctorIdent))
188+
weightedNonRecursiveGenerators := weightedNonRecursiveGenerators.push (← `((1, $pureGen)))
189+
nonRecursiveGeneratorsNoWeights := nonRecursiveGeneratorsNoWeights.push pureGen
190+
else
191+
-- Add all the constructor's argument names + types to the local context,
192+
-- then produce the body of the sub-generator (& a flag indicating if the constructor is recursive)
193+
let (generatorBody, ctorIsRecursive) ←
194+
withLocalDeclsDND ctorArgNamesTypes (fun _ => do
195+
let mut doElems := #[]
196+
197+
-- Flag to indicate whether the constructor is recursive (initialized to `false`)
198+
let mut ctorIsRecursive := false
199+
200+
-- Examine each argument to see which of them require recursive calls to the generator
201+
for (freshIdent, argType) in ctorArgIdentsTypes do
202+
-- If the argument's type is the same as the target type,
203+
-- produce a recursive call to the generator using `aux_arb`,
204+
-- otherwise generate a value using `arbitrary`
205+
let bindExpr ←
206+
if argType.getAppFn.constName == targetTypeName then
207+
-- We've detected that the constructor has a recursive argument, so we update the flag
208+
ctorIsRecursive := true
209+
`(doElem| let $freshIdent ← $(mkIdent `aux_arb):term $(freshFuel'):term)
210+
else
211+
`(doElem| let $freshIdent ← $(mkIdent ``Arbitrary.arbitrary):term)
212+
doElems := doElems.push bindExpr
213+
214+
-- Create an expression `return C x1 ... xn` at the end of the generator, where
215+
-- `C` is the constructor name and the `xi` are the generated values for the args
216+
let pureExpr ← `(doElem| return $ctorIdent $ctorArgIdents*)
217+
doElems := doElems.push pureExpr
218+
219+
-- Put the body of the generator together in an explicitly-parenthesized `do`-block
220+
let generatorBody ← `((do $[$doElems:doElem]*))
221+
pure (generatorBody, ctorIsRecursive))
222+
223+
if !ctorIsRecursive then
224+
-- Non-recursive generators have weight 1, following the QuickChick convention
225+
weightedNonRecursiveGenerators := weightedNonRecursiveGenerators.push (← `((1, $generatorBody)))
226+
nonRecursiveGeneratorsNoWeights := nonRecursiveGeneratorsNoWeights.push generatorBody
227+
else
228+
-- Recursive generaotrs have an associated weight of `fuel' + 1`, following the QuickChick convention
229+
weightedRecursiveGenerators := weightedRecursiveGenerators.push (← ``(($freshFuel' + 1, $generatorBody)))
230+
231+
-- Use the first non-recursive generator (without its weight) as the default generator
232+
-- If the target type has no non-recursive constructors, we emit an error message
233+
-- saying that we cannot derive a generator for that type
234+
let defaultGenerator ← Option.getDM (nonRecursiveGeneratorsNoWeights[0]?)
235+
(throwError m!"derive Arbitrary failed, {targetTypeName} has no non-recursive constructors")
236+
237+
-- Create the cases for the pattern-match on the fuel argument
238+
-- If `fuel = 0`, pick one of the non-recursive generators
239+
let mut caseExprs := #[]
240+
let zeroCase ← `(Term.matchAltExpr| | $(mkIdent ``Nat.zero) => $(mkIdent ``Gen.oneOfWithDefault) $defaultGenerator [$nonRecursiveGeneratorsNoWeights,*])
241+
caseExprs := caseExprs.push zeroCase
242+
243+
-- If `fuel = fuel' + 1`, pick a generator (it can be non-recursive or recursive)
244+
let mut allWeightedGenerators ← `([$weightedNonRecursiveGenerators,*, $weightedRecursiveGenerators,*])
245+
let succCase ← `(Term.matchAltExpr| | $freshFuel' + 1 => $(mkIdent ``Gen.frequency) $defaultGenerator $allWeightedGenerators)
246+
caseExprs := caseExprs.push succCase
247+
248+
-- Create function argument for the generator fuel
249+
let fuelParam ← `(Term.letIdBinder| ($freshFuel : $(mkIdent `Nat)))
250+
let matchExpr ← `(match $freshFuel:ident with $caseExprs:matchAlt*)
251+
252+
-- Create an instance of the `ArbitraryFueled` typeclass
253+
`(let rec $auxArb:ident $fuelParam : $generatorType :=
254+
$matchExpr
255+
fun $freshFuel => $auxArb $freshFuel)
256+
257+
258+
/-- Creates the function definition for the derived generator -/
259+
def mkAuxFunction (ctx : Deriving.Context) (i : Nat) : TermElabM Command := do
260+
let auxFunName := ctx.auxFunNames[i]!
261+
let indVal := ctx.typeInfos[i]!
262+
let header ← mkArbitraryHeader indVal
263+
let mut binders := header.binders
264+
265+
-- Determine the type of the generator
266+
-- (the `Plausible.Gen` type constructor applied to the name of the `inductive` type, plus any type parameters)
267+
let targetType ← mkInductiveApp ctx.typeInfos[i]! header.argNames
268+
let generatorType ← `($(mkIdent ``Plausible.Gen) $targetType)
269+
270+
-- Create the body of the generator function
271+
let mut body ← mkBody header indVal generatorType
272+
273+
-- For mutually-recursive types, we need to create
274+
-- local `let`-definitions containing the relevant `ArbitraryFueled` instances so that
275+
-- the derived generator typechecks
276+
if ctx.usePartial then
277+
let letDecls ← mkLocalInstanceLetDecls ctx ``ArbitraryFueled header.argNames
278+
body ← mkLet letDecls body
279+
280+
-- If we are deriving a generator for a bunch of mutually-recursive types,
281+
-- the derived generator needs to be marked `partial` (following the implementation
282+
-- of the `deriving Repr` handler)
283+
if ctx.usePartial then
284+
`(partial def $(mkIdent auxFunName):ident $binders:bracketedBinder* : $(mkIdent ``Nat) → $generatorType := $body:term)
285+
else
286+
`(def $(mkIdent auxFunName):ident $binders:bracketedBinder* : $(mkIdent ``Nat) → $generatorType := $body:term)
287+
288+
/-- Creates a `mutual ... end` block containing the definitions of the derived generators -/
289+
def mkMutualBlock (ctx : Deriving.Context) : TermElabM Syntax := do
290+
let mut auxDefs := #[]
291+
for i in *...ctx.typeInfos.size do
292+
auxDefs := auxDefs.push (← mkAuxFunction ctx i)
293+
`(mutual
294+
$auxDefs:command*
295+
end)
296+
297+
/-- Creates an instance of the `ArbitraryFueled` typeclass -/
298+
private def mkArbitraryFueledInstanceCmd (declName : Name) : TermElabM (Array Syntax) := do
299+
let ctx ← mkContext ``Arbitrary "arbitrary" declName
300+
let cmds := #[← mkMutualBlock ctx] ++ (← mkArbitraryFueledInstanceCmds ctx #[declName])
301+
trace[plausible.deriving.arbitrary] "\n{cmds}"
302+
return cmds
303+
304+
/-- Deriving handler which produces an instance of the `ArbitraryFueled` typeclass for
305+
each type specified in `declNames` -/
306+
def mkArbitraryInstanceHandler (declNames : Array Name) : CommandElabM Bool := do
307+
if (← declNames.allM isInductive) then
308+
for declName in declNames do
309+
let cmds ← liftTermElabM $ mkArbitraryFueledInstanceCmd declName
310+
cmds.forM elabCommand
311+
return true
312+
else
313+
throwError "Cannot derive instance of Arbitrary typeclass for non-inductive types"
314+
return false
315+
316+
initialize
317+
registerDerivingHandler ``Arbitrary mkArbitraryInstanceHandler
318+
319+
end Plausible

0 commit comments

Comments
 (0)