Skip to content

Commit 5f11cfb

Browse files
authored
Merge pull request #115 from astojanov/develop
NumericOps: adding Const() optimisations.
2 parents b94ac17 + ea97605 commit 5f11cfb

File tree

1 file changed

+29
-35
lines changed

1 file changed

+29
-35
lines changed

src/common/NumericOps.scala

Lines changed: 29 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -14,25 +14,22 @@ trait NumericOps extends Variables {
1414
this: PrimitiveOps =>
1515

1616
// workaround for infix not working with manifests
17-
implicit def numericToNumericOps[T:Numeric:Typ](n: T) = new NumericOpsCls(unit(n))
17+
implicit def numericToNumericOps [T:Numeric:Typ](n: T) = new NumericOpsCls(unit(n))
1818
implicit def repNumericToNumericOps[T:Numeric:Typ](n: Rep[T]) = new NumericOpsCls(n)
1919
implicit def varNumericToNumericOps[T:Numeric:Typ](n: Var[T]) = new NumericOpsCls(readVar(n))
2020

2121
class NumericOpsCls[T:Numeric:Typ](lhs: Rep[T]){
2222
def +[A](rhs: A)(implicit c: A => T, pos: SourceContext) = numeric_plus(lhs,unit(c(rhs)))
23-
def +(rhs: Rep[T])(implicit pos: SourceContext) = numeric_plus(lhs,rhs)
24-
def -(rhs: Rep[T])(implicit pos: SourceContext) = numeric_minus(lhs,rhs)
25-
def *(rhs: Rep[T])(implicit pos: SourceContext) = numeric_times(lhs,rhs)
23+
def +(rhs: Rep[T])(implicit pos: SourceContext) = numeric_plus (lhs,rhs)
24+
def -(rhs: Rep[T])(implicit pos: SourceContext) = numeric_minus (lhs,rhs)
25+
def *(rhs: Rep[T])(implicit pos: SourceContext) = numeric_times (lhs,rhs)
2626
def /(rhs: Rep[T])(implicit pos: SourceContext) = numeric_divide(lhs,rhs)
2727
}
2828

29-
def numeric_plus[T:Numeric:Typ](lhs: Rep[T], rhs: Rep[T])(implicit pos: SourceContext): Rep[T]
30-
def numeric_minus[T:Numeric:Typ](lhs: Rep[T], rhs: Rep[T])(implicit pos: SourceContext): Rep[T]
31-
def numeric_times[T:Numeric:Typ](lhs: Rep[T], rhs: Rep[T])(implicit pos: SourceContext): Rep[T]
29+
def numeric_plus [T:Numeric:Typ](lhs: Rep[T], rhs: Rep[T])(implicit pos: SourceContext): Rep[T]
30+
def numeric_minus [T:Numeric:Typ](lhs: Rep[T], rhs: Rep[T])(implicit pos: SourceContext): Rep[T]
31+
def numeric_times [T:Numeric:Typ](lhs: Rep[T], rhs: Rep[T])(implicit pos: SourceContext): Rep[T]
3232
def numeric_divide[T:Numeric:Typ](lhs: Rep[T], rhs: Rep[T])(implicit pos: SourceContext): Rep[T]
33-
//def numeric_negate[T:Numeric](x: T): Rep[T]
34-
//def numeric_abs[T:Numeric](x: T): Rep[T]
35-
//def numeric_signum[T:Numeric](x: T): Rep[Int]
3633
}
3734

3835
trait NumericOpsExp extends NumericOps with VariablesExp with BaseFatExp {
@@ -43,27 +40,25 @@ trait NumericOpsExp extends NumericOps with VariablesExp with BaseFatExp {
4340
def aev = implicitly[Numeric[A]]
4441
}
4542

46-
case class NumericPlus[T:Numeric:Typ](lhs: Exp[T], rhs: Exp[T]) extends DefMN[T]
47-
case class NumericMinus[T:Numeric:Typ](lhs: Exp[T], rhs: Exp[T]) extends DefMN[T]
48-
case class NumericTimes[T:Numeric:Typ](lhs: Exp[T], rhs: Exp[T]) extends DefMN[T]
49-
case class NumericDivide[T:Numeric:Typ](lhs: Exp[T], rhs: Exp[T]) extends DefMN[T]
43+
case class NumericPlus [T:Numeric:Typ](lhs: Exp[T], rhs: Exp[T]) extends DefMN[T]
44+
case class NumericMinus [T:Numeric:Typ](lhs: Exp[T], rhs: Exp[T]) extends DefMN[T]
45+
case class NumericTimes [T:Numeric:Typ](lhs: Exp[T], rhs: Exp[T]) extends DefMN[T]
46+
case class NumericDivide [T:Numeric:Typ](lhs: Exp[T], rhs: Exp[T]) extends DefMN[T]
5047

51-
def numeric_plus[T:Numeric:Typ](lhs: Exp[T], rhs: Exp[T])(implicit pos: SourceContext) : Exp[T] = NumericPlus(lhs, rhs)
52-
def numeric_minus[T:Numeric:Typ](lhs: Exp[T], rhs: Exp[T])(implicit pos: SourceContext) : Exp[T] = NumericMinus(lhs, rhs)
53-
def numeric_times[T:Numeric:Typ](lhs: Exp[T], rhs: Exp[T])(implicit pos: SourceContext) : Exp[T] = NumericTimes(lhs, rhs)
54-
def numeric_divide[T:Numeric:Typ](lhs: Exp[T], rhs: Exp[T])(implicit pos: SourceContext) : Exp[T] = NumericDivide(lhs, rhs)
48+
def numeric_plus [T:Numeric:Typ](lhs: Exp[T], rhs: Exp[T])(implicit pos: SourceContext) : Exp[T] = NumericPlus(lhs, rhs)
49+
def numeric_minus [T:Numeric:Typ](lhs: Exp[T], rhs: Exp[T])(implicit pos: SourceContext) : Exp[T] = NumericMinus(lhs, rhs)
50+
def numeric_times [T:Numeric:Typ](lhs: Exp[T], rhs: Exp[T])(implicit pos: SourceContext) : Exp[T] = NumericTimes(lhs, rhs)
51+
def numeric_divide [T:Numeric:Typ](lhs: Exp[T], rhs: Exp[T])(implicit pos: SourceContext) : Exp[T] = NumericDivide(lhs, rhs)
5552

5653
override def mirror[A:Typ](e: Def[A], f: Transformer)(implicit pos: SourceContext): Exp[A] = (e match {
57-
case e@NumericPlus(l,r) => numeric_plus(f(l), f(r))(e.aev.asInstanceOf[Numeric[A]], mtype(e.mev), pos)
58-
case e@NumericMinus(l,r) => numeric_minus(f(l), f(r))(e.aev.asInstanceOf[Numeric[A]], mtype(e.mev), pos)
59-
case e@NumericTimes(l,r) => numeric_times(f(l), f(r))(e.aev.asInstanceOf[Numeric[A]], mtype(e.mev), pos)
54+
case e@NumericPlus (l,r) => numeric_plus (f(l), f(r))(e.aev.asInstanceOf[Numeric[A]], mtype(e.mev), pos)
55+
case e@NumericMinus (l,r) => numeric_minus (f(l), f(r))(e.aev.asInstanceOf[Numeric[A]], mtype(e.mev), pos)
56+
case e@NumericTimes (l,r) => numeric_times (f(l), f(r))(e.aev.asInstanceOf[Numeric[A]], mtype(e.mev), pos)
6057
case e@NumericDivide(l,r) => numeric_divide(f(l), f(r))(e.aev.asInstanceOf[Numeric[A]], mtype(e.mev), pos)
6158
case _ => super.mirror(e,f)
6259
}).asInstanceOf[Exp[A]]
63-
6460
}
6561

66-
6762
trait NumericOpsExpOpt extends NumericOpsExp {
6863
this: PrimitiveOpsExp =>
6964

@@ -75,6 +70,7 @@ trait NumericOpsExpOpt extends NumericOpsExp {
7570
}
7671
override def numeric_minus[T:Numeric:Typ](lhs: Exp[T], rhs: Exp[T])(implicit pos: SourceContext): Exp[T] = (lhs,rhs) match {
7772
case (Const(x), Const(y)) => Const(implicitly[Numeric[T]].minus(x,y))
73+
case (x, Const(y)) if y == implicitly[Numeric[T]].zero => x
7874
case _ => super.numeric_minus(lhs,rhs)
7975
}
8076
override def numeric_times[T:Numeric:Typ](lhs: Exp[T], rhs: Exp[T])(implicit pos: SourceContext): Exp[T] = (lhs,rhs) match {
@@ -88,6 +84,8 @@ trait NumericOpsExpOpt extends NumericOpsExp {
8884
override def numeric_divide[T:Numeric:Typ](lhs: Exp[T], rhs: Exp[T])(implicit pos: SourceContext): Exp[T] = (lhs,rhs) match {
8985
// CAVEAT: Numeric doesn't have .div, Fractional has
9086
case (Const(x), Const(y)) => Const(implicitly[Numeric[T]].asInstanceOf[Fractional[T]].div(x,y))
87+
case (Const(x), y) if x == implicitly[Numeric[T]].zero => Const(x)
88+
case (x, Const(y)) if y == implicitly[Numeric[T]].one => x
9189
case _ => super.numeric_divide(lhs,rhs)
9290
}
9391
}
@@ -98,10 +96,10 @@ trait ScalaGenNumericOps extends ScalaGenFat {
9896
import IR._
9997

10098
override def emitNode(sym: Sym[Any], rhs: Def[Any]) = rhs match {
101-
case NumericPlus(a,b) => emitValDef(sym, src"$a + $b")
102-
case NumericMinus(a,b) => emitValDef(sym, src"$a - $b")
103-
case NumericTimes(a,b) => emitValDef(sym, src"$a * $b")
104-
case NumericDivide(a,b) => emitValDef(sym, src"$a / $b")
99+
case NumericPlus (a,b) => emitValDef(sym, src"$a + $b")
100+
case NumericMinus (a,b) => emitValDef(sym, src"$a - $b")
101+
case NumericTimes (a,b) => emitValDef(sym, src"$a * $b")
102+
case NumericDivide (a,b) => emitValDef(sym, src"$a / $b")
105103
case _ => super.emitNode(sym, rhs)
106104
}
107105
}
@@ -112,14 +110,10 @@ trait CLikeGenNumericOps extends CLikeGenBase {
112110

113111
override def emitNode(sym: Sym[Any], rhs: Def[Any]) = {
114112
rhs match {
115-
case NumericPlus(a,b) =>
116-
emitValDef(sym, src"$a + $b")
117-
case NumericMinus(a,b) =>
118-
emitValDef(sym, src"$a - $b")
119-
case NumericTimes(a,b) =>
120-
emitValDef(sym, src"$a * $b")
121-
case NumericDivide(a,b) =>
122-
emitValDef(sym, src"$a / $b")
113+
case NumericPlus (a,b) => emitValDef(sym, src"$a + $b")
114+
case NumericMinus (a,b) => emitValDef(sym, src"$a - $b")
115+
case NumericTimes (a,b) => emitValDef(sym, src"$a * $b")
116+
case NumericDivide(a,b) => emitValDef(sym, src"$a / $b")
123117
case _ => super.emitNode(sym, rhs)
124118
}
125119
}

0 commit comments

Comments
 (0)