|
| 1 | +/- |
| 2 | +Copyright (c) 2024 Plausible Contributors. All rights reserved. |
| 3 | +Released under Apache 2.0 license as described in the file LICENSE. |
| 4 | +Authors: Alok Singh |
| 5 | +-/ |
| 6 | +module |
| 7 | + |
| 8 | +public section |
| 9 | + |
| 10 | +/-! |
| 11 | +# Conjecture Engine |
| 12 | +
|
| 13 | +Hypothesis-style testing: generate bytes, interpret via `Strategy`, shrink bytes. |
| 14 | +No per-type `Shrinkable` needed. See https://hypothesis.works/articles/how-hypothesis-works/ |
| 15 | +-/ |
| 16 | + |
| 17 | +namespace Plausible.Conjecture |
| 18 | + |
| 19 | +instance : Repr ByteArray where |
| 20 | + reprPrec ba _ := |
| 21 | + let hex := ba.toList.map fun b => |
| 22 | + let hi := b.toNat / 16 |
| 23 | + let lo := b.toNat % 16 |
| 24 | + let toHexDigit n := if n < 10 then Char.ofNat (48 + n) else Char.ofNat (87 + n) |
| 25 | + s!"{toHexDigit hi}{toHexDigit lo}" |
| 26 | + Repr.addAppParen s!"ByteArray.mk #{hex}" 0 |
| 27 | + |
| 28 | +/-- Interval [start, stop) in the byte buffer. Shrinking prefers complete spans. -/ |
| 29 | +structure Span where |
| 30 | + start : Nat |
| 31 | + stop : Nat |
| 32 | + label : String := "" |
| 33 | + deriving Repr, BEq, Inhabited |
| 34 | + |
| 35 | +inductive Status where |
| 36 | + | valid | interesting | invalid | overrun |
| 37 | + deriving Repr, BEq, Inhabited |
| 38 | + |
| 39 | +/-- Byte buffer with span annotations for shrinking. -/ |
| 40 | +structure ChoiceSequence where |
| 41 | + buffer : ByteArray |
| 42 | + spans : Array Span |
| 43 | + index : Nat := 0 |
| 44 | + spanStack : List Nat := [] |
| 45 | + maxSize : Nat := 8 * 1024 |
| 46 | + deriving Repr, Inhabited |
| 47 | + |
| 48 | +namespace ChoiceSequence |
| 49 | + |
| 50 | +def empty (maxSize : Nat := 8 * 1024) : ChoiceSequence := |
| 51 | + { buffer := ByteArray.empty, spans := #[], maxSize } |
| 52 | + |
| 53 | +def ofBytes (bytes : ByteArray) (maxSize : Nat := 8 * 1024) : ChoiceSequence := |
| 54 | + { buffer := bytes, spans := #[], maxSize } |
| 55 | + |
| 56 | +def size (cs : ChoiceSequence) : Nat := cs.buffer.size |
| 57 | +def remaining (cs : ChoiceSequence) : Nat := cs.buffer.size - cs.index |
| 58 | +def exhausted (cs : ChoiceSequence) : Bool := cs.index >= cs.buffer.size |
| 59 | + |
| 60 | +def lexLt (cs1 cs2 : ChoiceSequence) : Bool := |
| 61 | + let b1 := cs1.buffer |
| 62 | + let b2 := cs2.buffer |
| 63 | + -- Shorter is smaller |
| 64 | + if b1.size < b2.size then true |
| 65 | + else if b1.size > b2.size then false |
| 66 | + else |
| 67 | + -- Same size: lexicographic on bytes |
| 68 | + let rec go (i : Nat) : Bool := |
| 69 | + if i >= b1.size then false -- equal |
| 70 | + else if b1.get! i < b2.get! i then true |
| 71 | + else if b1.get! i > b2.get! i then false |
| 72 | + else go (i + 1) |
| 73 | + go 0 |
| 74 | + |
| 75 | +end ChoiceSequence |
| 76 | + |
| 77 | +inductive DrawResult (α : Type) where |
| 78 | + | ok (value : α) (cs : ChoiceSequence) |
| 79 | + | overrun |
| 80 | + deriving Inhabited |
| 81 | + |
| 82 | +abbrev StrategyM (α : Type) := ChoiceSequence → DrawResult α |
| 83 | + |
| 84 | +instance : Monad StrategyM where |
| 85 | + pure a := fun cs => .ok a cs |
| 86 | + bind ma f := fun cs => |
| 87 | + match ma cs with |
| 88 | + | .ok a cs' => f a cs' |
| 89 | + | .overrun => .overrun |
| 90 | + |
| 91 | +instance : MonadExcept Unit StrategyM where |
| 92 | + throw _ := fun _ => .overrun |
| 93 | + tryCatch ma handler := fun cs => |
| 94 | + match ma cs with |
| 95 | + | .overrun => handler () cs |
| 96 | + | result => result |
| 97 | + |
| 98 | +namespace StrategyM |
| 99 | + |
| 100 | +def drawBytes (n : Nat) : StrategyM ByteArray := fun cs => |
| 101 | + if cs.index + n > cs.buffer.size then .overrun |
| 102 | + else .ok (cs.buffer.extract cs.index (cs.index + n)) { cs with index := cs.index + n } |
| 103 | + |
| 104 | +def drawByte : StrategyM UInt8 := do |
| 105 | + let bytes ← drawBytes 1 |
| 106 | + return bytes.get! 0 |
| 107 | + |
| 108 | +def startSpan : StrategyM Unit := fun cs => |
| 109 | + .ok () { cs with spanStack := cs.index :: cs.spanStack } |
| 110 | + |
| 111 | +def endSpan (label : String := "") : StrategyM Unit := fun cs => |
| 112 | + match cs.spanStack with |
| 113 | + | [] => .ok () cs |
| 114 | + | start :: rest => |
| 115 | + .ok () { cs with spans := cs.spans.push { start, stop := cs.index, label }, spanStack := rest } |
| 116 | + |
| 117 | +def withSpan (label : String := "") (m : StrategyM α) : StrategyM α := do |
| 118 | + startSpan |
| 119 | + let result ← m |
| 120 | + endSpan label |
| 121 | + return result |
| 122 | + |
| 123 | +def getSize : StrategyM Nat := fun cs => |
| 124 | + .ok (cs.remaining / 10 |>.max 1) cs |
| 125 | + |
| 126 | +end StrategyM |
| 127 | + |
| 128 | +class Strategy (α : Type) where |
| 129 | + draw : StrategyM α |
| 130 | + |
| 131 | +namespace Strategy |
| 132 | + |
| 133 | +instance instStrategyNat : Strategy Nat where |
| 134 | + draw := StrategyM.withSpan "Nat" do |
| 135 | + let bytes ← StrategyM.drawBytes 8 |
| 136 | + return (List.range 8).foldl (init := 0) fun acc i => acc * 256 + (bytes.get! i).toNat |
| 137 | + |
| 138 | +def natBelow (n : Nat) : StrategyM Nat := StrategyM.withSpan "Nat<" do |
| 139 | + return (← instStrategyNat.draw) % n |
| 140 | + |
| 141 | +def natRange (lo hi : Nat) : StrategyM Nat := StrategyM.withSpan "Nat[]" do |
| 142 | + return lo + (← instStrategyNat.draw) % (hi - lo + 1) |
| 143 | + |
| 144 | +instance : Strategy Bool where |
| 145 | + draw := StrategyM.withSpan "Bool" do |
| 146 | + return (← StrategyM.drawByte) % 2 == 1 |
| 147 | + |
| 148 | +instance : Strategy Int where |
| 149 | + draw := StrategyM.withSpan "Int" do |
| 150 | + let magnitude ← instStrategyNat.draw |
| 151 | + let negative ← Strategy.draw (α := Bool) |
| 152 | + return if negative then -magnitude else magnitude |
| 153 | + |
| 154 | +instance : Strategy UInt8 where |
| 155 | + draw := StrategyM.withSpan "UInt8" StrategyM.drawByte |
| 156 | + |
| 157 | +instance : Strategy UInt64 where |
| 158 | + draw := StrategyM.withSpan "UInt64" do |
| 159 | + let bytes ← StrategyM.drawBytes 8 |
| 160 | + return (List.range 8).foldl (init := (0 : UInt64)) fun acc i => acc * 256 + (bytes.get! i).toUInt64 |
| 161 | + |
| 162 | +instance : Strategy Char where |
| 163 | + draw := StrategyM.withSpan "Char" do |
| 164 | + return Char.ofNat (32 + (← StrategyM.drawByte).toNat % 95) -- printable ASCII |
| 165 | + |
| 166 | +partial def list (elem : StrategyM α) : StrategyM (List α) := StrategyM.withSpan "List" do |
| 167 | + if !(← Strategy.draw (α := Bool)) then return [] |
| 168 | + return (← elem) :: (← list elem) |
| 169 | + |
| 170 | +instance instStrategyList [Strategy α] : Strategy (List α) where |
| 171 | + draw := list Strategy.draw |
| 172 | + |
| 173 | +instance [Strategy α] : Strategy (Array α) where |
| 174 | + draw := StrategyM.withSpan "Array" do return (← instStrategyList.draw).toArray |
| 175 | + |
| 176 | +instance : Strategy String where |
| 177 | + draw := StrategyM.withSpan "String" do return String.ofList (← instStrategyList (α := Char).draw) |
| 178 | + |
| 179 | +instance [Strategy α] : Strategy (Option α) where |
| 180 | + draw := StrategyM.withSpan "Option" do |
| 181 | + if ← Strategy.draw (α := Bool) then return some (← Strategy.draw) |
| 182 | + return none |
| 183 | + |
| 184 | +instance [Strategy α] [Strategy β] : Strategy (α × β) where |
| 185 | + draw := StrategyM.withSpan "Prod" do return (← Strategy.draw, ← Strategy.draw) |
| 186 | + |
| 187 | +instance [Strategy α] [Strategy β] : Strategy (Sum α β) where |
| 188 | + draw := StrategyM.withSpan "Sum" do |
| 189 | + if ← Strategy.draw (α := Bool) then return .inl (← Strategy.draw) |
| 190 | + return .inr (← Strategy.draw) |
| 191 | + |
| 192 | +instance {n : Nat} : Strategy (Fin n.succ) where |
| 193 | + draw := StrategyM.withSpan "Fin" do |
| 194 | + let raw ← instStrategyNat.draw |
| 195 | + return ⟨raw % n.succ, Nat.mod_lt raw (Nat.zero_lt_succ n)⟩ |
| 196 | + |
| 197 | +end Strategy |
| 198 | + |
| 199 | +namespace Shrinker |
| 200 | + |
| 201 | +def deleteSpans (cs : ChoiceSequence) : List ChoiceSequence := |
| 202 | + cs.spans.toList.filterMap fun span => |
| 203 | + if span.stop <= cs.buffer.size then |
| 204 | + some { cs with |
| 205 | + buffer := cs.buffer.extract 0 span.start ++ cs.buffer.extract span.stop cs.buffer.size |
| 206 | + spans := #[] |
| 207 | + index := 0 |
| 208 | + } |
| 209 | + else none |
| 210 | + |
| 211 | +def reduceBytes (cs : ChoiceSequence) : List ChoiceSequence := |
| 212 | + List.range cs.buffer.size |>.filterMap fun i => |
| 213 | + let byte := cs.buffer.get! i |
| 214 | + if byte > 0 then some { cs with buffer := cs.buffer.set! i (byte / 2), spans := #[], index := 0 } |
| 215 | + else none |
| 216 | + |
| 217 | +def zeroSpans (cs : ChoiceSequence) : List ChoiceSequence := |
| 218 | + cs.spans.toList.filterMap fun span => |
| 219 | + if span.stop <= cs.buffer.size then |
| 220 | + let newBuffer := Id.run do |
| 221 | + let mut buf := cs.buffer |
| 222 | + for i in [span.start:span.stop] do |
| 223 | + buf := buf.set! i 0 |
| 224 | + return buf |
| 225 | + some { cs with buffer := newBuffer, spans := #[], index := 0 } |
| 226 | + else none |
| 227 | + |
| 228 | +def sortedByteReductions (cs : ChoiceSequence) : List ChoiceSequence := |
| 229 | + let indexed := List.range cs.buffer.size |>.map fun i => (i, cs.buffer.get! i) |
| 230 | + let sorted := indexed.filter (·.2 > 0) |>.toArray.qsort (fun a b => a.2 > b.2) |>.toList |
| 231 | + sorted.filterMap fun (i, byte) => |
| 232 | + some { cs with buffer := cs.buffer.set! i (byte / 2), spans := #[], index := 0 } |
| 233 | + |
| 234 | +def shrink (cs : ChoiceSequence) : List ChoiceSequence := |
| 235 | + deleteSpans cs ++ zeroSpans cs ++ sortedByteReductions cs |
| 236 | + |
| 237 | +def filterSmaller (original : ChoiceSequence) (candidates : List ChoiceSequence) : List ChoiceSequence := |
| 238 | + candidates.filter (·.lexLt original) |
| 239 | + |
| 240 | +end Shrinker |
| 241 | + |
| 242 | +def generateRandom (size : Nat) : IO ChoiceSequence := do |
| 243 | + let mut buf := ByteArray.empty |
| 244 | + for _ in [:size] do |
| 245 | + buf := buf.push (← IO.rand 0 255).toUInt8 |
| 246 | + return { buffer := buf, spans := #[], maxSize := size * 2 } |
| 247 | + |
| 248 | +def defaultDbPath : System.FilePath := ".plausible" / "examples" |
| 249 | + |
| 250 | +structure DbEntry where |
| 251 | + testName : String |
| 252 | + choiceSeq : ByteArray |
| 253 | + timestamp : Nat |
| 254 | + deriving Repr |
| 255 | + |
| 256 | +namespace ExampleDb |
| 257 | + |
| 258 | +def ensureDir (path : System.FilePath := defaultDbPath) : IO Unit := |
| 259 | + IO.FS.createDirAll path |
| 260 | + |
| 261 | +def save (testName : String) (cs : ChoiceSequence) (path : System.FilePath := defaultDbPath) : IO Unit := do |
| 262 | + ensureDir path |
| 263 | + IO.FS.writeBinFile (path / s!"{testName}.bin") cs.buffer |
| 264 | + |
| 265 | +def load (testName : String) (path : System.FilePath := defaultDbPath) : IO (Option ChoiceSequence) := do |
| 266 | + let filename := path / s!"{testName}.bin" |
| 267 | + if ← filename.pathExists then |
| 268 | + return some (ChoiceSequence.ofBytes (← IO.FS.readBinFile filename)) |
| 269 | + return none |
| 270 | + |
| 271 | +def listTests (path : System.FilePath := defaultDbPath) : IO (List String) := do |
| 272 | + if ← path.pathExists then |
| 273 | + return (← path.readDir).toList.filterMap fun entry => |
| 274 | + if entry.fileName.endsWith ".bin" then some (entry.fileName.dropEnd 4).toString |
| 275 | + else none |
| 276 | + return [] |
| 277 | + |
| 278 | +end ExampleDb |
| 279 | + |
| 280 | +inductive HealthWarning where |
| 281 | + | tooSlow (avgMs : Float) |
| 282 | + | filterTooMuch (ratio : Float) |
| 283 | + | dataTooLarge (avgBytes : Float) |
| 284 | + deriving Repr |
| 285 | + |
| 286 | +namespace HealthCheck |
| 287 | + |
| 288 | +def checkSpeed (totalMs : Float) (numExamples : Nat) : Option HealthWarning := |
| 289 | + let avgMs := totalMs / numExamples.toFloat |
| 290 | + if avgMs > 200.0 then some (.tooSlow avgMs) else none |
| 291 | + |
| 292 | +def checkFilterRatio (validCount invalidCount : Nat) : Option HealthWarning := |
| 293 | + let total := validCount + invalidCount |
| 294 | + if total > 0 then |
| 295 | + let ratio := invalidCount.toFloat / total.toFloat |
| 296 | + if ratio > 0.5 then some (.filterTooMuch ratio) else none |
| 297 | + else none |
| 298 | + |
| 299 | +def checkDataSize (totalBytes : Nat) (numExamples : Nat) : Option HealthWarning := |
| 300 | + let avgBytes := totalBytes.toFloat / (if numExamples > 0 then numExamples.toFloat else 1.0) |
| 301 | + if avgBytes > 4096.0 then some (.dataTooLarge avgBytes) else none |
| 302 | + |
| 303 | +def runAll (totalMs : Float) (validCount invalidCount : Nat) (totalBytes : Nat) : List HealthWarning := |
| 304 | + [checkSpeed totalMs (validCount + invalidCount), |
| 305 | + checkFilterRatio validCount invalidCount, |
| 306 | + checkDataSize totalBytes validCount].filterMap id |
| 307 | + |
| 308 | +end HealthCheck |
| 309 | + |
| 310 | +def negInfinity : Float := -1.0e308 |
| 311 | + |
| 312 | +structure TargetState where |
| 313 | + bestScore : Float := negInfinity |
| 314 | + bestChoiceSeq : Option ChoiceSequence := none |
| 315 | + observations : Array (Float × ChoiceSequence) := #[] |
| 316 | + deriving Inhabited |
| 317 | + |
| 318 | +namespace Targeted |
| 319 | + |
| 320 | +def observe (score : Float) (cs : ChoiceSequence) (state : TargetState) : TargetState := |
| 321 | + let newObs := state.observations.push (score, cs) |
| 322 | + if score > state.bestScore then |
| 323 | + { bestScore := score, bestChoiceSeq := some cs, observations := newObs } |
| 324 | + else |
| 325 | + { state with observations := newObs } |
| 326 | + |
| 327 | +def mutate (cs : ChoiceSequence) (mutationRate : Float := 0.1) : IO ChoiceSequence := do |
| 328 | + let mut buf := cs.buffer |
| 329 | + for i in [:cs.buffer.size] do |
| 330 | + if (← IO.rand 0 100).toFloat / 100.0 < mutationRate then |
| 331 | + buf := buf.set! i (← IO.rand 0 255).toUInt8 |
| 332 | + return { cs with buffer := buf, spans := #[], index := 0 } |
| 333 | + |
| 334 | +def selectForMutation (state : TargetState) : IO (Option ChoiceSequence) := do |
| 335 | + if state.observations.isEmpty then return none |
| 336 | + let sorted := state.observations.toList.toArray.qsort (fun a b => a.1 > b.1) |
| 337 | + let topCount := (sorted.size / 4).max 1 |
| 338 | + return some sorted[← IO.rand 0 (topCount - 1)]!.2 |
| 339 | + |
| 340 | +end Targeted |
| 341 | + |
| 342 | +structure Config where |
| 343 | + numTests : Nat := 100 |
| 344 | + initialSize : Nat := 64 |
| 345 | + maxSize : Nat := 8 * 1024 |
| 346 | + useDb : Bool := true |
| 347 | + healthChecks : Bool := true |
| 348 | + targeted : Bool := false |
| 349 | + mutationRate : Float := 0.1 |
| 350 | + traceShrink : Bool := false |
| 351 | + quiet : Bool := false |
| 352 | + deriving Repr, Inhabited |
| 353 | + |
| 354 | +structure TestRun where |
| 355 | + status : Status |
| 356 | + choiceSeq : ChoiceSequence |
| 357 | + shrinkSteps : Nat := 0 |
| 358 | + healthWarnings : List HealthWarning := [] |
| 359 | + targetState : TargetState := {} |
| 360 | + deriving Inhabited |
| 361 | + |
| 362 | +def runStrategy [Strategy α] (cs : ChoiceSequence) : Option (α × ChoiceSequence) := |
| 363 | + match Strategy.draw (α := α) cs with |
| 364 | + | .ok value cs' => some (value, cs') |
| 365 | + | .overrun => none |
| 366 | + |
| 367 | +partial def shrinkLoop [Strategy α] (test : α → Bool) (cs : ChoiceSequence) |
| 368 | + (maxSteps : Nat := 1000) (trace : Bool := false) : ChoiceSequence × Nat := |
| 369 | + let rec go (current : ChoiceSequence) (steps : Nat) (fuel : Nat) : ChoiceSequence × Nat := |
| 370 | + if fuel == 0 then (current, steps) |
| 371 | + else |
| 372 | + let candidates := Shrinker.shrink current |> Shrinker.filterSmaller current |
| 373 | + match candidates.find? fun c => |
| 374 | + match runStrategy (α := α) c with |
| 375 | + | some (value, _) => !test value |
| 376 | + | none => false |
| 377 | + with |
| 378 | + | some smaller => |
| 379 | + if trace then |
| 380 | + dbgTrace s!"[Shrink] {current.size} → {smaller.size} bytes" fun _ => |
| 381 | + go smaller (steps + 1) (fuel - 1) |
| 382 | + else |
| 383 | + go smaller (steps + 1) (fuel - 1) |
| 384 | + | none => (current, steps) |
| 385 | + go cs 0 maxSteps |
| 386 | + |
| 387 | +end Plausible.Conjecture |
0 commit comments