@@ -269,7 +269,7 @@ trait LoopFusionCore extends internal.FatScheduling with CodeMotion with Simplif
269
269
val levelScope = getExactScope(currentScope)(result) // could provide as input ...
270
270
// TODO: cannot in general fuse several effect loops (one effectful and several pure ones is ok though)
271
271
// so we need a strategy. a simple one would be exclude all effectful loops right away (TODO).
272
- levelScope collect { case e @ TTP (_, _, SimpleFatLoop (_,_,_)) => e }
272
+ levelScope collect { case e @ TTP (_, SimpleFatLoop (_,_,_)) => e }
273
273
}
274
274
275
275
// FIXME: more than one super call means exponential cost -- is there a better way?
@@ -284,7 +284,7 @@ trait LoopFusionCore extends internal.FatScheduling with CodeMotion with Simplif
284
284
var done = false
285
285
286
286
// keep track of loops in inner scopes
287
- var UloopSyms = currentScope collect { case e @ TTP (lhs, _, SimpleFatLoop (_,_,_)) if ! Wloops .contains(e) => lhs }
287
+ var UloopSyms = currentScope collect { case e @ TTP (_, SimpleFatLoop (_,_,_)) if ! Wloops .contains(e) => e. lhs }
288
288
289
289
// do{
290
290
@@ -367,7 +367,7 @@ trait LoopFusionCore extends internal.FatScheduling with CodeMotion with Simplif
367
367
var partitionsIn = Wloops
368
368
var partitionsOut = Nil : List [Stm ]
369
369
370
- for (b@ TTP (_,_,_ ) <- partitionsIn) {
370
+ for (b@ TTP (_,_) <- partitionsIn) {
371
371
// try to add to an item in partitionsOut, if not possible add as-is
372
372
partitionsOut.find(a => canFuse(a,b)) match {
373
373
case Some (a : TTP ) =>
@@ -394,13 +394,14 @@ trait LoopFusionCore extends internal.FatScheduling with CodeMotion with Simplif
394
394
shapeA
395
395
}
396
396
397
- val lhs = a.lhs ++ b.lhs
397
+ val tps = a.tps ++ b.tps
398
398
399
- val fused = TTP (lhs, a.mhs ++ b.mhs , SimpleFatLoop (shape, targetVar, WgetLoopRes (a)::: WgetLoopRes (b)))
399
+ val fused = TTP (tps , SimpleFatLoop (shape, targetVar, WgetLoopRes (a)::: WgetLoopRes (b)))
400
400
partitionsOut = fused :: (partitionsOut diff List (a))
401
401
402
- val preNeg = WtableNeg collect { case p if (lhs contains p._2) => p._1 }
403
- val postNeg = WtableNeg collect { case p if (lhs contains p._1) => p._2 }
402
+ val syms = tps.map(_.sym).toSet
403
+ val preNeg = WtableNeg collect { case p if (syms contains p._2) => p._1 }
404
+ val postNeg = WtableNeg collect { case p if (syms contains p._1) => p._2 }
404
405
405
406
val fusedNeg = preNeg flatMap { s1 => postNeg map { s2 => (s1,s2) } }
406
407
WtableNeg = (fusedNeg ++ WtableNeg ).distinct
@@ -461,10 +462,10 @@ trait LoopFusionCore extends internal.FatScheduling with CodeMotion with Simplif
461
462
462
463
// prune Wloops (some might be no longer necessary)
463
464
Wloops = pOutT map {
464
- case TTP (lhs, mhs , SimpleFatLoop (s, x, rhs)) =>
465
- val ex = lhs map (s => currentScope exists (_.lhs contains s))
465
+ case TTP (tps , SimpleFatLoop (s, x, rhs)) =>
466
+ val ex = tps map (s => currentScope exists (_.lhs contains s.sym ))
466
467
def select [A ](a : List [A ], b : List [Boolean ]) = (a zip b) collect { case (w, true ) => w }
467
- TTP (select(lhs, ex), select(mhs , ex), SimpleFatLoop (s, x, select(rhs, ex)))
468
+ TTP (select(tps , ex), SimpleFatLoop (s, x, select(rhs, ex)))
468
469
}
469
470
470
471
currentScope = (currentScope diff pInT) ++ Wloops
@@ -520,7 +521,7 @@ trait LoopFusionCore extends internal.FatScheduling with CodeMotion with Simplif
520
521
val levelScope = getExactScope(currentScope)(result) // could provide as input ...
521
522
// TODO: cannot in general fuse several effect loops (one effectful and several pure ones is ok though)
522
523
// so we need a strategy. a simple one would be exclude all effectful loops right away (TODO).
523
- levelScope collect { case e @ TTP(_, _, SimpleFatLoop(_,_,_)) => e }
524
+ levelScope collect { case e @ TTP(_, SimpleFatLoop(_,_,_)) => e }
524
525
}
525
526
526
527
// FIXME: more than one super call means exponential cost -- is there a better way?
@@ -535,7 +536,7 @@ trait LoopFusionCore extends internal.FatScheduling with CodeMotion with Simplif
535
536
var done = false
536
537
537
538
// keep track of loops in inner scopes
538
- var UloopSyms = currentScope collect { case e @ TTP(lhs, _, SimpleFatLoop(_,_,_)) if !Wloops.contains(e) => lhs }
539
+ var UloopSyms = currentScope collect { case e @ TTP(lhs, SimpleFatLoop(_,_,_)) if !Wloops.contains(e) => lhs.map(_.sym) }
539
540
540
541
do {
541
542
// utils
@@ -630,7 +631,7 @@ trait LoopFusionCore extends internal.FatScheduling with CodeMotion with Simplif
630
631
var partitionsIn = Wloops
631
632
var partitionsOut = Nil:List[Stm]
632
633
633
- for (b@ TTP(_,_,_ ) <- partitionsIn) {
634
+ for (b@ TTP(_,_) <- partitionsIn) {
634
635
// try to add to an item in partitionsOut, if not possible add as-is
635
636
partitionsOut.find(a => canFuse(a,b)) match {
636
637
case Some(a: TTP) =>
@@ -659,11 +660,12 @@ trait LoopFusionCore extends internal.FatScheduling with CodeMotion with Simplif
659
660
660
661
val lhs = a.lhs ++ b.lhs
661
662
662
- val fused = TTP(lhs, a.mhs ++ b.mhs, SimpleFatLoop(shape, targetVar, WgetLoopRes(a):::WgetLoopRes(b)))
663
+ val fused = TTP(lhs, SimpleFatLoop(shape, targetVar, WgetLoopRes(a):::WgetLoopRes(b)))
663
664
partitionsOut = fused :: (partitionsOut diff List(a))
664
665
665
- val preNeg = WtableNeg collect { case p if (lhs contains p._2) => p._1 }
666
- val postNeg = WtableNeg collect { case p if (lhs contains p._1) => p._2 }
666
+ val syms = lhs.map(_.sym).toSet
667
+ val preNeg = WtableNeg collect { case p if (syms contains p._2) => p._1 }
668
+ val postNeg = WtableNeg collect { case p if (syms contains p._1) => p._2 }
667
669
668
670
val fusedNeg = preNeg flatMap { s1 => postNeg map { s2 => (s1,s2) } }
669
671
WtableNeg = (fusedNeg ++ WtableNeg).distinct
@@ -722,19 +724,19 @@ trait LoopFusionCore extends internal.FatScheduling with CodeMotion with Simplif
722
724
723
725
// prune Wloops (some might be no longer necessary)
724
726
Wloops = Wloops map {
725
- case TTP(lhs, mhs, SimpleFatLoop(s, x, rhs)) =>
726
- val ex = lhs map (s => currentScope exists (_.lhs == List(s)))
727
+ case TTP(lhs, SimpleFatLoop(s, x, rhs)) =>
728
+ val ex = lhs map (s => currentScope exists (_.lhs == List(s.sym )))
727
729
def select[A](a: List[A], b: List[Boolean]) = (a zip b) collect { case (w, true) => w }
728
- TTP(select(lhs, ex), select(mhs, ex), SimpleFatLoop(s, x, select(rhs, ex)))
730
+ TTP(select(lhs, ex), SimpleFatLoop(s, x, select(rhs, ex)))
729
731
}
730
732
731
733
// PREVIOUS PROBLEM: don't throw out all loops, might have some that are *not* in levelScope
732
734
// note: if we don't do it here, we will likely see a problem going back to innerScope in
733
735
// FatCodegen.focusExactScopeFat below. --> how to go back from SimpleFatLoop to VectorPlus??
734
736
// UPDATE: UloopSyms puts a tentative fix in place. check if it is sufficient!!
735
737
// what is the reason we cannot just look at Wloops??
736
- currentScope = currentScope.filter { case e@TTP(lhs, _, _ : AbstractFatLoop) =>
737
- val keep = UloopSyms contains lhs
738
+ currentScope = currentScope.filter { case e@TTP(lhs, _: AbstractFatLoop) =>
739
+ val keep = UloopSyms contains lhs.map(_.sym)
738
740
//if (!keep) println("dropping: " + e + ", not int UloopSyms: " + UloopSyms)
739
741
keep case _ => true } ::: Wloops
740
742
0 commit comments