@@ -14,25 +14,22 @@ trait NumericOps extends Variables {
14
14
this : PrimitiveOps =>
15
15
16
16
// workaround for infix not working with manifests
17
- implicit def numericToNumericOps [T : Numeric : Typ ](n : T ) = new NumericOpsCls (unit(n))
17
+ implicit def numericToNumericOps [T : Numeric : Typ ](n : T ) = new NumericOpsCls (unit(n))
18
18
implicit def repNumericToNumericOps [T : Numeric : Typ ](n : Rep [T ]) = new NumericOpsCls (n)
19
19
implicit def varNumericToNumericOps [T : Numeric : Typ ](n : Var [T ]) = new NumericOpsCls (readVar(n))
20
20
21
21
class NumericOpsCls [T : Numeric : Typ ](lhs : Rep [T ]){
22
22
def + [A ](rhs : A )(implicit c : A => T , pos : SourceContext ) = numeric_plus(lhs,unit(c(rhs)))
23
- def + (rhs : Rep [T ])(implicit pos : SourceContext ) = numeric_plus(lhs,rhs)
24
- def - (rhs : Rep [T ])(implicit pos : SourceContext ) = numeric_minus(lhs,rhs)
25
- def * (rhs : Rep [T ])(implicit pos : SourceContext ) = numeric_times(lhs,rhs)
23
+ def + (rhs : Rep [T ])(implicit pos : SourceContext ) = numeric_plus (lhs,rhs)
24
+ def - (rhs : Rep [T ])(implicit pos : SourceContext ) = numeric_minus (lhs,rhs)
25
+ def * (rhs : Rep [T ])(implicit pos : SourceContext ) = numeric_times (lhs,rhs)
26
26
def / (rhs : Rep [T ])(implicit pos : SourceContext ) = numeric_divide(lhs,rhs)
27
27
}
28
28
29
- def numeric_plus [T : Numeric : Typ ](lhs : Rep [T ], rhs : Rep [T ])(implicit pos : SourceContext ): Rep [T ]
30
- def numeric_minus [T : Numeric : Typ ](lhs : Rep [T ], rhs : Rep [T ])(implicit pos : SourceContext ): Rep [T ]
31
- def numeric_times [T : Numeric : Typ ](lhs : Rep [T ], rhs : Rep [T ])(implicit pos : SourceContext ): Rep [T ]
29
+ def numeric_plus [T : Numeric : Typ ](lhs : Rep [T ], rhs : Rep [T ])(implicit pos : SourceContext ): Rep [T ]
30
+ def numeric_minus [T : Numeric : Typ ](lhs : Rep [T ], rhs : Rep [T ])(implicit pos : SourceContext ): Rep [T ]
31
+ def numeric_times [T : Numeric : Typ ](lhs : Rep [T ], rhs : Rep [T ])(implicit pos : SourceContext ): Rep [T ]
32
32
def numeric_divide [T : Numeric : Typ ](lhs : Rep [T ], rhs : Rep [T ])(implicit pos : SourceContext ): Rep [T ]
33
- // def numeric_negate[T:Numeric](x: T): Rep[T]
34
- // def numeric_abs[T:Numeric](x: T): Rep[T]
35
- // def numeric_signum[T:Numeric](x: T): Rep[Int]
36
33
}
37
34
38
35
trait NumericOpsExp extends NumericOps with VariablesExp with BaseFatExp {
@@ -43,27 +40,25 @@ trait NumericOpsExp extends NumericOps with VariablesExp with BaseFatExp {
43
40
def aev = implicitly[Numeric [A ]]
44
41
}
45
42
46
- case class NumericPlus [T : Numeric : Typ ](lhs : Exp [T ], rhs : Exp [T ]) extends DefMN [T ]
47
- case class NumericMinus [T : Numeric : Typ ](lhs : Exp [T ], rhs : Exp [T ]) extends DefMN [T ]
48
- case class NumericTimes [T : Numeric : Typ ](lhs : Exp [T ], rhs : Exp [T ]) extends DefMN [T ]
49
- case class NumericDivide [T : Numeric : Typ ](lhs : Exp [T ], rhs : Exp [T ]) extends DefMN [T ]
43
+ case class NumericPlus [T : Numeric : Typ ](lhs : Exp [T ], rhs : Exp [T ]) extends DefMN [T ]
44
+ case class NumericMinus [T : Numeric : Typ ](lhs : Exp [T ], rhs : Exp [T ]) extends DefMN [T ]
45
+ case class NumericTimes [T : Numeric : Typ ](lhs : Exp [T ], rhs : Exp [T ]) extends DefMN [T ]
46
+ case class NumericDivide [T : Numeric : Typ ](lhs : Exp [T ], rhs : Exp [T ]) extends DefMN [T ]
50
47
51
- def numeric_plus [T : Numeric : Typ ](lhs : Exp [T ], rhs : Exp [T ])(implicit pos : SourceContext ) : Exp [T ] = NumericPlus (lhs, rhs)
52
- def numeric_minus [T : Numeric : Typ ](lhs : Exp [T ], rhs : Exp [T ])(implicit pos : SourceContext ) : Exp [T ] = NumericMinus (lhs, rhs)
53
- def numeric_times [T : Numeric : Typ ](lhs : Exp [T ], rhs : Exp [T ])(implicit pos : SourceContext ) : Exp [T ] = NumericTimes (lhs, rhs)
54
- def numeric_divide [T : Numeric : Typ ](lhs : Exp [T ], rhs : Exp [T ])(implicit pos : SourceContext ) : Exp [T ] = NumericDivide (lhs, rhs)
48
+ def numeric_plus [T : Numeric : Typ ](lhs : Exp [T ], rhs : Exp [T ])(implicit pos : SourceContext ) : Exp [T ] = NumericPlus (lhs, rhs)
49
+ def numeric_minus [T : Numeric : Typ ](lhs : Exp [T ], rhs : Exp [T ])(implicit pos : SourceContext ) : Exp [T ] = NumericMinus (lhs, rhs)
50
+ def numeric_times [T : Numeric : Typ ](lhs : Exp [T ], rhs : Exp [T ])(implicit pos : SourceContext ) : Exp [T ] = NumericTimes (lhs, rhs)
51
+ def numeric_divide [T : Numeric : Typ ](lhs : Exp [T ], rhs : Exp [T ])(implicit pos : SourceContext ) : Exp [T ] = NumericDivide (lhs, rhs)
55
52
56
53
override def mirror [A : Typ ](e : Def [A ], f : Transformer )(implicit pos : SourceContext ): Exp [A ] = (e match {
57
- case e@ NumericPlus (l,r) => numeric_plus(f(l), f(r))(e.aev.asInstanceOf [Numeric [A ]], mtype(e.mev), pos)
58
- case e@ NumericMinus (l,r) => numeric_minus(f(l), f(r))(e.aev.asInstanceOf [Numeric [A ]], mtype(e.mev), pos)
59
- case e@ NumericTimes (l,r) => numeric_times(f(l), f(r))(e.aev.asInstanceOf [Numeric [A ]], mtype(e.mev), pos)
54
+ case e@ NumericPlus (l,r) => numeric_plus (f(l), f(r))(e.aev.asInstanceOf [Numeric [A ]], mtype(e.mev), pos)
55
+ case e@ NumericMinus (l,r) => numeric_minus (f(l), f(r))(e.aev.asInstanceOf [Numeric [A ]], mtype(e.mev), pos)
56
+ case e@ NumericTimes (l,r) => numeric_times (f(l), f(r))(e.aev.asInstanceOf [Numeric [A ]], mtype(e.mev), pos)
60
57
case e@ NumericDivide (l,r) => numeric_divide(f(l), f(r))(e.aev.asInstanceOf [Numeric [A ]], mtype(e.mev), pos)
61
58
case _ => super .mirror(e,f)
62
59
}).asInstanceOf [Exp [A ]]
63
-
64
60
}
65
61
66
-
67
62
trait NumericOpsExpOpt extends NumericOpsExp {
68
63
this : PrimitiveOpsExp =>
69
64
@@ -75,6 +70,7 @@ trait NumericOpsExpOpt extends NumericOpsExp {
75
70
}
76
71
override def numeric_minus [T : Numeric : Typ ](lhs : Exp [T ], rhs : Exp [T ])(implicit pos : SourceContext ): Exp [T ] = (lhs,rhs) match {
77
72
case (Const (x), Const (y)) => Const (implicitly[Numeric [T ]].minus(x,y))
73
+ case (x, Const (y)) if y == implicitly[Numeric [T ]].zero => x
78
74
case _ => super .numeric_minus(lhs,rhs)
79
75
}
80
76
override def numeric_times [T : Numeric : Typ ](lhs : Exp [T ], rhs : Exp [T ])(implicit pos : SourceContext ): Exp [T ] = (lhs,rhs) match {
@@ -88,6 +84,8 @@ trait NumericOpsExpOpt extends NumericOpsExp {
88
84
override def numeric_divide [T : Numeric : Typ ](lhs : Exp [T ], rhs : Exp [T ])(implicit pos : SourceContext ): Exp [T ] = (lhs,rhs) match {
89
85
// CAVEAT: Numeric doesn't have .div, Fractional has
90
86
case (Const (x), Const (y)) => Const (implicitly[Numeric [T ]].asInstanceOf [Fractional [T ]].div(x,y))
87
+ case (Const (x), y) if x == implicitly[Numeric [T ]].zero => Const (x)
88
+ case (x, Const (y)) if y == implicitly[Numeric [T ]].one => x
91
89
case _ => super .numeric_divide(lhs,rhs)
92
90
}
93
91
}
@@ -98,10 +96,10 @@ trait ScalaGenNumericOps extends ScalaGenFat {
98
96
import IR ._
99
97
100
98
override def emitNode (sym : Sym [Any ], rhs : Def [Any ]) = rhs match {
101
- case NumericPlus (a,b) => emitValDef(sym, src " $a + $b" )
102
- case NumericMinus (a,b) => emitValDef(sym, src " $a - $b" )
103
- case NumericTimes (a,b) => emitValDef(sym, src " $a * $b" )
104
- case NumericDivide (a,b) => emitValDef(sym, src " $a / $b" )
99
+ case NumericPlus (a,b) => emitValDef(sym, src " $a + $b" )
100
+ case NumericMinus (a,b) => emitValDef(sym, src " $a - $b" )
101
+ case NumericTimes (a,b) => emitValDef(sym, src " $a * $b" )
102
+ case NumericDivide (a,b) => emitValDef(sym, src " $a / $b" )
105
103
case _ => super .emitNode(sym, rhs)
106
104
}
107
105
}
@@ -112,14 +110,10 @@ trait CLikeGenNumericOps extends CLikeGenBase {
112
110
113
111
override def emitNode (sym : Sym [Any ], rhs : Def [Any ]) = {
114
112
rhs match {
115
- case NumericPlus (a,b) =>
116
- emitValDef(sym, src " $a + $b" )
117
- case NumericMinus (a,b) =>
118
- emitValDef(sym, src " $a - $b" )
119
- case NumericTimes (a,b) =>
120
- emitValDef(sym, src " $a * $b" )
121
- case NumericDivide (a,b) =>
122
- emitValDef(sym, src " $a / $b" )
113
+ case NumericPlus (a,b) => emitValDef(sym, src " $a + $b" )
114
+ case NumericMinus (a,b) => emitValDef(sym, src " $a - $b" )
115
+ case NumericTimes (a,b) => emitValDef(sym, src " $a * $b" )
116
+ case NumericDivide (a,b) => emitValDef(sym, src " $a / $b" )
123
117
case _ => super .emitNode(sym, rhs)
124
118
}
125
119
}
0 commit comments