diff --git a/.github/workflows/scala.yml b/.github/workflows/scala.yml index dc8bb5fd3..5610a96e6 100644 --- a/.github/workflows/scala.yml +++ b/.github/workflows/scala.yml @@ -9,7 +9,7 @@ on: jobs: build: - runs-on: ubuntu-20.04 + runs-on: ubuntu-22.04 defaults: run: working-directory: ./ @@ -27,7 +27,7 @@ jobs: - name: Set up dependencies run: | sudo apt-get update - sudo DEBIAN_FRONTEND=noninteractive apt-get install -y git g++ cmake bison flex libboost-all-dev python + sudo DEBIAN_FRONTEND=noninteractive apt-get install -y git g++ cmake bison flex libboost-all-dev 2to3 python-is-python3 sudo DEBIAN_FRONTEND=noninteractive apt-get install -y perl minisat curl gnupg2 locales clang-11 wget - name: Generate test files (LLVM IR) run: | @@ -77,3 +77,4 @@ jobs: sbt 'testOnly gensym.wasm.TestEval' sbt 'testOnly gensym.wasm.TestScriptRun' sbt 'testOnly gensym.wasm.TestConcolic' + sbt 'testOnly gensym.wasm.TestDriver' diff --git a/benchmarks/wasm/branch-strip-buggy.wat b/benchmarks/wasm/branch-strip-buggy.wat new file mode 100644 index 000000000..c957db7f6 --- /dev/null +++ b/benchmarks/wasm/branch-strip-buggy.wat @@ -0,0 +1,46 @@ +(module + (type (;0;) (func (param i32 i32) (result i32))) + (type (;1;) (func (param i32))) + (func (;0;) (type 0) (param i32 i32) (result i32) + local.get 0 + i32.const 0 + i32.le_s + if (result i32) ;; label = @1 + i32.const 1 + else + local.get 1 + i32.const 0 + i32.le_s + end + if (result i32) ;; label = @1 + i32.const -1 + else + local.get 0 + local.get 0 + i32.mul + local.get 1 + local.get 1 + i32.mul + i32.add + i32.const 25 + i32.eq + if (result i32) ;; label = @2 + i32.const 1 + else + i32.const 0 + call 2 + end + end + ) + (export "f" (func 0)) + (func $real_main + ;; TODO: is there a better way to put symbolic values on the stack? + i32.const 0 + i32.symbolic + i32.const 1 + i32.symbolic + call 0 + ) + (import "console" "assert" (func (type 1))) + (export "real_main" (func 1)) +) diff --git a/benchmarks/wasm/branch-strip.wat b/benchmarks/wasm/branch-strip.wat index 81ca303d1..beb6fbc88 100644 --- a/benchmarks/wasm/branch-strip.wat +++ b/benchmarks/wasm/branch-strip.wat @@ -4,11 +4,14 @@ local.get 0 i32.const 0 i32.le_s - local.get 1 - i32.const 0 - i32.le_s - i32.or - if (result i32) ;; label = @1 + if (result i32) ;; label = @1 + i32.const 1 + else + local.get 1 + i32.const 0 + i32.le_s + end + if (result i32) ;; label = @1 i32.const -1 else local.get 0 @@ -20,7 +23,7 @@ i32.add i32.const 25 i32.eq - if (result i32) ;; label = @2 + if (result i32) ;; label = @2 i32.const 1 else i32.const 0 diff --git a/benchmarks/wasm/branch-strip1.wat b/benchmarks/wasm/branch-strip1.wat new file mode 100644 index 000000000..81ca303d1 --- /dev/null +++ b/benchmarks/wasm/branch-strip1.wat @@ -0,0 +1,41 @@ +(module + (type (;0;) (func (param i32 i32) (result i32))) + (func (;0;) (type 0) (param i32 i32) (result i32) + local.get 0 + i32.const 0 + i32.le_s + local.get 1 + i32.const 0 + i32.le_s + i32.or + if (result i32) ;; label = @1 + i32.const -1 + else + local.get 0 + local.get 0 + i32.mul + local.get 1 + local.get 1 + i32.mul + i32.add + i32.const 25 + i32.eq + if (result i32) ;; label = @2 + i32.const 1 + else + i32.const 0 + end + end + ) + (export "f" (func 0)) + (func $real_main + ;; TODO: is there a better way to put symbolic values on the stack? + i32.const 0 + i32.symbolic + i32.const 1 + i32.symbolic + call 0 + ) + + (export "real_main" (func 1)) +) diff --git a/benchmarks/wasm/branch.wat b/benchmarks/wasm/branch.wat index 693f7ae71..98211ca21 100644 --- a/benchmarks/wasm/branch.wat +++ b/benchmarks/wasm/branch.wat @@ -2,9 +2,10 @@ (func $f (param $x i32) (param $y i32) (result i32) ;; if (x <= 0 || y <= 0) (if (result i32) - (i32.or + (if (result i32) (i32.le_s (local.get $x) (i32.const 0)) - (i32.le_s (local.get $y) (i32.const 0)) + (then (i32.const 1)) + (else (i32.le_s (local.get $y) (i32.const 0))) ) (then (i32.const -1)) ;; return -1 (else diff --git a/benchmarks/wasm/branch1.wat b/benchmarks/wasm/branch1.wat new file mode 100644 index 000000000..693f7ae71 --- /dev/null +++ b/benchmarks/wasm/branch1.wat @@ -0,0 +1,29 @@ +(module + (func $f (param $x i32) (param $y i32) (result i32) + ;; if (x <= 0 || y <= 0) + (if (result i32) + (i32.or + (i32.le_s (local.get $x) (i32.const 0)) + (i32.le_s (local.get $y) (i32.const 0)) + ) + (then (i32.const -1)) ;; return -1 + (else + ;; if (x * x + y * y == 25) + (if (result i32) + (i32.eq + (i32.add + (i32.mul (local.get $x) (local.get $x)) + (i32.mul (local.get $y) (local.get $y)) + ) + (i32.const 25) + ) + (then (i32.const 1)) ;; return 1 + (else (i32.const 0)) ;; return 0 + ) + ) + ) + ) + + ;; Optionally export the function + (export "f" (func $f)) +) diff --git a/src/main/scala/wasm/ConcolicDriver.scala b/src/main/scala/wasm/ConcolicDriver.scala index 1f766efde..42f9707aa 100644 --- a/src/main/scala/wasm/ConcolicDriver.scala +++ b/src/main/scala/wasm/ConcolicDriver.scala @@ -9,9 +9,10 @@ import scala.collection.immutable.Queue import scala.collection.mutable.{HashMap, HashSet} import z3.scala._ +import scala.tools.nsc.doc.model.Val object ConcolicDriver { - def condsToEnv(conds: List[Cond])(implicit z3Ctx: Z3Context): HashMap[Int, Value] = { + def condsToEnv(conds: List[Cond])(implicit z3Ctx: Z3Context): Option[HashMap[Int, Value]] = { val intSort = z3Ctx.mkIntSort() val boolSort = z3Ctx.mkBoolSort() @@ -24,16 +25,57 @@ object ConcolicDriver { case Add(_) => z3Ctx.mkAdd(symVToZ3(lhs), symVToZ3(rhs)) // does numtype matter? case Sub(_) => z3Ctx.mkSub(symVToZ3(lhs), symVToZ3(rhs)) case Mul(_) => z3Ctx.mkMul(symVToZ3(lhs), symVToZ3(rhs)) - case _ => ??? + case Or(_) => + var result = z3Ctx.mkBVOr( + z3Ctx.mkInt2BV(32, symVToZ3(lhs)), + z3Ctx.mkInt2BV(32, symVToZ3(rhs)) + ) + z3Ctx.mkBV2Int(result, false) + case _ => throw new NotImplementedError(s"Unsupported binary operation: $op") } case SymUnary(op, v) => op match { case _ => ??? } case SymIte(cond, thenV, elseV) => z3Ctx.mkITE(condToZ3(cond), symVToZ3(thenV), symVToZ3(elseV)) - case Concrete(v) => ??? - case _ => ??? + case Concrete(v) => + v match { + // todo: replace with bitvector + case I32V(i) => z3Ctx.mkInt(i, intSort) + case I64V(i) => z3Ctx.mkNumeral(i.toString(), intSort) + // TODO: Float + case _ => ??? + } + case RelCond(op, lhs, rhs) => + val res = op match { + case GeS(_) => z3Ctx.mkGE(symVToZ3(lhs), symVToZ3(rhs)) + case GtS(_) => z3Ctx.mkGT(symVToZ3(lhs), symVToZ3(rhs)) + case LtS(_) => z3Ctx.mkLT(symVToZ3(lhs), symVToZ3(rhs)) + case LeS(_) => z3Ctx.mkLE(symVToZ3(lhs), symVToZ3(rhs)) + case GtU(_) => z3Ctx.mkGT(symVToZ3(lhs), symVToZ3(rhs)) + case GeU(_) => z3Ctx.mkGE(symVToZ3(lhs), symVToZ3(rhs)) + case LtU(_) => z3Ctx.mkLT(symVToZ3(lhs), symVToZ3(rhs)) + case LeU(_) => z3Ctx.mkLE(symVToZ3(lhs), symVToZ3(rhs)) + case Eq(_) => z3Ctx.mkEq(symVToZ3(lhs), symVToZ3(rhs)) + case Ne(_) => z3Ctx.mkNot(z3Ctx.mkEq(symVToZ3(lhs), symVToZ3(rhs))) + case Ge(_) => z3Ctx.mkGE(symVToZ3(lhs), symVToZ3(rhs)) + case Gt(_) => z3Ctx.mkGT(symVToZ3(lhs), symVToZ3(rhs)) + case Le(_) => z3Ctx.mkLE(symVToZ3(lhs), symVToZ3(rhs)) + case Lt(_) => z3Ctx.mkLT(symVToZ3(lhs), symVToZ3(rhs)) + } + // convert resutl to int + z3Ctx.mkITE(res, z3Ctx.mkInt(1, intSort), z3Ctx.mkInt(0, intSort)) + case _ => throw new NotImplementedError(s"Unsupported SymVal: $symV") + } + + def getIndexOfSym(sym: String): Int = { + val pattern = ".*_(\\d+)$".r + sym match { + case pattern(index) => index.toInt + case _ => throw new IllegalArgumentException(s"Invalid symbol format: $sym") + } } + def condToZ3(cond: Cond): Z3AST = cond match { case CondEqz(v) => z3Ctx.mkEq(symVToZ3(v), z3Ctx.mkInt(0, intSort)) case Not(cond) => z3Ctx.mkNot(condToZ3(cond)) @@ -49,42 +91,55 @@ object ConcolicDriver { } // solve for all vars + println(s"solving constraints: ${solver.toString()}") solver.check() match { case Some(true) => { val model = solver.getModel() val vars = model.getConsts - val env = HashMap() + val env = HashMap[Int, Value]() for (v <- vars) { val name = v.getName val ast = z3Ctx.mkConst(name, intSort) val value = model.eval(ast) + println(s"name: $name") + println(s"value: $value") + // TODO: support other types of symbolic values(currently only i32) val intValue = if (value.isDefined && value.get.getSort.isIntSort) { - I32V(value.toString.toInt) + val negPattern = """\(\-\s*(\d+)\)""".r + val plainPattern = """(-?\d+)""".r + val num = value.get.toString match { + case negPattern(digits) => -digits.toInt + case plainPattern(number) => number.toInt + case _ => throw new IllegalArgumentException("Invalid format") + } + I32V(num) } else { ??? } - // env += (name.toString -> intValue) - ??? + env += (getIndexOfSym(name.toString) -> intValue) } - ??? - // env + println(s"solved env: $env") + Some(env) } - case _ => ??? + case _ => None } } def negateCond(conds: List[Cond], i: Int): List[Cond] = { - ??? + conds(i).negated :: conds.drop(i + 1) } def checkPCToFile(pc: List[Cond]): Unit = { + // TODO: what this function for? ??? } def exec(module: Module, mainFun: String, startEnv: HashMap[Int, Value])(implicit z3Ctx: Z3Context) = { val worklist = Queue(startEnv) - // val visited = ??? // how to avoid re-execution - + val unreachables = HashSet[ExploreTree]() + val visited = HashSet[ExploreTree]() + // the root node of exploration tree + val root = new ExploreTree() def loop(worklist: Queue[HashMap[Int, Value]]): Unit = worklist match { case Queue() => () case env +: rest => { @@ -92,31 +147,37 @@ object ConcolicDriver { Evaluator(moduleInst).execWholeProgram( Some(mainFun), env, - (_endStack, _endSymStack, pathConds) => { - println(s"env: $env") - val newEnv = condsToEnv(pathConds) - val newWork = for (i <- 0 until pathConds.length) yield { - val newConds = negateCond(pathConds, i) - checkPCToFile(newConds) - condsToEnv(newConds) + root, + (_endStack, _endSymStack, tree) => { + tree.fillWithFinished() + val unexploredTrees = root.unexploredTrees() + // if a node is already visited or marked as unreachable, don't try to explore it + val addedNewWork = unexploredTrees.filterNot(unreachables.contains) + .filterNot(visited.contains) + .flatMap { tree => + val conds = tree.collectConds() + val newEnv = condsToEnv(conds) + // if the path conditions to reach this node are unsatisfiable, mark it as unreachable. + if (newEnv.isEmpty) unreachables.add(tree) + newEnv + } + for (tree <- unexploredTrees) { + visited.add(tree) } - loop(rest ++ newWork) + loop(rest ++ addedNewWork) } ) } } loop(worklist) + println(s"unreachable trees number: ${unreachables.size}") + println(s"number of normal explored paths: ${root.finishedTrees().size}") + val failedTrees = root.failedTrees() + println(s"number of failed explored paths: ${failedTrees.size}") + for (tree <- failedTrees) { + println(s"find a failed endpoint: ${tree}") + } + println(s"exploration tree: ${root.toString}") } } - -object DriverSimpleTest { - def fileTestDriver(file: String, mainFun: String, startEnv: HashMap[Int, Value]) = { - import gensym.wasm.concolicminiwasm._ - import collection.mutable.ArrayBuffer - val module = Parser.parseFile(file) - ConcolicDriver.exec(module, mainFun, startEnv)(new Z3Context()) - } - - def main(args: Array[String]) = {} -} diff --git a/src/main/scala/wasm/ConcolicMiniWasm.scala b/src/main/scala/wasm/ConcolicMiniWasm.scala index 9c621e47e..849fd8314 100644 --- a/src/main/scala/wasm/ConcolicMiniWasm.scala +++ b/src/main/scala/wasm/ConcolicMiniWasm.scala @@ -33,6 +33,9 @@ object ModuleInstance { } object Primitives { + // a random number generator with fixed seed + val rng = new Random(0) + def evalBinOp(op: BinOp, lhs: Value, rhs: Value): Value = op match { case Add(_) => (lhs, rhs) match { @@ -221,11 +224,20 @@ object Primitives { } def randomOfTy(ty: ValueType): Value = ty match { - case NumType(I32Type) => I32V(Random.nextInt()) - case NumType(I64Type) => I64V(Random.nextLong()) - case NumType(F32Type) => F32V(Random.nextFloat()) - case NumType(F64Type) => F64V(Random.nextDouble()) + case NumType(I32Type) => I32V(rng.nextInt()) + case NumType(I64Type) => I64V(rng.nextLong()) + case NumType(F32Type) => F32V(rng.nextFloat()) + case NumType(F64Type) => F64V(rng.nextDouble()) } + + def getFuncType(ty: BlockType): FuncType = + ty match { + case VarBlockType(_, None) => + ??? // TODO: fill this branch until we handle type index correctly + case VarBlockType(_, Some(tipe)) => tipe + case ValBlockType(Some(tipe)) => FuncType(List(), List(), List(tipe)) + case ValBlockType(None) => FuncType(List(), List(), List()) + } } case class Frame(module: ModuleInstance, locals: ArrayBuffer[Value], symLocals: ArrayBuffer[SymVal]) @@ -233,8 +245,10 @@ case class Frame(module: ModuleInstance, locals: ArrayBuffer[Value], symLocals: case class Evaluator(module: ModuleInstance) { import Primitives._ - type RetCont = (List[Value], List[SymVal], List[Cond]) => Unit - type Cont = (List[Value], List[SymVal], List[Cond]) => Unit + type Cont = (List[Value], List[SymVal], ExploreTree) => Unit + type RetCont = Cont + + var driverCont: Cont = null; val symEnv = HashMap[Int, Value]() @@ -244,21 +258,25 @@ case class Evaluator(module: ModuleInstance) { concStack: List[Value], symStack: List[SymVal], frame: Frame, - ret: RetCont, + kont: RetCont, trail: List[Cont] - )(implicit pathConds: List[Cond]): Unit = { - if (insts.isEmpty) return ret(concStack, symStack, pathConds) + )(implicit tree: ExploreTree): Unit = { + if (insts.isEmpty) return kont(concStack, symStack, tree) - println(s"pathConds: $pathConds") val inst = insts.head val rest = insts.tail - println(s"inst: $inst, concStack: $concStack, symStack: $symStack") + println( + s"""|inst: $inst + |rest: $rest + |concStack: $concStack + |symStack: $symStack + |pathConds: ${tree.collectConds()}""".stripMargin) inst match { case PushSym(name, v) => - eval(rest, v :: concStack, SymV(name) :: symStack, frame, ret, trail) + eval(rest, v :: concStack, SymV(name) :: symStack, frame, kont, trail) case Symbolic(ty) => val I32V(symIndex) :: newStack = concStack val symVal = SymV(s"sym_$symIndex") @@ -266,47 +284,47 @@ case class Evaluator(module: ModuleInstance) { symEnv(symIndex) = Primitives.randomOfTy(ty) } val v = symEnv(symIndex) - eval(rest, v :: newStack, symVal :: symStack, frame, ret, trail) - case Drop => eval(rest, concStack.tail, symStack.tail, frame, ret, trail) + eval(rest, v :: newStack, symVal :: symStack.tail, frame, kont, trail) + case Drop => eval(rest, concStack.tail, symStack.tail, frame, kont, trail) case Select(_) => val I32V(cond) :: v2 :: v1 :: newStack = concStack val symCond :: symV2 :: symV1 :: newSymStack = symStack val value = if (cond == 0) v1 else v2 val symVal = SymIte(CondEqz(symCond), symV1, symV2) - eval(rest, value :: newStack, symVal :: newSymStack, frame, ret, trail) + eval(rest, value :: newStack, symVal :: newSymStack, frame, kont, trail) case LocalGet(i) => - eval(rest, frame.locals(i) :: concStack, frame.symLocals(i) :: symStack, frame, ret, trail) + eval(rest, frame.locals(i) :: concStack, frame.symLocals(i) :: symStack, frame, kont, trail) case LocalSet(i) => val value :: newStack = concStack val symVal :: newSymStack = symStack frame.locals(i) = value frame.symLocals(i) = symVal - eval(rest, newStack, newSymStack, frame, ret, trail) + eval(rest, newStack, newSymStack, frame, kont, trail) case LocalTee(i) => val value :: _ = concStack val symVal :: _ = symStack frame.locals(i) = value frame.symLocals(i) = symVal - eval(rest, concStack, symStack, frame, ret, trail) + eval(rest, concStack, symStack, frame, kont, trail) case GlobalGet(i) => val (conc, sym) = frame.module.globals(i) - eval(rest, conc.value :: concStack, sym :: symStack, frame, ret, trail) + eval(rest, conc.value :: concStack, sym :: symStack, frame, kont, trail) case GlobalSet(i) => val value :: newStack = concStack val symVal :: newSymStack = symStack val oldConc = frame.module.globals(i)._1 frame.module.globals(i) = (oldConc.copy(value = value), symVal) - eval(rest, newStack, newSymStack, frame, ret, trail) + eval(rest, newStack, newSymStack, frame, kont, trail) // I think these are essentially dummies in WASP // to more accurately replace them, we should probably // add a dummy memory size field to ConcolicMemory case MemorySize => - eval(rest, I32V(100) :: concStack, Concrete(I32V(100)) :: symStack, frame, ret, trail) + eval(rest, I32V(100) :: concStack, Concrete(I32V(100)) :: symStack, frame, kont, trail) // val cv = I32V(frame.module.memory.head.size) // val sv = Concrete(cv) // eval(rest, cv::concStack, sv::symStack, frame, ret, trail) case MemoryGrow => - eval(rest, I32V(100) :: concStack, Concrete(I32V(100)) :: symStack, frame, ret, trail) + eval(rest, I32V(100) :: concStack, Concrete(I32V(100)) :: symStack, frame, kont, trail) // val I32V(delta)::newStack = concStack // val mem = frame.module.memory.head // val oldSize = mem.size @@ -330,81 +348,94 @@ case class Evaluator(module: ModuleInstance) { // frame.module.memory.head.copy(dest, src, n) // eval(rest, newStack, frame, ret, trail) // } - case Const(n) => eval(rest, n :: concStack, Concrete(n) :: symStack, frame, ret, trail) + case Const(n) => eval(rest, n :: concStack, Concrete(n) :: symStack, frame, kont, trail) case Binary(op) => val v2 :: v1 :: newStack = concStack val sv2 :: sv1 :: newSymStack = symStack - eval(rest, evalBinOp(op, v1, v2) :: newStack, evalSymBinOp(op, sv1, sv2) :: newSymStack, frame, ret, trail) + eval(rest, evalBinOp(op, v1, v2) :: newStack, evalSymBinOp(op, sv1, sv2) :: newSymStack, frame, kont, trail) case Unary(op) => val v :: newStack = concStack val sv :: newSymStack = symStack - eval(rest, evalUnaryOp(op, v) :: newStack, SymUnary(op, sv) :: newSymStack, frame, ret, trail) + eval(rest, evalUnaryOp(op, v) :: newStack, SymUnary(op, sv) :: newSymStack, frame, kont, trail) case Compare(op) => val v2 :: v1 :: newStack = concStack val sv2 :: sv1 :: newSymStack = symStack - eval(rest, evalRelOp(op, v1, v2) :: newStack, evalSymRelOp(op, sv1, sv2) :: newSymStack, frame, ret, trail) + eval(rest, evalRelOp(op, v1, v2) :: newStack, evalSymRelOp(op, sv1, sv2) :: newSymStack, frame, kont, trail) case Test(op) => val v :: newStack = concStack val sv :: newSymStack = symStack val test = evalTestOp(op, v) val symTest = evalSymTestOp(op, sv) - eval(rest, test :: newStack, symTest :: newSymStack, frame, ret, trail) + eval(rest, test :: newStack, symTest :: newSymStack, frame, kont, trail) case Store(StoreOp(align, offset, ty, None)) => val I32V(v) :: I32V(addr) :: newStack = concStack val sv :: sa :: newSymStack = symStack // need to concretize sa and then checkAccess frame.module.memory(0).storeInt(addr + offset, (v, sv)) - eval(rest, newStack, symStack.drop(2), frame, ret, trail) + eval(rest, newStack, symStack.drop(2), frame, kont, trail) case Load(LoadOp(align, offset, ty, None, None)) => val I32V(addr) :: newStack = concStack val sa :: newSymStack = symStack // need to concretize sv and then checkAccess val (value, sv) = frame.module.memory(0).loadInt(addr + offset) - eval(rest, I32V(value) :: newStack, sv :: newSymStack, frame, ret, trail) + eval(rest, I32V(value) :: newStack, sv :: newSymStack, frame, kont, trail) case Nop => - eval(rest, concStack, symStack, frame, ret, trail) + eval(rest, concStack, symStack, frame, kont, trail) case Unreachable => throw new RuntimeException("Unreachable") case Block(ty, inner) => - val k: Cont = (retStack, retSymStack, newPathConds) => - eval(rest, concStack ++ retStack, symStack ++ retSymStack, frame, ret, trail)(newPathConds) - - eval(inner, List(), List(), frame, k, k :: trail) + val funcTy = getFuncType(ty) + val (inputSize, outputSize) = (funcTy.inps.size, funcTy.out.size) + val (inputs, restStack) = concStack.splitAt(inputSize) + val (symInputs, restSymStack) = symStack.splitAt(inputSize) + val restK: Cont = (retStack, retSymStack, tree) => + eval(rest, retStack.take(outputSize) ++ restStack, retSymStack.take(outputSize) ++ restSymStack, frame, kont, trail)(tree) + eval(inner, inputs, symInputs, frame, restK, restK :: trail) case Loop(ty, inner) => - val k: Cont = (retStack, retSymStack, newPathConds) => - eval(insts, concStack ++ retStack, symStack ++ retSymStack, frame, ret, trail)(newPathConds) - eval(inner, List(), List(), frame, k, k :: trail) + val funcTy = getFuncType(ty) + val (inputSize, outputSize) = (funcTy.inps.size, funcTy.out.size) + val (inputs, restStack) = concStack.splitAt(inputSize) + val (symInputs, restSymStack) = symStack.splitAt(inputSize) + val restK: Cont = (retStack, retSymStack, tree) => + eval(rest, retStack.take(outputSize) ++ restStack, retSymStack.take(outputSize) ++ restSymStack, frame, kont, trail)(tree) + def loop(retStack: List[Value], retSymStack: List[SymVal], tree: ExploreTree): Unit = + eval(inner, retStack.take(inputSize), retSymStack.take(inputSize), frame, restK, loop _ :: trail)(tree) + loop(inputs, symInputs, tree) case If(ty, thn, els) => val scnd :: newSymStack = symStack val I32V(cond) :: newStack = concStack - val inner = if (cond != 0) thn else els - val newPathConds = scnd match { - case Concrete(_) => pathConds - case _ => if (cond != 0) CondEqz(scnd) :: pathConds else Not(CondEqz(scnd)) :: pathConds + val (ifNode, elseNode) = if (scnd.isInstanceOf[Concrete]) { + // if this is a concrete value, we don't need to put + (tree, tree) + } else { + val ifElseNode = tree.fillWithIfElse(Not(CondEqz(scnd))) + (ifElseNode.thenNode, ifElseNode.elseNode) } - val k: Cont = (retStack, retSymStack, newPathConds) => - eval(rest, retStack ++ newStack, retSymStack ++ newSymStack, frame, ret, trail)(newPathConds) - eval(inner, List(), List(), frame, ret, k :: trail)(newPathConds) + val restK: Cont = (retStack, retSymStack, tree) => + eval(rest, retStack ++ newStack, retSymStack ++ newSymStack, frame, kont, trail)(tree) + if (cond != 0) + eval(thn, List(), List(), frame, restK, restK :: trail)(ifNode) + else + eval(els, List(), List(), frame, restK, restK :: trail)(elseNode) case Br(label) => - trail(label)(concStack, symStack, pathConds) + trail(label)(concStack, symStack, tree) case BrIf(label) => val scnd :: newSymStack = symStack val I32V(cond) :: newStack = concStack - val newPathConds = scnd match { - case Concrete(_) => pathConds - case _ => if (cond != 0) CondEqz(scnd) :: pathConds else Not(CondEqz(scnd)) :: pathConds + val (ifNode, elseNode) = if (scnd.isInstanceOf[Concrete]) { + // if this is a concrete value, we don't need to put + (tree, tree) + } else { + val ifElseNode = tree.fillWithIfElse(Not(CondEqz(scnd))) + (ifElseNode.thenNode, ifElseNode.elseNode) } - if (cond == 0) eval(rest, newStack, newSymStack, frame, ret, trail)(newPathConds) - else trail(label)(newStack, newSymStack, newPathConds) - case Return => ret(concStack, symStack, pathConds) - case Call(f) => - evalCall(rest, concStack, symStack, frame, ret, trail, f, false) + if (cond == 0) eval(rest, newStack, newSymStack, frame, kont, trail)(ifNode) + else trail(label)(newStack, newSymStack, elseNode) + case Return => trail.last(concStack, symStack, tree) + case Call(f) => evalCall(rest, concStack, symStack, frame, kont, trail, f, false) case _ => ??? } } - // def eval(insts: List[Instr], concStack: List[Value], symStack: List[SymVal], - // frame: Frame, ret: RetCont, trail: List[Cont])(implicit pathConds: List[Cond]) - def evalCall( rest: List[Instr], concStack: List[Value], @@ -414,7 +445,7 @@ case class Evaluator(module: ModuleInstance) { trail: List[Cont], funcIndex: Int, isTail: Boolean - )(implicit pathConds: List[Cond]): Unit = + )(implicit tree: ExploreTree): Unit = module.funcs(funcIndex) match { case FuncDef(_, FuncBodyDef(ty, _, locals, body)) => println(s"call $funcIndex: locals: ${locals}") @@ -425,12 +456,20 @@ case class Evaluator(module: ModuleInstance) { val frameLocals = args ++ locals.map(_ => I32V(0)) // GW: always I32? or depending on their types? val symFrameLocals = symArgs ++ locals.map(_ => Concrete(I32V(0))) val newFrame = Frame(frame.module, ArrayBuffer(frameLocals: _*), ArrayBuffer(symFrameLocals: _*)) - val newRet: RetCont = (retStack, retSymStack, newPathConds) => - eval(rest, retStack ++ newStack, retSymStack ++ newSymStack, frame, ret, trail)(newPathConds) - // val k: Cont = (retStack, symStack) => - // eval(rest, retStack, frame, ret, trail) - eval(body, List(), List(), newFrame, newRet, newRet :: trail) // GW: should we install new trail cont? - + val restK: RetCont = (retStack, retSymStack, tree) => + eval(rest, retStack ++ newStack, retSymStack ++ newSymStack, frame, ret, trail)(tree) + eval(body, List(), List(), newFrame, restK, List(restK)) // GW: should we install new trail cont? + case Import("console", "assert", _) => + val I32V(v) :: newStack = concStack + val sv :: newSymStack = symStack + if (v == 0) { + println(s"Assertion failed: find a bug with input $symEnv") + tree.fillWithFail(symEnv) + // go to toplevel halt continuation + driverCont(concStack, symStack, tree) + } else { + eval(rest, newStack, newSymStack, frame, ret, trail) + } // TODO: clean up the other cases // case Import("console", "log", _) => // val I32V(v) :: newStack = stack @@ -441,7 +480,7 @@ case class Evaluator(module: ModuleInstance) { // println(v) // eval(rest, newStack, frame, kont, trail) // case Import(_, _, _) => throw new Exception(s"Unknown import at $funcIndex") - case _ => throw new Exception(s"Definition at $funcIndex is not callable") + case _ => throw new Exception(s"Definition ${module.funcs(funcIndex)} at $funcIndex is not callable") } // TODO: seems bad, global might have an expression to evaluate (maybe only constants?) @@ -458,23 +497,26 @@ case class Evaluator(module: ModuleInstance) { sv = symStack.head }, List() - )(List()) + )(new ExploreTree()) // TODO: Should the execution path of global expressions be ignored like this? (cv, sv) } - private def printRetCont(concStack: List[Value], symStack: List[SymVal], pathConds: List[Cond]) = { + private def printRetCont(concStack: List[Value], symStack: List[SymVal], tree: ExploreTree) = { println(s"retCont: $concStack") println(s"symStack: $symStack") - println(s"pathCnds: $pathConds") + println(s"pathConds: ${tree.collectConds()}") } def execWholeProgram( main: Option[String] = None, symEnv: HashMap[Int, Value] = HashMap(), - k: RetCont = printRetCont + root: ExploreTree = new ExploreTree(), + retCont: RetCont = printRetCont ) = { import collection.mutable.ArrayBuffer + driverCont = retCont + this.symEnv.clear() this.symEnv ++= symEnv @@ -507,7 +549,7 @@ case class Evaluator(module: ModuleInstance) { name == Some(main) }) - print(s"instrs: $instrs") + println(s"instrs: $instrs") // val instrs = List(Call(funcId)) val globals = module.defs.collect({ case g @ Global(_, _) => g }) @@ -536,9 +578,9 @@ case class Evaluator(module: ModuleInstance) { List(), // frame, Frame(module, ArrayBuffer(I32V(0)), ArrayBuffer(Concrete(I32V(0)))), - k, - List((newStack, _, _) => println(s"trail: $newStack")) - )(List()) + retCont, + List(retCont) + )(root) } diff --git a/src/main/scala/wasm/Symbolic.scala b/src/main/scala/wasm/Symbolic.scala index cb87f3f83..c2333fa0d 100644 --- a/src/main/scala/wasm/Symbolic.scala +++ b/src/main/scala/wasm/Symbolic.scala @@ -2,6 +2,7 @@ package gensym.wasm.symbolic import gensym.wasm.ast._ import z3.scala._ +import scala.collection.mutable.HashMap case class SymV(name: String) extends SymVal case class SymBinary(op: BinOp, lhs: SymVal, rhs: SymVal) extends SymVal @@ -10,7 +11,12 @@ case class SymIte(cond: Cond, thn: SymVal, els: SymVal) extends SymVal case class Concrete(v: Value) extends SymVal // The following should be encoded to boolean in SMT -abstract class Cond extends SymVal +abstract class Cond extends SymVal { + def negated: Cond = this match { + case Not(cond) => cond + case _ => Not(this) + } +} case class CondEqz(v: SymVal) extends Cond case class Not(cond: Cond) extends Cond case class RelCond(op: RelOp, lhs: SymVal, rhs: SymVal) extends Cond @@ -24,3 +30,113 @@ abstract class SymVal { case _ => ??? } } + +// consider using zipper to simplify mutations +class ExploreTree(var node: Node = UnExplored(), val parent: Option[ExploreTree] = None) { + def collectConds(): List[Cond] = { + this.parent match { + case Some(parent) => parent.node match { + case IfElse(cond, thenNode, elseNode) => + if (this eq thenNode) { + cond :: parent.collectConds() + } else if (this eq elseNode) { + cond.negated :: parent.collectConds() + } else { + throw new Exception("Internal Error: a tree is not pointed by its parent!") + } + case _ => throw new Exception(s"Internal Error: ${parent.node} is not a valid parent node!") + } + case None => Nil + } + } + + def fillWithIfElse(cond: Cond): IfElse = { + node match { + case UnExplored() => { + var newNode = IfElse(cond, new ExploreTree(parent = Some(this)), new ExploreTree(parent = Some(this))) + node = newNode + newNode + } + case node@IfElse(_, _, _) => node + case _ => throw new Exception("Internal Error: Some exploration paths are not compatible!") + } + } + + def fillWithFinished(): Unit = { + node match { + case UnExplored() => node = Finished() + case Finished() => + println(s"Warning: path to ${this} has been re-executed!") + case Fail(_) => () + case _ => + throw new Exception("Internal Error: Some exploration paths are not compatible!") + } + } + + def fillWithFail(env: HashMap[Int, Value]): Unit = { + node match { + case UnExplored() => node = Fail(env) + case _ => + throw new Exception("Internal Error: Some exploration paths are not compatible!") + } + } + + def unexploredTrees(): List[ExploreTree] = { + node match { + case UnExplored() => List(this) + case IfElse(_, thenNode, elseNode) => + thenNode.unexploredTrees() ++ elseNode.unexploredTrees() + case _ => Nil + } + } + + def finishedTrees(): List[ExploreTree] = { + node match { + case Finished() => List(this) + case IfElse(_, thenNode, elseNode) => + thenNode.finishedTrees() ++ elseNode.finishedTrees() + case _ => Nil + } + } + + def failedTrees(): List[ExploreTree] = { + node match { + case Fail(_) => List(this) + case IfElse(_, thenNode, elseNode) => + thenNode.failedTrees() ++ elseNode.failedTrees() + case _ => Nil + } + } + + // parent node will not affect the sub-tree's structure. + // we ignore it when printing for now + override def toString(): String = node.toString() +} + +sealed abstract class Node { + def cond: Option[Cond] +} + +case class IfElse( + _cond: Cond, + thenNode: ExploreTree, + elseNode: ExploreTree +) extends Node { + // subnodes' parent should point to current tree + + def cond: Option[Cond] = Some(_cond) +} + + +case class UnExplored() extends Node { + def cond = None +} + +case class Finished() extends Node { + def cond = None +} + +case class Fail(env: HashMap[Int, Value]) extends Node { + def cond = None +} + diff --git a/src/test/scala/genwasym/TestConcolic.scala b/src/test/scala/genwasym/TestConcolic.scala index dcff72400..767b8443c 100644 --- a/src/test/scala/genwasym/TestConcolic.scala +++ b/src/test/scala/genwasym/TestConcolic.scala @@ -43,7 +43,12 @@ class TestDriver extends FunSuite { // TODO: fix this test("driver") { - fileTestDriver("./benchmarks/wasm/branch-strip.wat", "real_main", new HashMap[Int, Value]()) + fileTestDriver("./benchmarks/wasm/branch-strip.wat", "real_main", HashMap()) + fileTestDriver("./benchmarks/wasm/branch-strip1.wat", "real_main", HashMap()) + } + + test("bug-finding") { + fileTestDriver("./benchmarks/wasm/branch-strip-buggy.wat", "real_main", HashMap()) } }