Skip to content

Fix incorrect equality for ArrayNew and friends, add missing rewrites in PrimitiveOps #104

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Dec 1, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
lib_managed
*.iml
.idea/
.DS_Store
local.properties
project/boot
project/build/target
target
virtualization-lms-core.iml
.gitingore
#test-out
target/
test-out/
!test-out/**/*.check
data/
22 changes: 15 additions & 7 deletions src/internal/Expressions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,12 @@ trait Expressions extends Utils {
def pos: List[SourceContext] = Nil
}

case class Const[+T:Manifest](x: T) extends Exp[T]
case class Const[+T:Manifest](x: T) extends Exp[T] {
override def equals(other: Any) = other match {
case c: Const[_] => x == c.x && tp == c.tp
case _ => false
}
}

case class Sym[+T:Manifest](val id: Int) extends Exp[T] {
var sourceContexts: List[SourceContext] = Nil
Expand Down Expand Up @@ -65,11 +70,11 @@ trait Expressions extends Utils {
case _ => None
}

def infix_defines[A](stm: Stm, rhs: Def[A]): Option[Sym[A]] = stm match {
case TP(sym: Sym[A], `rhs`) => Some(sym)
def infix_defines[A: Manifest](stm: Stm, rhs: Def[A]): Option[Sym[A]] = stm match {
case TP(sym: Sym[A], `rhs`) if sym.tp <:< manifest[A] => Some(sym)
case _ => None
}

case class TP[+T](sym: Sym[T], rhs: Def[T]) extends Stm

// graph construction state
Expand Down Expand Up @@ -108,16 +113,19 @@ trait Expressions extends Utils {
globalDefsCache.get(s)
//globalDefs.find(x => x.defines(s).nonEmpty)

def findDefinition[T](d: Def[T]): Option[Stm] =
def findDefinition[T: Manifest](d: Def[T]): Option[Stm] =
globalDefs.find(x => x.defines(d).nonEmpty)

def findOrCreateDefinition[T:Manifest](d: Def[T], pos: List[SourceContext]): Stm =
findDefinition[T](d) map { x => x.defines(d).foreach(_.withPos(pos)); x } getOrElse {
createDefinition(fresh[T](pos), d)
}

def findOrCreateDefinitionExp[T:Manifest](d: Def[T], pos: List[SourceContext]): Exp[T] =
findOrCreateDefinition(d, pos).defines(d).get
def findOrCreateDefinitionExp[T:Manifest](d: Def[T], pos: List[SourceContext]): Exp[T] = {
val stm = findOrCreateDefinition(d, pos)
val optExp = stm.defines(d)
optExp.get
}

def createDefinition[T](s: Sym[T], d: Def[T]): Stm = {
val f = TP(s, d)
Expand Down
12 changes: 10 additions & 2 deletions src/internal/FatExpressions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,16 @@ trait FatExpressions extends Expressions {
case _ => super.infix_defines(stm, sym)
}

override def infix_defines[A](stm: Stm, rhs: Def[A]): Option[Sym[A]] = stm match {
case TTP(lhs, mhs, rhs) => mhs.indexOf(rhs) match { case idx if idx >= 0 => Some(lhs(idx).asInstanceOf[Sym[A]]) case _ => None }
override def infix_defines[A: Manifest](stm: Stm, rhs: Def[A]): Option[Sym[A]] = stm match {
case TTP(lhs, mhs, rhs) => mhs.indexOf(rhs) match {
case idx if idx >= 0 =>
val sym = lhs(idx)
if (sym.tp <:< manifest[A])
Some(sym.asInstanceOf[Sym[A]])
else
None
case _ => None
}
case _ => super.infix_defines(stm, rhs)
}

Expand Down
56 changes: 28 additions & 28 deletions test-out/epfl/test14-queries2.check
Original file line number Diff line number Diff line change
Expand Up @@ -6,39 +6,39 @@ nVars=2000
import scala.lms.epfl.test14.Schema
class staged$0 extends ((Unit)=>(Unit)) {
def apply(x0:Unit): Unit = {
val x2550 = println("rangeFromNames(\"Edna\",\"Bert\"):")
val x617 = Schema.db.people.flatMap { x111 =>
val x616 = Schema.db.people.flatMap { x435 =>
val x615 = Schema.db.people.flatMap { x598 =>
val x112 = x111.name
val x113 = x112 == "Edna"
val x438 = x435.name
val x439 = x438 == "Bert"
val x116 = x111.age
val x602 = x598.age
val x603 = x116 <= x602
val x441 = x435.age
val x604 = x602 < x441
val x605 = x603 && x604
val x611 = x439 && x605
val x613 = x113 && x611
val x614 = if (x613) {
val x607 = x598.name
val x608 = new Schema.Record { val name = x607 }
val x609 = List(x608)
x609
val x2580 = println("rangeFromNames(\"Edna\",\"Bert\"):")
val x641 = Schema.db.people.flatMap { x112 =>
val x640 = Schema.db.people.flatMap { x445 =>
val x639 = Schema.db.people.flatMap { x620 =>
val x113 = x112.name
val x114 = x113 == "Edna"
val x448 = x445.name
val x449 = x448 == "Bert"
val x117 = x112.age
val x624 = x620.age
val x625 = x117 <= x624
val x452 = x445.age
val x626 = x624 < x452
val x627 = x625 && x626
val x634 = x449 && x627
val x637 = x114 && x634
val x638 = if (x637) {
val x630 = x620.name
val x631 = new Schema.Record { val name = x630 }
val x632 = List(x631)
x632
} else {
val x19 = List()
x19
val x34 = List()
x34
}
x614
x638
}
x615
x639
}
x616
x640
}
val x2551 = println(x617)
x2551
val x2581 = println(x641)
x2581
}
}
/*****************************************
Expand Down
80 changes: 40 additions & 40 deletions test-out/epfl/test14-queries3.check
Original file line number Diff line number Diff line change
Expand Up @@ -6,55 +6,55 @@ nVars=2000
import scala.lms.epfl.test14.Schema
class staged$0 extends ((Unit)=>(Unit)) {
def apply(x0:Unit): Unit = {
val x2550 = println("expertise(\"abstract\"):")
val x800 = Schema.org.departments.flatMap { x732 =>
val x792 = Schema.org.employees.flatMap { x763 =>
val x734 = x732.dpt
val x764 = x763.dpt
val x765 = x734 == x764
val x785 = Schema.org.tasks.flatMap { x776 =>
val x767 = x763.emp
val x777 = x776.emp
val x778 = x767 == x777
val x779 = x776.tsk
val x780 = x779 == "abstract"
val x781 = x778 && x780
val x784 = if (x781) {
val x661 = new Schema.Record { val ignore = () }
val x688 = List(x661)
x688
val x2580 = println("expertise(\"abstract\"):")
val x824 = Schema.org.departments.flatMap { x756 =>
val x816 = Schema.org.employees.flatMap { x787 =>
val x758 = x756.dpt
val x788 = x787.dpt
val x789 = x758 == x788
val x809 = Schema.org.tasks.flatMap { x800 =>
val x791 = x787.emp
val x801 = x800.emp
val x802 = x791 == x801
val x803 = x800.tsk
val x804 = x803 == "abstract"
val x805 = x802 && x804
val x808 = if (x805) {
val x685 = new Schema.Record { val ignore = () }
val x712 = List(x685)
x712
} else {
val x19 = List()
x19
val x34 = List()
x34
}
x784
x808
}
val x786 = x785.isEmpty
val x788 = x765 && x786
val x791 = if (x788) {
val x661 = new Schema.Record { val ignore = () }
val x688 = List(x661)
x688
val x810 = x809.isEmpty
val x812 = x789 && x810
val x815 = if (x812) {
val x685 = new Schema.Record { val ignore = () }
val x712 = List(x685)
x712
} else {
val x19 = List()
x19
val x34 = List()
x34
}
x791
x815
}
val x793 = x792.isEmpty
val x799 = if (x793) {
val x734 = x732.dpt
val x797 = new Schema.Record { val dpt = x734 }
val x798 = List(x797)
x798
val x817 = x816.isEmpty
val x823 = if (x817) {
val x758 = x756.dpt
val x821 = new Schema.Record { val dpt = x758 }
val x822 = List(x821)
x822
} else {
val x19 = List()
x19
val x34 = List()
x34
}
x799
x823
}
val x2551 = println(x800)
x2551
val x2581 = println(x824)
x2581
}
}
/*****************************************
Expand Down
80 changes: 40 additions & 40 deletions test-out/epfl/test14-queries4.check
Original file line number Diff line number Diff line change
Expand Up @@ -6,55 +6,55 @@ nVars=2000
import scala.lms.epfl.test14.Schema
class staged$0 extends ((Unit)=>(Unit)) {
def apply(x0:Unit): Unit = {
val x2550 = println("expertise2(\"abstract\"):")
val x1147 = Schema.org.departments.flatMap { x991 =>
val x1140 = Schema.org.employees.flatMap { x1083 =>
val x993 = x991.dpt
val x1084 = x1083.dpt
val x1085 = x993 == x1084
val x1133 = Schema.org.tasks.flatMap { x1120 =>
val x1088 = x1083.emp
val x1121 = x1120.emp
val x1122 = x1088 == x1121
val x1125 = x1120.tsk
val x1128 = x1125 == "abstract"
val x1130 = x1122 && x1128
val x1132 = if (x1130) {
val x661 = new Schema.Record { val ignore = () }
val x688 = List(x661)
x688
val x2580 = println("expertise2(\"abstract\"):")
val x1177 = Schema.org.departments.flatMap { x1018 =>
val x1169 = Schema.org.employees.flatMap { x1111 =>
val x1020 = x1018.dpt
val x1112 = x1111.dpt
val x1113 = x1020 == x1112
val x1161 = Schema.org.tasks.flatMap { x1148 =>
val x1116 = x1111.emp
val x1149 = x1148.emp
val x1150 = x1116 == x1149
val x1153 = x1148.tsk
val x1156 = x1153 == "abstract"
val x1158 = x1150 && x1156
val x1160 = if (x1158) {
val x685 = new Schema.Record { val ignore = () }
val x712 = List(x685)
x712
} else {
val x19 = List()
x19
val x34 = List()
x34
}
x1132
x1160
}
val x1134 = x1133.isEmpty
val x1137 = x1085 && x1134
val x1139 = if (x1137) {
val x661 = new Schema.Record { val ignore = () }
val x688 = List(x661)
x688
val x1162 = x1161.isEmpty
val x1166 = x1113 && x1162
val x1168 = if (x1166) {
val x685 = new Schema.Record { val ignore = () }
val x712 = List(x685)
x712
} else {
val x19 = List()
x19
val x34 = List()
x34
}
x1139
x1168
}
val x1141 = x1140.isEmpty
val x1146 = if (x1141) {
val x993 = x991.dpt
val x1144 = new Schema.Record { val dpt = x993 }
val x1145 = List(x1144)
x1145
val x1170 = x1169.isEmpty
val x1176 = if (x1170) {
val x1020 = x1018.dpt
val x1174 = new Schema.Record { val dpt = x1020 }
val x1175 = List(x1174)
x1175
} else {
val x19 = List()
x19
val x34 = List()
x34
}
x1146
x1176
}
val x2551 = println(x1147)
x2551
val x2581 = println(x1177)
x2581
}
}
/*****************************************
Expand Down
15 changes: 8 additions & 7 deletions test-out/epfl/test9-struct3.check
Original file line number Diff line number Diff line change
Expand Up @@ -3,29 +3,30 @@ partitions: List(TTP(List(Sym(4), Sym(5), Sym(13)),List(SimpleLoop(Const(100),Sy
considering TP(Sym(9),ArrayIndex(Sym(5),Sym(7)))
replace TP(Sym(9),ArrayIndex(Sym(5),Sym(7))) at 1 within TTP(List(Sym(4), Sym(5), Sym(13)),List(SimpleLoop(Const(100),Sym(1),ArrayElem(Block(Sym(1)))), SimpleLoop(Const(100),Sym(1),ArrayElem(Block(Sym(2)))), SimpleLoop(Const(100),Sym(7),ArrayElem(Block(Sym(11))))),SimpleFatLoop(Const(100),Sym(1),List(ArrayElem(Block(Sym(1))), ArrayElem(Block(Sym(2))), ArrayElem(Block(Sym(11))))))
warning: mirroring of Sym(11)=Minus(Const(0.0),Sym(9)) type Double returned Sym(1) type Int (not a subtype)
warning: mirroring of Sym(13)=SimpleLoop(Const(100),Sym(7),ArrayElem(Block(Sym(11)))) type Array[Double] returned Sym(4)=SimpleLoop(Const(100),Sym(1),ArrayElem(Block(Sym(1)))) type Array[Int] (not a subtype)
try once more ...
wtableneg: List()
partitions: List(TTP(List(Sym(4)),List(SimpleLoop(Const(100),Sym(1),ArrayElem(Block(Sym(1))))),SimpleFatLoop(Const(100),Sym(1),List(ArrayElem(Block(Sym(1)))))))
partitions: List(TTP(List(Sym(4), Sym(17)),List(SimpleLoop(Const(100),Sym(1),ArrayElem(Block(Sym(1)))), SimpleLoop(Const(100),Sym(1),ArrayElem(Block(Sym(1))))),SimpleFatLoop(Const(100),Sym(1),List(ArrayElem(Block(Sym(1))), ArrayElem(Block(Sym(1)))))))
no changes, we're done
super.focusExactScopeFat with result changed from List(Sym(16)) to List(Sym(19))
super.focusExactScopeFat with result changed from List(Sym(16)) to List(Sym(20))
/*****************************************
Emitting Generated Code
*******************************************/
class Test extends ((Int)=>(Unit)) {
def apply(x0:Int): Unit = {
var x4 = new Array[Int](100)
var x17 = new Array[Int](100)
for (x1 <- 0 until 100) {
x4(x1) = x1
x17(x1) = x1
}
val x17 = new ArrayOfAnon189207751(x4,x4)
val x18 = println(x17)
val x15 = x18
val x18 = new ArrayOfAnon189207751(x4,x17)
val x19 = println(x18)
val x15 = x19
x15
}
}
/*****************************************
End of Generated Code
*******************************************/

case class ArrayOfAnon189207751(re: Array[Int], im: Array[Int])
case class ArrayOfAnon189207751(re: Array[Int], im: Array[Double])
2 changes: 1 addition & 1 deletion test-src/epfl/test2-fft/DisableOpts.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ package test2
import internal._

trait DisableCSE extends Expressions {
override def findDefinition[T](d: Def[T]) = None
override def findDefinition[T: Manifest](d: Def[T]) = None
}


Expand Down
Loading