Skip to content

Combine lhs and mhs in TTP #110

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: develop-1.0.x
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
lib_managed
*.iml
.idea/
.DS_Store
local.properties
project/boot
project/build/target
target
virtualization-lms-core.iml
.gitingore
#test-out
target/
test-out/
!test-out/**/*.check
data/
8 changes: 4 additions & 4 deletions src/common/IfThenElse.scala
Original file line number Diff line number Diff line change
Expand Up @@ -211,14 +211,14 @@ trait BaseGenIfThenElseFat extends BaseGenIfThenElse with GenericFatCodegen {
import IR._

override def fatten(e: Stm): Stm = e match {
case TP(sym, o: AbstractIfThenElse[_]) =>
TTP(List(sym), List(o), SimpleFatIfThenElse(o.cond, List(o.thenp), List(o.elsep)))
case TP(sym, p @ Reflect(o: AbstractIfThenElse[_], u, es)) => //if !u.maySimple && !u.mayGlobal => // contrary, fusing will not change observable order
case tp @ TP(sym, o: AbstractIfThenElse[_]) =>
TTP(List(tp), SimpleFatIfThenElse(o.cond, List(o.thenp), List(o.elsep)))
case tp @ TP(sym, p @ Reflect(o: AbstractIfThenElse[_], u, es)) => //if !u.maySimple && !u.mayGlobal => // contrary, fusing will not change observable order
// assume body will reflect, too...
printdbg("-- fatten effectful if/then/else " + e)
val e2 = SimpleFatIfThenElse(o.cond, List(o.thenp), List(o.elsep))
e2.extradeps = es //HACK
TTP(List(sym), List(p), e2)
TTP(List(tp), e2)
case _ => super.fatten(e)
}
}
Expand Down
44 changes: 23 additions & 21 deletions src/common/LoopFusionOpt.scala
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ trait LoopFusionCore extends internal.FatScheduling with CodeMotion with Simplif
val levelScope = getExactScope(currentScope)(result) // could provide as input ...
// TODO: cannot in general fuse several effect loops (one effectful and several pure ones is ok though)
// so we need a strategy. a simple one would be exclude all effectful loops right away (TODO).
levelScope collect { case e @ TTP(_, _, SimpleFatLoop(_,_,_)) => e }
levelScope collect { case e @ TTP(_, SimpleFatLoop(_,_,_)) => e }
}

// FIXME: more than one super call means exponential cost -- is there a better way?
Expand All @@ -284,7 +284,7 @@ trait LoopFusionCore extends internal.FatScheduling with CodeMotion with Simplif
var done = false

// keep track of loops in inner scopes
var UloopSyms = currentScope collect { case e @ TTP(lhs, _, SimpleFatLoop(_,_,_)) if !Wloops.contains(e) => lhs }
var UloopSyms = currentScope collect { case e @ TTP(_, SimpleFatLoop(_,_,_)) if !Wloops.contains(e) => e.lhs }

//do{

Expand Down Expand Up @@ -367,7 +367,7 @@ trait LoopFusionCore extends internal.FatScheduling with CodeMotion with Simplif
var partitionsIn = Wloops
var partitionsOut = Nil:List[Stm]

for (b@ TTP(_,_,_) <- partitionsIn) {
for (b@ TTP(_,_) <- partitionsIn) {
// try to add to an item in partitionsOut, if not possible add as-is
partitionsOut.find(a => canFuse(a,b)) match {
case Some(a: TTP) =>
Expand All @@ -394,13 +394,14 @@ trait LoopFusionCore extends internal.FatScheduling with CodeMotion with Simplif
shapeA
}

val lhs = a.lhs ++ b.lhs
val tps = a.tps ++ b.tps

val fused = TTP(lhs, a.mhs ++ b.mhs, SimpleFatLoop(shape, targetVar, WgetLoopRes(a):::WgetLoopRes(b)))
val fused = TTP(tps, SimpleFatLoop(shape, targetVar, WgetLoopRes(a):::WgetLoopRes(b)))
partitionsOut = fused :: (partitionsOut diff List(a))

val preNeg = WtableNeg collect { case p if (lhs contains p._2) => p._1 }
val postNeg = WtableNeg collect { case p if (lhs contains p._1) => p._2 }
val syms = tps.map(_.sym).toSet
val preNeg = WtableNeg collect { case p if (syms contains p._2) => p._1 }
val postNeg = WtableNeg collect { case p if (syms contains p._1) => p._2 }

val fusedNeg = preNeg flatMap { s1 => postNeg map { s2 => (s1,s2) } }
WtableNeg = (fusedNeg ++ WtableNeg).distinct
Expand Down Expand Up @@ -461,10 +462,10 @@ trait LoopFusionCore extends internal.FatScheduling with CodeMotion with Simplif

// prune Wloops (some might be no longer necessary)
Wloops = pOutT map {
case TTP(lhs, mhs, SimpleFatLoop(s, x, rhs)) =>
val ex = lhs map (s => currentScope exists (_.lhs contains s))
case TTP(tps, SimpleFatLoop(s, x, rhs)) =>
val ex = tps map (s => currentScope exists (_.lhs contains s.sym))
def select[A](a: List[A], b: List[Boolean]) = (a zip b) collect { case (w, true) => w }
TTP(select(lhs, ex), select(mhs, ex), SimpleFatLoop(s, x, select(rhs, ex)))
TTP(select(tps, ex), SimpleFatLoop(s, x, select(rhs, ex)))
}

currentScope = (currentScope diff pInT) ++ Wloops
Expand Down Expand Up @@ -520,7 +521,7 @@ trait LoopFusionCore extends internal.FatScheduling with CodeMotion with Simplif
val levelScope = getExactScope(currentScope)(result) // could provide as input ...
// TODO: cannot in general fuse several effect loops (one effectful and several pure ones is ok though)
// so we need a strategy. a simple one would be exclude all effectful loops right away (TODO).
levelScope collect { case e @ TTP(_, _, SimpleFatLoop(_,_,_)) => e }
levelScope collect { case e @ TTP(_, SimpleFatLoop(_,_,_)) => e }
}

// FIXME: more than one super call means exponential cost -- is there a better way?
Expand All @@ -535,7 +536,7 @@ trait LoopFusionCore extends internal.FatScheduling with CodeMotion with Simplif
var done = false

// keep track of loops in inner scopes
var UloopSyms = currentScope collect { case e @ TTP(lhs, _, SimpleFatLoop(_,_,_)) if !Wloops.contains(e) => lhs }
var UloopSyms = currentScope collect { case e @ TTP(lhs, SimpleFatLoop(_,_,_)) if !Wloops.contains(e) => lhs.map(_.sym) }

do {
// utils
Expand Down Expand Up @@ -630,7 +631,7 @@ trait LoopFusionCore extends internal.FatScheduling with CodeMotion with Simplif
var partitionsIn = Wloops
var partitionsOut = Nil:List[Stm]

for (b@ TTP(_,_,_) <- partitionsIn) {
for (b@ TTP(_,_) <- partitionsIn) {
// try to add to an item in partitionsOut, if not possible add as-is
partitionsOut.find(a => canFuse(a,b)) match {
case Some(a: TTP) =>
Expand Down Expand Up @@ -659,11 +660,12 @@ trait LoopFusionCore extends internal.FatScheduling with CodeMotion with Simplif

val lhs = a.lhs ++ b.lhs

val fused = TTP(lhs, a.mhs ++ b.mhs, SimpleFatLoop(shape, targetVar, WgetLoopRes(a):::WgetLoopRes(b)))
val fused = TTP(lhs, SimpleFatLoop(shape, targetVar, WgetLoopRes(a):::WgetLoopRes(b)))
partitionsOut = fused :: (partitionsOut diff List(a))

val preNeg = WtableNeg collect { case p if (lhs contains p._2) => p._1 }
val postNeg = WtableNeg collect { case p if (lhs contains p._1) => p._2 }
val syms = lhs.map(_.sym).toSet
val preNeg = WtableNeg collect { case p if (syms contains p._2) => p._1 }
val postNeg = WtableNeg collect { case p if (syms contains p._1) => p._2 }

val fusedNeg = preNeg flatMap { s1 => postNeg map { s2 => (s1,s2) } }
WtableNeg = (fusedNeg ++ WtableNeg).distinct
Expand Down Expand Up @@ -722,19 +724,19 @@ trait LoopFusionCore extends internal.FatScheduling with CodeMotion with Simplif

// prune Wloops (some might be no longer necessary)
Wloops = Wloops map {
case TTP(lhs, mhs, SimpleFatLoop(s, x, rhs)) =>
val ex = lhs map (s => currentScope exists (_.lhs == List(s)))
case TTP(lhs, SimpleFatLoop(s, x, rhs)) =>
val ex = lhs map (s => currentScope exists (_.lhs == List(s.sym)))
def select[A](a: List[A], b: List[Boolean]) = (a zip b) collect { case (w, true) => w }
TTP(select(lhs, ex), select(mhs, ex), SimpleFatLoop(s, x, select(rhs, ex)))
TTP(select(lhs, ex), SimpleFatLoop(s, x, select(rhs, ex)))
}

// PREVIOUS PROBLEM: don't throw out all loops, might have some that are *not* in levelScope
// note: if we don't do it here, we will likely see a problem going back to innerScope in
// FatCodegen.focusExactScopeFat below. --> how to go back from SimpleFatLoop to VectorPlus??
// UPDATE: UloopSyms puts a tentative fix in place. check if it is sufficient!!
// what is the reason we cannot just look at Wloops??
currentScope = currentScope.filter { case e@TTP(lhs, _, _: AbstractFatLoop) =>
val keep = UloopSyms contains lhs
currentScope = currentScope.filter { case e@TTP(lhs, _: AbstractFatLoop) =>
val keep = UloopSyms contains lhs.map(_.sym)
//if (!keep) println("dropping: " + e + ", not int UloopSyms: " + UloopSyms)
keep case _ => true } ::: Wloops

Expand Down
8 changes: 4 additions & 4 deletions src/common/Loops.scala
Original file line number Diff line number Diff line change
Expand Up @@ -137,11 +137,11 @@ trait BaseLoopsTraversalFat extends FatBlockTraversal {
import IR._

override def fatten(e: Stm): Stm = e match {
case TP(sym, op: AbstractLoop[_]) =>
TTP(List(sym), List(op), SimpleFatLoop(op.size, op.v, List(op.body)))
case TP(sym, p @ Reflect(op: AbstractLoop[_], u, es)) if !u.maySimple && !u.mayGlobal => // assume body will reflect, too. bring it on...
case tp @ TP(sym, op: AbstractLoop[_]) =>
TTP(List(tp), SimpleFatLoop(op.size, op.v, List(op.body)))
case tp @ TP(sym, p @ Reflect(op: AbstractLoop[_], u, es)) if !u.maySimple && !u.mayGlobal => // assume body will reflect, too. bring it on...
printdbg("-- fatten effectful loop " + e)
TTP(List(sym), List(p), SimpleFatLoop(op.size, op.v, List(op.body)))
TTP(List(tp), SimpleFatLoop(op.size, op.v, List(op.body)))
case _ => super.fatten(e)
}

Expand Down
75 changes: 49 additions & 26 deletions src/common/SimplifyTransform.scala
Original file line number Diff line number Diff line change
Expand Up @@ -106,65 +106,88 @@ trait SimplifyTransform extends internal.FatScheduling {
case s: Sym[Any] => (Option(scopeIndex.get(s)) orElse findDefinition(s)).toList // check scope before target graph
case _ => Nil
}
case TTP(lhs, mhs, SimpleFatIfThenElse(c,as,bs)) =>
case TTP(tps, SimpleFatIfThenElse(c,as,bs)) =>
// alternate strategy: transform thin def, then fatten again (a little detour)
printdbg("need to transform rhs of fat if/then/else: " + lhs + ", if " + c + " then " + as + " else " + bs)
val lhs2 = (lhs zip mhs).map { case (s,r) => transformOne(s, r, t) }.distinct.asInstanceOf[List[Sym[Any]]]
val lhs1 = tps.map(_.sym)
printdbg("need to transform rhs of fat if/then/else: " + lhs1 + ", if " + c + " then " + as + " else " + bs)
val lhs2 = tps.map { case TP(s,r) => transformOne(s, r, t) }.distinct.asInstanceOf[List[Sym[Any]]]
val mhs2 = lhs2.map(s => findDefinition(s).get.defines(s).get)
// TBD: we're mirroring the defs in mhs, creating new stms
// we don't really want new stms: if the defs are just abstract descriptions we only want them updated

// this means we'll have both a TP and a TTP defining the same sym in globalDefs --> bad!
// not quite so fast, chances are the TTP's never make it into globalDefs (no toAtom call)!!!

if (lhs != lhs2) {

val lhsesAreDifferent = lhs1 != lhs2
if (lhsesAreDifferent) {
val missing = Nil//(lhs2.map(s => findDefinition(s).get) diff innerScope)
printdbg("lhs changed! will add to innerScope: "+missing.mkString(","))
printdbg("tps changed! will add to innerScope: "+missing.mkString(","))
//innerScope = innerScope ::: missing
}

def infix_toIf(d: Def[Any]) = d match {
case l: AbstractIfThenElse[_] => l
case Reflect(l: AbstractIfThenElse[_], _, _) => l
}
val cond2 = if (lhs != lhs2) mhs2.map (_.toIf.cond) reduceLeft { (s1,s2) => assert(s1==s2,"conditions don't agree: "+s1+","+s2); s1 }
else t(c)
val as2 = (if (lhs != lhs2) (lhs2 zip (mhs2 map (_.toIf.thenp)))
else (lhs zip as)) map { case (s,r) => transformIfBody(s,r,t) }
val bs2 = (if (lhs != lhs2) (lhs2 zip (mhs2 map (_.toIf.elsep)))
else (lhs zip bs)) map { case (s,r) => transformIfBody(s,r,t) }
def branches(ifDifferent: AbstractIfThenElse[_] => Block[_], ifSame: List[Block[_]]) = {
val rhs2 = if (lhsesAreDifferent)
mhs2.map(x => ifDifferent(x.toIf))
else
ifSame
lhs2.zip(rhs2).map { case (s,r) => transformIfBody(s,r,t) }
}

val cond2 = if (lhsesAreDifferent)
mhs2.map(_.toIf.cond).reduceLeft { (s1,s2) =>
assert(s1==s2,"conditions don't agree: "+s1+","+s2)
s1
}
else
t(c)
val as2 = branches(_.thenp, as)
val bs2 = branches(_.elsep, bs)

printdbg("came up with: " + lhs2 + ", if " + cond2 + " then " + as2 + " else " + bs2 + " with subst " + t.subst.mkString(","))
List(TTP(lhs2, mhs2, SimpleFatIfThenElse(cond2,as2,bs2)))
List(TTP(lhs2.zip(mhs2).map { case (s, d) => TP(s,d) }, SimpleFatIfThenElse(cond2,as2,bs2)))

case TTP(lhs, mhs, SimpleFatLoop(s,x,rhs)) =>
case TTP(tps, SimpleFatLoop(s,x,rhs)) =>
// alternate strategy: transform thin def, then fatten again (a little detour)
printdbg("need to transform rhs of fat loop: " + lhs + ", " + rhs)
val lhs2 = (lhs zip mhs).map { case (s,r) => transformOne(s, r, t) }.distinct.asInstanceOf[List[Sym[Any]]]
val lhs1 = tps.map(_.sym)
printdbg("need to transform rhs of fat loop: " + lhs1 + ", " + rhs)
val lhs2 = tps.map { case TP(s,r) => transformOne(s, r, t) }.distinct.asInstanceOf[List[Sym[Any]]]
val mhs2 = lhs2.map(s => findDefinition(s).get.defines(s).get)
if (lhs != lhs2) {
val lhsesAreDifferent = lhs1 != lhs2
if (lhsesAreDifferent) {
val missing = (lhs2.map(s => findDefinition(s).get) diff scope/*innerScope*/)
printdbg("lhs changed! will add to innerScope: "+missing.mkString(","))
printdbg("tps changed! will add to innerScope: "+missing.mkString(","))
//innerScope = innerScope ::: missing
}
//val shape2 = if (lhs != lhs2) lhs2.map { case Def(SimpleLoop(s,_,_)) => s } reduceLeft { (s1,s2) => assert(s1==s2,"shapes don't agree: "+s1+","+s2); s1 }
//val shape2 = if (lhs1 != lhs2) lhs2.map { case Def(SimpleLoop(s,_,_)) => s } reduceLeft { (s1,s2) => assert(s1==s2,"shapes don't agree: "+s1+","+s2); s1 }
def infix_toLoop(d: Def[Any]) = d match {
case l: AbstractLoop[_] => l
case Reflect(l: AbstractLoop[_], _, _) => l
}
val shape2 = if (lhs != lhs2) mhs2.map (_.toLoop.size) reduceLeft { (s1,s2) => assert(s1==s2,"shapes don't agree: "+s1+","+s2); s1 }
else t(s)
val rhs2 = (if (lhs != lhs2) (lhs2 zip (mhs2 map (_.toLoop.body)))
else (lhs zip rhs)) map { case (s,r) => transformLoopBody(s,r,t) }
val shape2 = if (lhsesAreDifferent)
mhs2.map(_.toLoop.size).reduceLeft { (s1,s2) =>
assert(s1==s2,"shapes don't agree: "+s1+","+s2)
s1
}
else
t(s)
val rhs2 = if (lhsesAreDifferent)
mhs2 map (_.toLoop.body)
else
rhs
val rhs3 = lhs2.zip(rhs2).map { case (s,r) => transformLoopBody(s,r,t) }

/* //update innerScope -- change definition of lhs2 in place (necessary?)
innerScope = innerScope map {
case TP(l,_) if lhs2 contains l => TP(l, SimpleLoop(shape2,t(x).asInstanceOf[Sym[Int]],rhs2(lhs2.indexOf(l))))
case TP(l,_) if lhs2 contains l => TP(l, SimpleLoop(shape2,t(x).asInstanceOf[Sym[Int]],rhs3(lhs2.indexOf(l))))
case d => d
}*/

printdbg("came up with: " + lhs2 + ", " + rhs2 + " with subst " + t.subst.mkString(","))
List(TTP(lhs2, mhs2, SimpleFatLoop(shape2,t(x).asInstanceOf[Sym[Int]],rhs2)))
printdbg("came up with: " + lhs2 + ", " + rhs3 + " with subst " + t.subst.mkString(","))
List(TTP(lhs2.zip(mhs2).map { case (s, d) => TP(s,d) }, SimpleFatLoop(shape2,t(x).asInstanceOf[Sym[Int]],rhs3)))
// still problem: VectorSum(a,b) = SimpleLoop(i, ReduceElem(f(i)))
// might need to translate f(i), but looking up VectorSum will not be changed at all!!!
// --> change rhs nonetheless???
Expand Down
30 changes: 15 additions & 15 deletions src/common/SplitEffects.scala
Original file line number Diff line number Diff line change
Expand Up @@ -183,14 +183,14 @@ trait BaseGenSplitEffects extends BaseGenIfThenElseFat with GenericFatCodegen {


override def fatten(e: Stm): Stm = e match {
case TP(s,d@While(c,b)) => TTP(List(s),List(d),SimpleFatWhile(c,List(b)))
case TP(s,d@Reflect(While(c,b),u,es)) =>
case tp @ TP(s,d@While(c,b)) => TTP(List(tp),SimpleFatWhile(c,List(b)))
case tp @ TP(s,d@Reflect(While(c,b),u,es)) =>
val x = SimpleFatWhile(c,List(b))
x.extradeps = es.asInstanceOf[List[Sym[Any]]]
TTP(List(s),List(d),x)
case TP(s,d@Reflect(PreviousIteration(k),u,es)) =>
TTP(List(tp),x)
case tp @ TP(s,d@Reflect(PreviousIteration(k),u,es)) =>
val x = SimpleFatPrevious(k,es.asInstanceOf[List[Sym[Any]]])
TTP(List(s),List(d),x)
TTP(List(tp),x)
case _ => super.fatten(e)
}

Expand All @@ -205,31 +205,31 @@ trait BaseGenSplitEffects extends BaseGenIfThenElseFat with GenericFatCodegen {
//println(e1)

val e2 = e1 collect {
case t@TTP(lhs, mhs, p @ SimpleFatIfThenElse(c,as,bs)) => t
case t@TTP(lhs, mhs, p @ SimpleFatWhile(c,b)) => t
case t@TTP(lhs, mhs, p @ SimpleFatPrevious(k,es)) => t
case t@TTP(_, p @ SimpleFatIfThenElse(c,as,bs)) => t
case t@TTP(_, p @ SimpleFatWhile(c,b)) => t
case t@TTP(_, p @ SimpleFatPrevious(k,es)) => t
}

val m = e2 groupBy {
case t@TTP(lhs, mhs, p @ SimpleFatIfThenElse(c,as,bs)) => (c, "if")
case t@TTP(lhs, mhs, p @ SimpleFatWhile(Block(Def(Reify(c,_,_))),b)) => (c, "while")
case t@TTP(lhs, mhs, p @ SimpleFatPrevious(k,es)) => (k,"prev")
case t@TTP(_, p @ SimpleFatIfThenElse(c,as,bs)) => (c, "if")
case t@TTP(_, p @ SimpleFatWhile(Block(Def(Reify(c,_,_))),b)) => (c, "while")
case t@TTP(_, p @ SimpleFatPrevious(k,es)) => (k,"prev")
}

val e3 = e1 diff e2

val g1 = m map {
case ((c:Exp[Boolean], "if"), ifs: List[TTP]) => TTP(ifs.flatMap(_.lhs), ifs.flatMap(_.mhs),
case ((c:Exp[Boolean], "if"), ifs: List[TTP]) => TTP(ifs.flatMap(_.tps),
SimpleFatIfThenElse(c, ifs.flatMap(_.rhs.asInstanceOf[SimpleFatIfThenElse].thenp),
ifs.flatMap(_.rhs.asInstanceOf[SimpleFatIfThenElse].elsep)))
case ((c, "while"), wls: List[TTP]) =>
val x = SimpleFatWhile(wls.map(_.rhs.asInstanceOf[SimpleFatWhile].cond).apply(0), //FIXME: merge cond!!!
wls.flatMap(_.rhs.asInstanceOf[SimpleFatWhile].body))
x.extradeps = wls.flatMap(_.rhs.asInstanceOf[SimpleFatWhile].extradeps) diff wls.flatMap(_.lhs)
TTP(wls.flatMap(_.lhs), wls.flatMap(_.mhs), // TODO: merge cond blocks!
x.extradeps = wls.flatMap(_.rhs.asInstanceOf[SimpleFatWhile].extradeps) diff wls.flatMap(_.tps)
TTP(wls.flatMap(_.tps), // TODO: merge cond blocks!
x)
case ((k:Exp[Nothing],"prev"), pvs: List[TTP]) =>
TTP(pvs.flatMap(_.lhs), pvs.flatMap(_.mhs), SimpleFatPrevious(k,pvs.flatMap(_.rhs.asInstanceOf[SimpleFatPrevious].extra)))
TTP(pvs.flatMap(_.tps), SimpleFatPrevious(k,pvs.flatMap(_.rhs.asInstanceOf[SimpleFatPrevious].extra)))
}

val r = e3 ++ g1
Expand Down
Loading