Skip to content

Commit 41419a8

Browse files
authored
Merge pull request #104 from scalan/develop-0.9.x
Fix incorrect equality for ArrayNew and friends, add missing rewrites in PrimitiveOps
2 parents 97164c1 + afc2376 commit 41419a8

File tree

10 files changed

+151
-133
lines changed

10 files changed

+151
-133
lines changed

.gitignore

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
lib_managed
22
*.iml
3+
.idea/
34
.DS_Store
45
local.properties
56
project/boot
67
project/build/target
7-
target
8-
virtualization-lms-core.iml
9-
.gitingore
10-
#test-out
8+
target/
9+
test-out/
10+
!test-out/**/*.check
11+
data/

src/internal/Expressions.scala

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,12 @@ trait Expressions extends Utils {
2020
def pos: List[SourceContext] = Nil
2121
}
2222

23-
case class Const[+T:Manifest](x: T) extends Exp[T]
23+
case class Const[+T:Manifest](x: T) extends Exp[T] {
24+
override def equals(other: Any) = other match {
25+
case c: Const[_] => x == c.x && tp == c.tp
26+
case _ => false
27+
}
28+
}
2429

2530
case class Sym[+T:Manifest](val id: Int) extends Exp[T] {
2631
var sourceContexts: List[SourceContext] = Nil
@@ -65,11 +70,11 @@ trait Expressions extends Utils {
6570
case _ => None
6671
}
6772

68-
def infix_defines[A](stm: Stm, rhs: Def[A]): Option[Sym[A]] = stm match {
69-
case TP(sym: Sym[A], `rhs`) => Some(sym)
73+
def infix_defines[A: Manifest](stm: Stm, rhs: Def[A]): Option[Sym[A]] = stm match {
74+
case TP(sym: Sym[A], `rhs`) if sym.tp <:< manifest[A] => Some(sym)
7075
case _ => None
7176
}
72-
77+
7378
case class TP[+T](sym: Sym[T], rhs: Def[T]) extends Stm
7479

7580
// graph construction state
@@ -108,16 +113,19 @@ trait Expressions extends Utils {
108113
globalDefsCache.get(s)
109114
//globalDefs.find(x => x.defines(s).nonEmpty)
110115

111-
def findDefinition[T](d: Def[T]): Option[Stm] =
116+
def findDefinition[T: Manifest](d: Def[T]): Option[Stm] =
112117
globalDefs.find(x => x.defines(d).nonEmpty)
113118

114119
def findOrCreateDefinition[T:Manifest](d: Def[T], pos: List[SourceContext]): Stm =
115120
findDefinition[T](d) map { x => x.defines(d).foreach(_.withPos(pos)); x } getOrElse {
116121
createDefinition(fresh[T](pos), d)
117122
}
118123

119-
def findOrCreateDefinitionExp[T:Manifest](d: Def[T], pos: List[SourceContext]): Exp[T] =
120-
findOrCreateDefinition(d, pos).defines(d).get
124+
def findOrCreateDefinitionExp[T:Manifest](d: Def[T], pos: List[SourceContext]): Exp[T] = {
125+
val stm = findOrCreateDefinition(d, pos)
126+
val optExp = stm.defines(d)
127+
optExp.get
128+
}
121129

122130
def createDefinition[T](s: Sym[T], d: Def[T]): Stm = {
123131
val f = TP(s, d)

src/internal/FatExpressions.scala

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,16 @@ trait FatExpressions extends Expressions {
2424
case _ => super.infix_defines(stm, sym)
2525
}
2626

27-
override def infix_defines[A](stm: Stm, rhs: Def[A]): Option[Sym[A]] = stm match {
28-
case TTP(lhs, mhs, rhs) => mhs.indexOf(rhs) match { case idx if idx >= 0 => Some(lhs(idx).asInstanceOf[Sym[A]]) case _ => None }
27+
override def infix_defines[A: Manifest](stm: Stm, rhs: Def[A]): Option[Sym[A]] = stm match {
28+
case TTP(lhs, mhs, rhs) => mhs.indexOf(rhs) match {
29+
case idx if idx >= 0 =>
30+
val sym = lhs(idx)
31+
if (sym.tp <:< manifest[A])
32+
Some(sym.asInstanceOf[Sym[A]])
33+
else
34+
None
35+
case _ => None
36+
}
2937
case _ => super.infix_defines(stm, rhs)
3038
}
3139

test-out/epfl/test14-queries2.check

Lines changed: 28 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -6,39 +6,39 @@ nVars=2000
66
import scala.lms.epfl.test14.Schema
77
class staged$0 extends ((Unit)=>(Unit)) {
88
def apply(x0:Unit): Unit = {
9-
val x2550 = println("rangeFromNames(\"Edna\",\"Bert\"):")
10-
val x617 = Schema.db.people.flatMap { x111 =>
11-
val x616 = Schema.db.people.flatMap { x435 =>
12-
val x615 = Schema.db.people.flatMap { x598 =>
13-
val x112 = x111.name
14-
val x113 = x112 == "Edna"
15-
val x438 = x435.name
16-
val x439 = x438 == "Bert"
17-
val x116 = x111.age
18-
val x602 = x598.age
19-
val x603 = x116 <= x602
20-
val x441 = x435.age
21-
val x604 = x602 < x441
22-
val x605 = x603 && x604
23-
val x611 = x439 && x605
24-
val x613 = x113 && x611
25-
val x614 = if (x613) {
26-
val x607 = x598.name
27-
val x608 = new Schema.Record { val name = x607 }
28-
val x609 = List(x608)
29-
x609
9+
val x2580 = println("rangeFromNames(\"Edna\",\"Bert\"):")
10+
val x641 = Schema.db.people.flatMap { x112 =>
11+
val x640 = Schema.db.people.flatMap { x445 =>
12+
val x639 = Schema.db.people.flatMap { x620 =>
13+
val x113 = x112.name
14+
val x114 = x113 == "Edna"
15+
val x448 = x445.name
16+
val x449 = x448 == "Bert"
17+
val x117 = x112.age
18+
val x624 = x620.age
19+
val x625 = x117 <= x624
20+
val x452 = x445.age
21+
val x626 = x624 < x452
22+
val x627 = x625 && x626
23+
val x634 = x449 && x627
24+
val x637 = x114 && x634
25+
val x638 = if (x637) {
26+
val x630 = x620.name
27+
val x631 = new Schema.Record { val name = x630 }
28+
val x632 = List(x631)
29+
x632
3030
} else {
31-
val x19 = List()
32-
x19
31+
val x34 = List()
32+
x34
3333
}
34-
x614
34+
x638
3535
}
36-
x615
36+
x639
3737
}
38-
x616
38+
x640
3939
}
40-
val x2551 = println(x617)
41-
x2551
40+
val x2581 = println(x641)
41+
x2581
4242
}
4343
}
4444
/*****************************************

test-out/epfl/test14-queries3.check

Lines changed: 40 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -6,55 +6,55 @@ nVars=2000
66
import scala.lms.epfl.test14.Schema
77
class staged$0 extends ((Unit)=>(Unit)) {
88
def apply(x0:Unit): Unit = {
9-
val x2550 = println("expertise(\"abstract\"):")
10-
val x800 = Schema.org.departments.flatMap { x732 =>
11-
val x792 = Schema.org.employees.flatMap { x763 =>
12-
val x734 = x732.dpt
13-
val x764 = x763.dpt
14-
val x765 = x734 == x764
15-
val x785 = Schema.org.tasks.flatMap { x776 =>
16-
val x767 = x763.emp
17-
val x777 = x776.emp
18-
val x778 = x767 == x777
19-
val x779 = x776.tsk
20-
val x780 = x779 == "abstract"
21-
val x781 = x778 && x780
22-
val x784 = if (x781) {
23-
val x661 = new Schema.Record { val ignore = () }
24-
val x688 = List(x661)
25-
x688
9+
val x2580 = println("expertise(\"abstract\"):")
10+
val x824 = Schema.org.departments.flatMap { x756 =>
11+
val x816 = Schema.org.employees.flatMap { x787 =>
12+
val x758 = x756.dpt
13+
val x788 = x787.dpt
14+
val x789 = x758 == x788
15+
val x809 = Schema.org.tasks.flatMap { x800 =>
16+
val x791 = x787.emp
17+
val x801 = x800.emp
18+
val x802 = x791 == x801
19+
val x803 = x800.tsk
20+
val x804 = x803 == "abstract"
21+
val x805 = x802 && x804
22+
val x808 = if (x805) {
23+
val x685 = new Schema.Record { val ignore = () }
24+
val x712 = List(x685)
25+
x712
2626
} else {
27-
val x19 = List()
28-
x19
27+
val x34 = List()
28+
x34
2929
}
30-
x784
30+
x808
3131
}
32-
val x786 = x785.isEmpty
33-
val x788 = x765 && x786
34-
val x791 = if (x788) {
35-
val x661 = new Schema.Record { val ignore = () }
36-
val x688 = List(x661)
37-
x688
32+
val x810 = x809.isEmpty
33+
val x812 = x789 && x810
34+
val x815 = if (x812) {
35+
val x685 = new Schema.Record { val ignore = () }
36+
val x712 = List(x685)
37+
x712
3838
} else {
39-
val x19 = List()
40-
x19
39+
val x34 = List()
40+
x34
4141
}
42-
x791
42+
x815
4343
}
44-
val x793 = x792.isEmpty
45-
val x799 = if (x793) {
46-
val x734 = x732.dpt
47-
val x797 = new Schema.Record { val dpt = x734 }
48-
val x798 = List(x797)
49-
x798
44+
val x817 = x816.isEmpty
45+
val x823 = if (x817) {
46+
val x758 = x756.dpt
47+
val x821 = new Schema.Record { val dpt = x758 }
48+
val x822 = List(x821)
49+
x822
5050
} else {
51-
val x19 = List()
52-
x19
51+
val x34 = List()
52+
x34
5353
}
54-
x799
54+
x823
5555
}
56-
val x2551 = println(x800)
57-
x2551
56+
val x2581 = println(x824)
57+
x2581
5858
}
5959
}
6060
/*****************************************

test-out/epfl/test14-queries4.check

Lines changed: 40 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -6,55 +6,55 @@ nVars=2000
66
import scala.lms.epfl.test14.Schema
77
class staged$0 extends ((Unit)=>(Unit)) {
88
def apply(x0:Unit): Unit = {
9-
val x2550 = println("expertise2(\"abstract\"):")
10-
val x1147 = Schema.org.departments.flatMap { x991 =>
11-
val x1140 = Schema.org.employees.flatMap { x1083 =>
12-
val x993 = x991.dpt
13-
val x1084 = x1083.dpt
14-
val x1085 = x993 == x1084
15-
val x1133 = Schema.org.tasks.flatMap { x1120 =>
16-
val x1088 = x1083.emp
17-
val x1121 = x1120.emp
18-
val x1122 = x1088 == x1121
19-
val x1125 = x1120.tsk
20-
val x1128 = x1125 == "abstract"
21-
val x1130 = x1122 && x1128
22-
val x1132 = if (x1130) {
23-
val x661 = new Schema.Record { val ignore = () }
24-
val x688 = List(x661)
25-
x688
9+
val x2580 = println("expertise2(\"abstract\"):")
10+
val x1177 = Schema.org.departments.flatMap { x1018 =>
11+
val x1169 = Schema.org.employees.flatMap { x1111 =>
12+
val x1020 = x1018.dpt
13+
val x1112 = x1111.dpt
14+
val x1113 = x1020 == x1112
15+
val x1161 = Schema.org.tasks.flatMap { x1148 =>
16+
val x1116 = x1111.emp
17+
val x1149 = x1148.emp
18+
val x1150 = x1116 == x1149
19+
val x1153 = x1148.tsk
20+
val x1156 = x1153 == "abstract"
21+
val x1158 = x1150 && x1156
22+
val x1160 = if (x1158) {
23+
val x685 = new Schema.Record { val ignore = () }
24+
val x712 = List(x685)
25+
x712
2626
} else {
27-
val x19 = List()
28-
x19
27+
val x34 = List()
28+
x34
2929
}
30-
x1132
30+
x1160
3131
}
32-
val x1134 = x1133.isEmpty
33-
val x1137 = x1085 && x1134
34-
val x1139 = if (x1137) {
35-
val x661 = new Schema.Record { val ignore = () }
36-
val x688 = List(x661)
37-
x688
32+
val x1162 = x1161.isEmpty
33+
val x1166 = x1113 && x1162
34+
val x1168 = if (x1166) {
35+
val x685 = new Schema.Record { val ignore = () }
36+
val x712 = List(x685)
37+
x712
3838
} else {
39-
val x19 = List()
40-
x19
39+
val x34 = List()
40+
x34
4141
}
42-
x1139
42+
x1168
4343
}
44-
val x1141 = x1140.isEmpty
45-
val x1146 = if (x1141) {
46-
val x993 = x991.dpt
47-
val x1144 = new Schema.Record { val dpt = x993 }
48-
val x1145 = List(x1144)
49-
x1145
44+
val x1170 = x1169.isEmpty
45+
val x1176 = if (x1170) {
46+
val x1020 = x1018.dpt
47+
val x1174 = new Schema.Record { val dpt = x1020 }
48+
val x1175 = List(x1174)
49+
x1175
5050
} else {
51-
val x19 = List()
52-
x19
51+
val x34 = List()
52+
x34
5353
}
54-
x1146
54+
x1176
5555
}
56-
val x2551 = println(x1147)
57-
x2551
56+
val x2581 = println(x1177)
57+
x2581
5858
}
5959
}
6060
/*****************************************

test-out/epfl/test9-struct3.check

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,29 +3,30 @@ partitions: List(TTP(List(Sym(4), Sym(5), Sym(13)),List(SimpleLoop(Const(100),Sy
33
considering TP(Sym(9),ArrayIndex(Sym(5),Sym(7)))
44
replace TP(Sym(9),ArrayIndex(Sym(5),Sym(7))) at 1 within TTP(List(Sym(4), Sym(5), Sym(13)),List(SimpleLoop(Const(100),Sym(1),ArrayElem(Block(Sym(1)))), SimpleLoop(Const(100),Sym(1),ArrayElem(Block(Sym(2)))), SimpleLoop(Const(100),Sym(7),ArrayElem(Block(Sym(11))))),SimpleFatLoop(Const(100),Sym(1),List(ArrayElem(Block(Sym(1))), ArrayElem(Block(Sym(2))), ArrayElem(Block(Sym(11))))))
55
warning: mirroring of Sym(11)=Minus(Const(0.0),Sym(9)) type Double returned Sym(1) type Int (not a subtype)
6-
warning: mirroring of Sym(13)=SimpleLoop(Const(100),Sym(7),ArrayElem(Block(Sym(11)))) type Array[Double] returned Sym(4)=SimpleLoop(Const(100),Sym(1),ArrayElem(Block(Sym(1)))) type Array[Int] (not a subtype)
76
try once more ...
87
wtableneg: List()
9-
partitions: List(TTP(List(Sym(4)),List(SimpleLoop(Const(100),Sym(1),ArrayElem(Block(Sym(1))))),SimpleFatLoop(Const(100),Sym(1),List(ArrayElem(Block(Sym(1)))))))
8+
partitions: List(TTP(List(Sym(4), Sym(17)),List(SimpleLoop(Const(100),Sym(1),ArrayElem(Block(Sym(1)))), SimpleLoop(Const(100),Sym(1),ArrayElem(Block(Sym(1))))),SimpleFatLoop(Const(100),Sym(1),List(ArrayElem(Block(Sym(1))), ArrayElem(Block(Sym(1)))))))
109
no changes, we're done
11-
super.focusExactScopeFat with result changed from List(Sym(16)) to List(Sym(19))
10+
super.focusExactScopeFat with result changed from List(Sym(16)) to List(Sym(20))
1211
/*****************************************
1312
Emitting Generated Code
1413
*******************************************/
1514
class Test extends ((Int)=>(Unit)) {
1615
def apply(x0:Int): Unit = {
1716
var x4 = new Array[Int](100)
17+
var x17 = new Array[Int](100)
1818
for (x1 <- 0 until 100) {
1919
x4(x1) = x1
20+
x17(x1) = x1
2021
}
21-
val x17 = new ArrayOfAnon189207751(x4,x4)
22-
val x18 = println(x17)
23-
val x15 = x18
22+
val x18 = new ArrayOfAnon189207751(x4,x17)
23+
val x19 = println(x18)
24+
val x15 = x19
2425
x15
2526
}
2627
}
2728
/*****************************************
2829
End of Generated Code
2930
*******************************************/
3031

31-
case class ArrayOfAnon189207751(re: Array[Int], im: Array[Int])
32+
case class ArrayOfAnon189207751(re: Array[Int], im: Array[Double])

test-src/epfl/test2-fft/DisableOpts.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ package test2
55
import internal._
66

77
trait DisableCSE extends Expressions {
8-
override def findDefinition[T](d: Def[T]) = None
8+
override def findDefinition[T: Manifest](d: Def[T]) = None
99
}
1010

1111

0 commit comments

Comments
 (0)