|
| 1 | +/- |
| 2 | +Copyright (c) 2025 Ernest Ng. All rights reserved. |
| 3 | +Released under Apache 2.0 license as described in the file LICENSE. |
| 4 | +Authors: Ernest Ng |
| 5 | +-/ |
| 6 | +import Lean.Elab |
| 7 | +import Lean.Elab.Deriving.Basic |
| 8 | +import Lean.Elab.Deriving.Util |
| 9 | + |
| 10 | +import Plausible.Arbitrary |
| 11 | +import Plausible.ArbitraryFueled |
| 12 | + |
| 13 | +open Lean Elab Meta Parser Term |
| 14 | +open Elab.Deriving |
| 15 | +open Elab.Command |
| 16 | + |
| 17 | +/-! |
| 18 | +
|
| 19 | +# Deriving Handler for `Arbitrary` |
| 20 | +
|
| 21 | +This file defines a handler which automatically derives `Arbitrary` instances |
| 22 | +for inductive types. |
| 23 | +
|
| 24 | +(Note that the deriving handler technically derives `ArbitraryFueled` instancces, |
| 25 | +but every `ArbitraryFueled` instance automatically results in an `Arbitrary` instance, |
| 26 | +as detailed in `Arbitrary.lean`.) |
| 27 | +
|
| 28 | +Note that the resulting `Arbitrary` and `ArbitraryFueled` instance should be considered |
| 29 | +to be opaque, following the convention for the deriving handler for Mathlib's `Encodable` typeclass. |
| 30 | +
|
| 31 | +Example usage: |
| 32 | +
|
| 33 | +```lean |
| 34 | +-- Datatype for binary trees |
| 35 | +inductive Tree |
| 36 | + | Leaf : Tree |
| 37 | + | Node : Nat → Tree → Tree → Tree |
| 38 | + deriving Arbitrary |
| 39 | +``` |
| 40 | +
|
| 41 | +To sample from a derived generator, users can simply call `Arbitrary.runArbitrary`, specify the type |
| 42 | +for the desired generated values and provide some Nat to act as the generator's fuel parameter (10 in the example below): |
| 43 | +
|
| 44 | +```lean |
| 45 | +#eval Arbitrary.runArbitrary (α := Tree) 10 |
| 46 | +``` |
| 47 | +
|
| 48 | +To view the code for the derived generator, users can enable trace messages using the `plausible.deriving.arbitrary` trace class as follows: |
| 49 | +
|
| 50 | +```lean |
| 51 | +set_option trace.plausible.deriving.arbitrary true |
| 52 | +``` |
| 53 | +
|
| 54 | +## Main definitions |
| 55 | +* Deriving handler for `ArbitraryFueled` typeclass |
| 56 | +
|
| 57 | +-/ |
| 58 | + |
| 59 | +namespace Plausible |
| 60 | + |
| 61 | +open Arbitrary |
| 62 | + |
| 63 | +/-- Takes the name of a constructor for an algebraic data type and returns an array |
| 64 | + containing `(argument_name, argument_type)` pairs. |
| 65 | +
|
| 66 | + If the algebraic data type is defined using anonymous constructor argument syntax, i.e. |
| 67 | + ``` |
| 68 | + inductive T where |
| 69 | + C1 : τ1 → … → τn |
| 70 | + … |
| 71 | + ``` |
| 72 | + Lean produces macro scopes when we try to access the names for the constructor args. |
| 73 | + In this case, we remove the macro scopes so that the name is user-accessible. |
| 74 | + (This will result in constructor argument names being non-unique in the array |
| 75 | + that is returned -- it is the caller's responsibility to produce fresh names.) |
| 76 | +-/ |
| 77 | +def getCtorArgsNamesAndTypes (header : Header) (indVal : InductiveVal) (ctorName : Name) : MetaM (Array (Name × Expr)) := do |
| 78 | + let ctorInfo ← getConstInfoCtor ctorName |
| 79 | + |
| 80 | + forallTelescopeReducing ctorInfo.type fun args _ => do |
| 81 | + let mut argNamesAndTypes := #[] |
| 82 | + |
| 83 | + for i in *...args.size do |
| 84 | + let arg := args[i]! |
| 85 | + let localDecl ← arg.fvarId!.getDecl |
| 86 | + let argType := localDecl.type |
| 87 | + |
| 88 | + let argName ← if i < indVal.numParams then pure header.argNames[i]! else Core.mkFreshUserName `a |
| 89 | + if i < indVal.numParams then |
| 90 | + continue |
| 91 | + else |
| 92 | + argNamesAndTypes := argNamesAndTypes.push (argName, argType) |
| 93 | + |
| 94 | + return argNamesAndTypes |
| 95 | + |
| 96 | +-- Note: the following functions closely follow the implementation of the deriving handler for `Repr` / `BEq` |
| 97 | +-- (see https://github.com/leanprover/lean4/blob/master/src/Lean/Elab/Deriving/Repr.lean). |
| 98 | + |
| 99 | +open TSyntax.Compat in |
| 100 | +/-- Variant of `Deriving.Util.mkHeader` where we don't add an explicit binder |
| 101 | + of the form `($targetName : $targetType)` to the field `binders` |
| 102 | + (i.e. `binders` contains only implicit binders) -/ |
| 103 | +def mkHeaderWithOnlyImplicitBinders (className : Name) (arity : Nat) (indVal : InductiveVal) : TermElabM Header := do |
| 104 | + let argNames ← mkInductArgNames indVal |
| 105 | + let binders ← mkImplicitBinders argNames |
| 106 | + let targetType ← mkInductiveApp indVal argNames |
| 107 | + let mut targetNames := #[] |
| 108 | + for _ in [:arity] do |
| 109 | + targetNames := targetNames.push (← mkFreshUserName `x) |
| 110 | + let binders := binders ++ (← mkInstImplicitBinders className indVal argNames) |
| 111 | + return { |
| 112 | + binders := binders |
| 113 | + argNames := argNames |
| 114 | + targetNames := targetNames |
| 115 | + targetType := targetType |
| 116 | + } |
| 117 | + |
| 118 | +open TSyntax.Compat in |
| 119 | +/-- Variant of `Deriving.Util.mkInstanceCmds` which is specialized to creating `ArbitraryFueled` instances |
| 120 | + that have `Arbitrary` inst-implicit binders. |
| 121 | +
|
| 122 | + Note that we can't use `mkInstanceCmds` out of the box, |
| 123 | + since it expects the inst-implicit binders and the instance we're creating to both belong to the same typeclass. -/ |
| 124 | +def mkArbitraryFueledInstanceCmds (ctx : Deriving.Context) (typeNames : Array Name) (useAnonCtor := true) : TermElabM (Array Command) := do |
| 125 | + let mut instances := #[] |
| 126 | + for i in [:ctx.typeInfos.size] do |
| 127 | + let indVal := ctx.typeInfos[i]! |
| 128 | + if typeNames.contains indVal.name then |
| 129 | + let auxFunName := ctx.auxFunNames[i]! |
| 130 | + let argNames ← mkInductArgNames indVal |
| 131 | + let binders ← mkImplicitBinders argNames |
| 132 | + let binders := binders ++ (← mkInstImplicitBinders ``Arbitrary indVal argNames) -- this line is changed from |
| 133 | + let indType ← mkInductiveApp indVal argNames |
| 134 | + let type ← `($(mkCIdent ``ArbitraryFueled) $indType) |
| 135 | + let mut val := mkIdent auxFunName |
| 136 | + if useAnonCtor then |
| 137 | + val ← `(⟨$val⟩) |
| 138 | + let instCmd ← `(instance $binders:implicitBinder* : $type := $val) |
| 139 | + instances := instances.push instCmd |
| 140 | + return instances |
| 141 | + |
| 142 | +/-- Creates a `Header` for the `Arbitrary` typeclass -/ |
| 143 | +def mkArbitraryHeader (indVal : InductiveVal) : TermElabM Header := |
| 144 | + mkHeaderWithOnlyImplicitBinders ``Arbitrary 1 indVal |
| 145 | + |
| 146 | +/-- Creates the *body* of the generator that appears in the instance of the `ArbitraryFueled` typeclass -/ |
| 147 | +def mkBody (header : Header) (inductiveVal : InductiveVal) (generatorType : TSyntax `term) : TermElabM Term := do |
| 148 | + -- Fetch the name of the target type (the type for which we are deriving a generator) |
| 149 | + let targetTypeName := inductiveVal.name |
| 150 | + |
| 151 | + -- Produce `Ident`s for the `fuel` argument for the lambda |
| 152 | + -- at the end of the generator function, as well as the `aux_arb` inner helper function |
| 153 | + let freshFuel := Lean.mkIdent (← Core.mkFreshUserName `fuel) |
| 154 | + let freshFuel' := Lean.mkIdent (← Core.mkFreshUserName `fuel') |
| 155 | + let auxArb := mkIdent `aux_arb |
| 156 | + |
| 157 | + -- Maintain two arrays which will be populated with pairs |
| 158 | + -- where the first component is a sub-generator (non-recursive / recursive) |
| 159 | + -- and the 2nd component is the generator's associated weight |
| 160 | + let mut weightedNonRecursiveGenerators := #[] |
| 161 | + let mut weightedRecursiveGenerators := #[] |
| 162 | + |
| 163 | + -- We also need to keep track of non-recursive generators without their weights, |
| 164 | + -- since some of Plausible's `Gen` combinators operate on generator functions |
| 165 | + let mut nonRecursiveGeneratorsNoWeights := #[] |
| 166 | + |
| 167 | + for ctorName in inductiveVal.ctors do |
| 168 | + let ctorIdent := mkIdent ctorName |
| 169 | + |
| 170 | + let ctorArgNamesTypes ← getCtorArgsNamesAndTypes header inductiveVal ctorName |
| 171 | + let (ctorArgNames, ctorArgTypes) := Array.unzip ctorArgNamesTypes |
| 172 | + |
| 173 | + /- Produce fresh names for each of the constructor's arguments. |
| 174 | + Producing fresh names is necessary in order to handle |
| 175 | + constructors expressed using the following syntax: |
| 176 | + ``` |
| 177 | + inductive Foo |
| 178 | + | C : T1 → ... → Tn |
| 179 | + ``` |
| 180 | + in which all the arguments to the constructor `C` don't have explicit names. |
| 181 | + -/ |
| 182 | + let ctorArgIdents := Lean.mkIdent <$> ctorArgNames |
| 183 | + let ctorArgIdentsTypes := Array.zip ctorArgIdents ctorArgTypes |
| 184 | + |
| 185 | + if ctorArgNamesTypes.isEmpty then |
| 186 | + -- Constructor is nullary, we can just use an generator of the form `pure ...` with weight 1, |
| 187 | + -- following the QuickChick convention. |
| 188 | + -- (For clarity, this generator is parenthesized in the code produced.) |
| 189 | + let pureGen ← `(($(Lean.mkIdent `pure) $ctorIdent)) |
| 190 | + weightedNonRecursiveGenerators := weightedNonRecursiveGenerators.push (← `((1, $pureGen))) |
| 191 | + nonRecursiveGeneratorsNoWeights := nonRecursiveGeneratorsNoWeights.push pureGen |
| 192 | + else |
| 193 | + -- Add all the constructor's argument names + types to the local context, |
| 194 | + -- then produce the body of the sub-generator (& a flag indicating if the constructor is recursive) |
| 195 | + let (generatorBody, ctorIsRecursive) ← |
| 196 | + withLocalDeclsDND ctorArgNamesTypes (fun _ => do |
| 197 | + let mut doElems := #[] |
| 198 | + |
| 199 | + -- Flag to indicate whether the constructor is recursive (initialized to `false`) |
| 200 | + let mut ctorIsRecursive := false |
| 201 | + |
| 202 | + -- Examine each argument to see which of them require recursive calls to the generator |
| 203 | + for (freshIdent, argType) in ctorArgIdentsTypes do |
| 204 | + -- If the argument's type is the same as the target type, |
| 205 | + -- produce a recursive call to the generator using `aux_arb`, |
| 206 | + -- otherwise generate a value using `arbitrary` |
| 207 | + let bindExpr ← |
| 208 | + if argType.getAppFn.constName == targetTypeName then |
| 209 | + -- We've detected that the constructor has a recursive argument, so we update the flag |
| 210 | + ctorIsRecursive := true |
| 211 | + `(doElem| let $freshIdent ← $(mkIdent `aux_arb):term $(freshFuel'):term) |
| 212 | + else |
| 213 | + `(doElem| let $freshIdent ← $(mkIdent ``Arbitrary.arbitrary):term) |
| 214 | + doElems := doElems.push bindExpr |
| 215 | + |
| 216 | + -- Create an expression `return C x1 ... xn` at the end of the generator, where |
| 217 | + -- `C` is the constructor name and the `xi` are the generated values for the args |
| 218 | + let pureExpr ← `(doElem| return $ctorIdent $ctorArgIdents*) |
| 219 | + doElems := doElems.push pureExpr |
| 220 | + |
| 221 | + -- Put the body of the generator together in an explicitly-parenthesized `do`-block |
| 222 | + let generatorBody ← `((do $[$doElems:doElem]*)) |
| 223 | + pure (generatorBody, ctorIsRecursive)) |
| 224 | + |
| 225 | + if !ctorIsRecursive then |
| 226 | + -- Non-recursive generators have weight 1, following the QuickChick convention |
| 227 | + weightedNonRecursiveGenerators := weightedNonRecursiveGenerators.push (← `((1, $generatorBody))) |
| 228 | + nonRecursiveGeneratorsNoWeights := nonRecursiveGeneratorsNoWeights.push generatorBody |
| 229 | + else |
| 230 | + -- Recursive generaotrs have an associated weight of `fuel' + 1`, following the QuickChick convention |
| 231 | + weightedRecursiveGenerators := weightedRecursiveGenerators.push (← ``(($freshFuel' + 1, $generatorBody))) |
| 232 | + |
| 233 | + -- Use the first non-recursive generator (without its weight) as the default generator |
| 234 | + -- If the target type has no non-recursive constructors, we emit an error message |
| 235 | + -- saying that we cannot derive a generator for that type |
| 236 | + let defaultGenerator ← Option.getDM (nonRecursiveGeneratorsNoWeights[0]?) |
| 237 | + (throwError m!"derive Arbitrary failed, {targetTypeName} has no non-recursive constructors") |
| 238 | + |
| 239 | + -- Create the cases for the pattern-match on the fuel argument |
| 240 | + -- If `fuel = 0`, pick one of the non-recursive generators |
| 241 | + let mut caseExprs := #[] |
| 242 | + let zeroCase ← `(Term.matchAltExpr| | $(mkIdent ``Nat.zero) => $(mkIdent ``Gen.oneOfWithDefault) $defaultGenerator [$nonRecursiveGeneratorsNoWeights,*]) |
| 243 | + caseExprs := caseExprs.push zeroCase |
| 244 | + |
| 245 | + -- If `fuel = fuel' + 1`, pick a generator (it can be non-recursive or recursive) |
| 246 | + let mut allWeightedGenerators ← `([$weightedNonRecursiveGenerators,*, $weightedRecursiveGenerators,*]) |
| 247 | + let succCase ← `(Term.matchAltExpr| | $freshFuel' + 1 => $(mkIdent ``Gen.frequency) $defaultGenerator $allWeightedGenerators) |
| 248 | + caseExprs := caseExprs.push succCase |
| 249 | + |
| 250 | + -- Create function argument for the generator fuel |
| 251 | + let fuelParam ← `(Term.letIdBinder| ($freshFuel : $(mkIdent `Nat))) |
| 252 | + let matchExpr ← `(match $freshFuel:ident with $caseExprs:matchAlt*) |
| 253 | + |
| 254 | + -- Create an instance of the `ArbitraryFueled` typeclass |
| 255 | + `(let rec $auxArb:ident $fuelParam : $generatorType := |
| 256 | + $matchExpr |
| 257 | + fun $freshFuel => $auxArb $freshFuel) |
| 258 | + |
| 259 | + |
| 260 | +/-- Creates the function definition for the derived generator -/ |
| 261 | +def mkAuxFunction (ctx : Deriving.Context) (i : Nat) : TermElabM Command := do |
| 262 | + let auxFunName := ctx.auxFunNames[i]! |
| 263 | + let indVal := ctx.typeInfos[i]! |
| 264 | + let header ← mkArbitraryHeader indVal |
| 265 | + let mut binders := header.binders |
| 266 | + |
| 267 | + -- Determine the type of the generator |
| 268 | + -- (the `Plausible.Gen` type constructor applied to the name of the `inductive` type, plus any type parameters) |
| 269 | + let targetType ← mkInductiveApp ctx.typeInfos[i]! header.argNames |
| 270 | + let generatorType ← `($(mkIdent ``Plausible.Gen) $targetType) |
| 271 | + |
| 272 | + -- Create the body of the generator function |
| 273 | + let mut body ← mkBody header indVal generatorType |
| 274 | + |
| 275 | + -- For mutually-recursive types, we need to create |
| 276 | + -- local `let`-definitions containing the relevant `ArbitraryFueled` instances so that |
| 277 | + -- the derived generator typechecks |
| 278 | + if ctx.usePartial then |
| 279 | + let letDecls ← mkLocalInstanceLetDecls ctx ``ArbitraryFueled header.argNames |
| 280 | + body ← mkLet letDecls body |
| 281 | + |
| 282 | + -- If we are deriving a generator for a bunch of mutually-recursive types, |
| 283 | + -- the derived generator needs to be marked `partial` (following the implementation |
| 284 | + -- of the `deriving Repr` handler) |
| 285 | + if ctx.usePartial then |
| 286 | + `(partial def $(mkIdent auxFunName):ident $binders:bracketedBinder* : $(mkIdent ``Nat) → $generatorType := $body:term) |
| 287 | + else |
| 288 | + `(def $(mkIdent auxFunName):ident $binders:bracketedBinder* : $(mkIdent ``Nat) → $generatorType := $body:term) |
| 289 | + |
| 290 | +/-- Creates a `mutual ... end` block containing the definitions of the derived generators -/ |
| 291 | +def mkMutualBlock (ctx : Deriving.Context) : TermElabM Syntax := do |
| 292 | + let mut auxDefs := #[] |
| 293 | + for i in *...ctx.typeInfos.size do |
| 294 | + auxDefs := auxDefs.push (← mkAuxFunction ctx i) |
| 295 | + `(mutual |
| 296 | + $auxDefs:command* |
| 297 | + end) |
| 298 | + |
| 299 | +/-- Creates an instance of the `ArbitraryFueled` typeclass -/ |
| 300 | +private def mkArbitraryFueledInstanceCmd (declName : Name) : TermElabM (Array Syntax) := do |
| 301 | + let ctx ← mkContext ``Arbitrary "arbitrary" declName |
| 302 | + let cmds := #[← mkMutualBlock ctx] ++ (← mkArbitraryFueledInstanceCmds ctx #[declName]) |
| 303 | + trace[plausible.deriving.arbitrary] "\n{cmds}" |
| 304 | + return cmds |
| 305 | + |
| 306 | +/-- Deriving handler which produces an instance of the `ArbitraryFueled` typeclass for |
| 307 | + each type specified in `declNames` -/ |
| 308 | +def mkArbitraryInstanceHandler (declNames : Array Name) : CommandElabM Bool := do |
| 309 | + if (← declNames.allM isInductive) then |
| 310 | + for declName in declNames do |
| 311 | + let cmds ← liftTermElabM $ mkArbitraryFueledInstanceCmd declName |
| 312 | + cmds.forM elabCommand |
| 313 | + return true |
| 314 | + else |
| 315 | + throwError "Cannot derive instance of Arbitrary typeclass for non-inductive types" |
| 316 | + return false |
| 317 | + |
| 318 | +initialize |
| 319 | + registerDerivingHandler ``Arbitrary mkArbitraryInstanceHandler |
| 320 | + |
| 321 | +end Plausible |
0 commit comments