diff --git a/src/common/ArrayOps.scala b/src/common/ArrayOps.scala index 669ff26b..d38392a4 100644 --- a/src/common/ArrayOps.scala +++ b/src/common/ArrayOps.scala @@ -3,68 +3,95 @@ package common import java.io.PrintWriter import internal._ -import scala.reflect.SourceContext +import scala.reflect.{SourceContext, RefinedManifest} +import scala.collection.mutable.{HashMap,Set} trait ArrayOps extends Variables { + type Size = Long + // multiple definitions needed because implicits won't chain // not using infix here because apply doesn't work with infix methods implicit def varToArrayOps[T:Manifest](x: Var[Array[T]]) = new ArrayOpsCls(readVar(x)) implicit def repArrayToArrayOps[T:Manifest](a: Rep[Array[T]]) = new ArrayOpsCls(a) + implicit def repArrayAnyToArrayOps(a: Rep[Array[Any]]) = new ArrayOpsCls(a) implicit def arrayToArrayOps[T:Manifest](a: Array[T]) = new ArrayOpsCls(unit(a)) // substitution for "new Array[T](...)" // TODO: look into overriding __new for arrays object NewArray { - def apply[T:Manifest](n: Rep[Int]) = array_obj_new(n) + def apply[T:Manifest](n: Rep[Size], specializedType: Rep[String] = unit("")) = array_obj_new(n, specializedType) } object Array { - def apply[T:Manifest](xs: Rep[T]*) = array_obj_fromseq(xs) + def apply[T:Manifest](xs: T*) = array_obj_fromseq(xs) } class ArrayOpsCls[T:Manifest](a: Rep[Array[T]]){ - def apply(n: Rep[Int])(implicit pos: SourceContext) = array_apply(a, n) - def update(n: Rep[Int], y: Rep[T])(implicit pos: SourceContext) = array_update(a,n,y) + def apply(n: Rep[Size])(implicit pos: SourceContext) = array_apply(a, n) + def update(n: Rep[Size], y: Rep[T])(implicit pos: SourceContext) = array_update(a,n,y) def length(implicit pos: SourceContext) = array_length(a) def foreach(block: Rep[T] => Rep[Unit])(implicit pos: SourceContext) = array_foreach(a, block) + def filter(f: Rep[T] => Rep[Boolean]) = array_filter(a, f) + def groupBy[B: Manifest](f: Rep[T] => Rep[B]) = array_group_by(a,f) def sort(implicit pos: SourceContext) = array_sort(a) def map[B:Manifest](f: Rep[T] => Rep[B]) = array_map(a,f) def toSeq = array_toseq(a) - def slice(start:Rep[Int], end:Rep[Int]) = array_slice(a,start,end) - } - - def array_obj_new[T:Manifest](n: Rep[Int]): Rep[Array[T]] - def array_obj_fromseq[T:Manifest](xs: Seq[Rep[T]]): Rep[Array[T]] - def array_apply[T:Manifest](x: Rep[Array[T]], n: Rep[Int])(implicit pos: SourceContext): Rep[T] - def array_update[T:Manifest](x: Rep[Array[T]], n: Rep[Int], y: Rep[T])(implicit pos: SourceContext): Rep[Unit] - def array_unsafe_update[T:Manifest](x: Rep[Array[T]], n: Rep[Int], y: Rep[T])(implicit pos: SourceContext): Rep[Unit] - def array_length[T:Manifest](x: Rep[Array[T]])(implicit pos: SourceContext) : Rep[Int] + def sum = array_sum(a) + def zip[B: Manifest](a2: Rep[Array[B]]) = array_zip(a,a2) + def corresponds[B: Manifest](a2: Rep[Array[B]]) = array_corresponds(a,a2) + def mkString(del: String = "") = array_mkString(a,del) + def startsWith[B:Manifest](s2: Rep[Array[B]])(implicit pos: SourceContext) = array_startsWith[T,B](a,s2) + def endsWith[B:Manifest](s2: Rep[Array[B]])(implicit pos: SourceContext) = array_endsWith[T,B](a,s2) + def slice(from: Rep[Size], until: Rep[Size]) = array_slice[T](a, from, until) + def hash = array_hash(a) + def containsSlice(a2: Rep[Array[T]]) = array_containsSlice(a,a2) + def indexOfSlice(a2: Rep[Array[T]], idx: Rep[Size]) = array_indexOfSlice(a,a2, idx) + def compare(a2: Rep[Array[T]]) = array_compare(a,a2) + } + + def array_obj_new[T:Manifest](n: Rep[Size], specializedType: Rep[String] = unit("")): Rep[Array[T]] + def array_obj_fromseq[T:Manifest](xs: Seq[T]): Rep[Array[T]] + def array_apply[T:Manifest](x: Rep[Array[T]], n: Rep[Size])(implicit pos: SourceContext): Rep[T] + def array_update[T:Manifest](x: Rep[Array[T]], n: Rep[Size], y: Rep[T])(implicit pos: SourceContext): Rep[Unit] + def array_unsafe_update[T:Manifest](x: Rep[Array[T]], n: Rep[Size], y: Rep[T])(implicit pos: SourceContext): Rep[Unit] + def array_length[T:Manifest](x: Rep[Array[T]])(implicit pos: SourceContext) : Rep[Size] def array_foreach[T:Manifest](x: Rep[Array[T]], block: Rep[T] => Rep[Unit])(implicit pos: SourceContext): Rep[Unit] - def array_copy[T:Manifest](src: Rep[Array[T]], srcPos: Rep[Int], dest: Rep[Array[T]], destPos: Rep[Int], len: Rep[Int])(implicit pos: SourceContext): Rep[Unit] - def array_unsafe_copy[T:Manifest](src: Rep[Array[T]], srcPos: Rep[Int], dest: Rep[Array[T]], destPos: Rep[Int], len: Rep[Int])(implicit pos: SourceContext): Rep[Unit] + def array_filter[T : Manifest](l: Rep[Array[T]], f: Rep[T] => Rep[Boolean])(implicit pos: SourceContext): Rep[Array[T]] + def array_group_by[T : Manifest, B: Manifest](l: Rep[Array[T]], f: Rep[T] => Rep[B])(implicit pos: SourceContext): Rep[HashMap[B, Array[T]]] def array_sort[T:Manifest](x: Rep[Array[T]])(implicit pos: SourceContext): Rep[Array[T]] def array_map[A:Manifest,B:Manifest](a: Rep[Array[A]], f: Rep[A] => Rep[B]): Rep[Array[B]] def array_toseq[A:Manifest](a: Rep[Array[A]]): Rep[Seq[A]] - def array_slice[A:Manifest](a: Rep[Array[A]], start:Rep[Int], end:Rep[Int]): Rep[Array[A]] + def array_sum[A:Manifest](a: Rep[Array[A]]): Rep[A] + def array_zip[A:Manifest, B: Manifest](a: Rep[Array[A]], a2: Rep[Array[B]]): Rep[Array[(A,B)]] + def array_corresponds[A: Manifest, B: Manifest](a: Rep[Array[A]], a2: Rep[Array[B]]): Rep[Boolean] // limited support for corresponds (tests equality) + def array_mkString[A: Manifest](a: Rep[Array[A]], del: String = ""): Rep[String] + def array_startsWith[A:Manifest, B:Manifest](s1: Rep[Array[A]], s2: Rep[Array[B]])(implicit pos: SourceContext): Rep[Boolean] + def array_endsWith[A:Manifest, B:Manifest](s1: Rep[Array[A]], s2: Rep[Array[B]])(implicit pos: SourceContext): Rep[Boolean] + def array_slice[A: Manifest](a: Rep[Array[A]], from: Rep[Size], until: Rep[Size]): Rep[Array[A]] + def array_hash[A:Manifest](a: Rep[Array[A]]): Rep[Size] + def array_containsSlice[A:Manifest](s1: Rep[Array[A]], s2: Rep[Array[A]])(implicit pos: SourceContext): Rep[Boolean] + def array_indexOfSlice[A:Manifest](s1: Rep[Array[A]], s2: Rep[Array[A]], idx: Rep[Size])(implicit pos: SourceContext): Rep[Size] + def array_compare[A:Manifest](s1: Rep[Array[A]], s2: Rep[Array[A]])(implicit pos: SourceContext): Rep[Int] + def array_copy[T:Manifest](src: Rep[Array[T]], srcPos: Rep[Size], dest: Rep[Array[T]], destPos: Rep[Size], len: Rep[Size])(implicit pos: SourceContext): Rep[Unit] + def array_unsafe_copy[T:Manifest](src: Rep[Array[T]], srcPos: Rep[Size], dest: Rep[Array[T]], destPos: Rep[Size], len: Rep[Size])(implicit pos: SourceContext): Rep[Unit] } -trait ArrayOpsExp extends ArrayOps with EffectExp with VariablesExp { - case class ArrayNew[T:Manifest](n: Exp[Int]) extends Def[Array[T]] { +trait ArrayOpsExp extends ArrayOps with EffectExp with VariablesExp with StructExp with WhileExp with OrderingOps with PrimitiveOps with NumericOps { + case class ArrayNew[T:Manifest](n: Exp[Size], specializedType: Rep[String] = unit("")) extends Def[Array[T]] { val m = manifest[T] } - case class ArrayFromSeq[T:Manifest](xs: Seq[Exp[T]]) extends Def[Array[T]] { + case class ArrayFromSeq[T:Manifest](xs: Seq[T]) extends Def[Array[T]] { val m = manifest[T] } - case class ArrayApply[T:Manifest](a: Exp[Array[T]], n: Exp[Int]) extends Def[T] - case class ArrayUpdate[T:Manifest](a: Exp[Array[T]], n: Exp[Int], y: Exp[T]) extends Def[Unit] - case class ArrayLength[T:Manifest](a: Exp[Array[T]]) extends Def[Int] { + case class ArrayApply[T:Manifest](a: Exp[Array[T]], n: Exp[Size]) extends Def[T] + case class ArrayUpdate[T:Manifest](a: Exp[Array[T]], n: Exp[Size], y: Exp[T]) extends Def[Unit] + case class ArrayLength[T:Manifest](a: Exp[Array[T]]) extends Def[Size] { val m = manifest[T] } case class ArrayForeach[T](a: Exp[Array[T]], x: Sym[T], block: Block[Unit]) extends Def[Unit] - case class ArrayCopy[T:Manifest](src: Exp[Array[T]], srcPos: Exp[Int], dest: Exp[Array[T]], destPos: Exp[Int], len: Exp[Int]) extends Def[Unit] { - val m = manifest[T] - } + case class ArrayFilter[T : Manifest](l: Exp[Array[T]], x: Sym[T], block: Block[Boolean]) extends Def[Array[T]] + case class ArrayGroupBy[T: Manifest, B: Manifest](l: Exp[Array[T]], x: Sym[T], block: Block[B]) extends Def[HashMap[B, Array[T]]] case class ArraySort[T:Manifest](x: Exp[Array[T]]) extends Def[Array[T]] { val m = manifest[T] } @@ -72,21 +99,49 @@ trait ArrayOpsExp extends ArrayOps with EffectExp with VariablesExp { val array = NewArray[B](a.length) } case class ArrayToSeq[A:Manifest](x: Exp[Array[A]]) extends Def[Seq[A]] - case class ArraySlice[A:Manifest](a: Exp[Array[A]], s:Exp[Int], e:Exp[Int]) extends Def[Array[A]] - - def array_obj_new[T:Manifest](n: Exp[Int]) = reflectMutable(ArrayNew(n)) - def array_obj_fromseq[T:Manifest](xs: Seq[Exp[T]]) = /*reflectMutable(*/ ArrayFromSeq(xs) /*)*/ - def array_apply[T:Manifest](x: Exp[Array[T]], n: Exp[Int])(implicit pos: SourceContext): Exp[T] = ArrayApply(x, n) - def array_update[T:Manifest](x: Exp[Array[T]], n: Exp[Int], y: Exp[T])(implicit pos: SourceContext) = reflectWrite(x)(ArrayUpdate(x,n,y)) - def array_unsafe_update[T:Manifest](x: Rep[Array[T]], n: Rep[Int], y: Rep[T])(implicit pos: SourceContext) = ArrayUpdate(x,n,y) - def array_length[T:Manifest](a: Exp[Array[T]])(implicit pos: SourceContext) : Rep[Int] = ArrayLength(a) + case class ArraySum[A:Manifest](x: Exp[Array[A]]) extends Def[A] + case class ArrayZip[A:Manifest, B: Manifest](x: Exp[Array[A]], x2: Exp[Array[B]]) extends Def[Array[(A,B)]] + case class ArrayCorresponds[A:Manifest, B: Manifest](x: Exp[Array[A]], x2: Exp[Array[B]]) extends Def[Boolean] + case class ArrayMkString[A:Manifest](a: Exp[Array[A]], b: String = "") extends Def[String] + case class ArrayStartsWith[A:Manifest,B:Manifest](s1: Exp[Array[A]], s2: Exp[Array[B]]) extends Def[Boolean] + case class ArrayEndsWith[A:Manifest,B:Manifest](s1: Exp[Array[A]], s2: Exp[Array[B]]) extends Def[Boolean] + case class ArraySlice[A:Manifest](a: Exp[Array[A]], from: Exp[Size], until: Exp[Size]) extends Def[Array[A]] { + val m = manifest[A] + } + case class ArrayContainsSlice[A:Manifest](s1: Exp[Array[A]], s2: Exp[Array[A]]) extends Def[Boolean] + case class ArrayIndexOfSlice[A:Manifest](s1: Exp[Array[A]], s2: Exp[Array[A]], idx: Rep[Size]) extends Def[Size] { + val m = manifest[A] + } + case class ArrayCompare[A:Manifest](s1: Exp[Array[A]], s2: Exp[Array[A]]) extends Def[Int] { + val m = manifest[A] + } + case class ArrayHash[A:Manifest](a: Exp[Array[A]]) extends Def[Size] + case class ArrayCopy[T:Manifest](src: Exp[Array[T]], srcPos: Exp[Size], dest: Exp[Array[T]], destPos: Exp[Size], len: Exp[Size]) extends Def[Unit] { + val m = manifest[T] + } + + def array_obj_new[T:Manifest](n: Exp[Size], specializedType: Rep[String] = unit("")) = reflectMutable(ArrayNew(n, specializedType)) + def array_obj_fromseq[T:Manifest](xs: Seq[T]) = /*reflectMutable(*/ ArrayFromSeq(xs) /*)*/ + def array_apply[T:Manifest](x: Exp[Array[T]], n: Exp[Size])(implicit pos: SourceContext): Exp[T] = ArrayApply(x, n) + def array_update[T:Manifest](x: Exp[Array[T]], n: Exp[Size], y: Exp[T])(implicit pos: SourceContext) = reflectWrite(x)(ArrayUpdate(x,n,y)) + def array_unsafe_update[T:Manifest](x: Rep[Array[T]], n: Rep[Size], y: Rep[T])(implicit pos: SourceContext) = ArrayUpdate(x,n,y) + def array_length[T:Manifest](a: Exp[Array[T]])(implicit pos: SourceContext) : Rep[Size] = ArrayLength(a) def array_foreach[T:Manifest](a: Exp[Array[T]], block: Exp[T] => Exp[Unit])(implicit pos: SourceContext): Exp[Unit] = { val x = fresh[T] val b = reifyEffects(block(x)) reflectEffect(ArrayForeach(a, x, b), summarizeEffects(b).star) } - def array_copy[T:Manifest](src: Exp[Array[T]], srcPos: Exp[Int], dest: Exp[Array[T]], destPos: Exp[Int], len: Exp[Int])(implicit pos: SourceContext) = reflectWrite(dest)(ArrayCopy(src,srcPos,dest,destPos,len)) - def array_unsafe_copy[T:Manifest](src: Exp[Array[T]], srcPos: Exp[Int], dest: Exp[Array[T]], destPos: Exp[Int], len: Exp[Int])(implicit pos: SourceContext) = ArrayCopy(src,srcPos,dest,destPos,len) + def array_filter[T : Manifest](l: Exp[Array[T]], f: Exp[T] => Exp[Boolean])(implicit pos: SourceContext) = { + val a = fresh[T] + val b = reifyEffects(f(a)) + reflectEffect(ArrayFilter(l, a, b), summarizeEffects(b).star) + } + def array_group_by[T : Manifest, B: Manifest](l: Exp[Array[T]], f: Exp[T] => Exp[B])(implicit pos: SourceContext) = { + val a = fresh[T] + val b = reifyEffects(f(a)) + reflectEffect(ArrayGroupBy(l, a, b), summarizeEffects(b).star) + } + def array_sort[T:Manifest](x: Exp[Array[T]])(implicit pos: SourceContext) = ArraySort(x) def array_map[A:Manifest,B:Manifest](a: Exp[Array[A]], f: Exp[A] => Exp[B]) = { val x = fresh[A] @@ -94,8 +149,20 @@ trait ArrayOpsExp extends ArrayOps with EffectExp with VariablesExp { reflectEffect(ArrayMap(a, x, b), summarizeEffects(b)) } def array_toseq[A:Manifest](a: Exp[Array[A]]) = ArrayToSeq(a) - def array_slice[A:Manifest](a: Rep[Array[A]], start:Rep[Int], end:Rep[Int]) = ArraySlice(a,start,end) - + def array_sum[A:Manifest](a: Exp[Array[A]]) = reflectEffect(ArraySum(a)) + def array_zip[A:Manifest, B: Manifest](a: Exp[Array[A]], a2: Exp[Array[B]]) = reflectEffect(ArrayZip(a,a2)) + def array_corresponds[A: Manifest, B: Manifest](a: Rep[Array[A]], a2: Rep[Array[B]]) = reflectEffect(ArrayCorresponds(a,a2)) + def array_mkString[A: Manifest](a: Rep[Array[A]], del: String = "") = reflectEffect(ArrayMkString(a, del)) + def array_startsWith[A:Manifest,B:Manifest](s1: Exp[Array[A]], s2: Exp[Array[B]])(implicit pos: SourceContext) = ArrayStartsWith(s1,s2) + def array_endsWith[A:Manifest,B:Manifest](s1: Exp[Array[A]], s2: Exp[Array[B]])(implicit pos: SourceContext) = ArrayEndsWith(s1,s2) + def array_slice[A: Manifest](a: Exp[Array[A]], from: Exp[Size], until: Exp[Size]) = reflectEffect(ArraySlice[A](a,from,until)) + def array_hash[A:Manifest](a: Rep[Array[A]]) = reflectEffect(ArrayHash(a)) + def array_containsSlice[A:Manifest](s1: Exp[Array[A]], s2: Exp[Array[A]])(implicit pos: SourceContext) = ArrayContainsSlice(s1,s2) + def array_indexOfSlice[A:Manifest](s1: Exp[Array[A]], s2: Exp[Array[A]], idx: Exp[Size])(implicit pos: SourceContext) = ArrayIndexOfSlice(s1,s2,idx) + def array_compare[A:Manifest](s1: Exp[Array[A]], s2: Exp[Array[A]])(implicit pos: SourceContext)= ArrayCompare(s1,s2) + def array_copy[T:Manifest](src: Exp[Array[T]], srcPos: Exp[Size], dest: Exp[Array[T]], destPos: Exp[Size], len: Exp[Size])(implicit pos: SourceContext) = reflectWrite(dest)(ArrayCopy(src,srcPos,dest,destPos,len)) + def array_unsafe_copy[T:Manifest](src: Exp[Array[T]], srcPos: Exp[Size], dest: Exp[Array[T]], destPos: Exp[Size], len: Exp[Size])(implicit pos: SourceContext) = ArrayCopy(src,srcPos,dest,destPos,len) + ////////////// // mirroring @@ -104,30 +171,44 @@ trait ArrayOpsExp extends ArrayOps with EffectExp with VariablesExp { case ArrayLength(x) => array_length(f(x)) case e@ArraySort(x) => array_sort(f(x))(e.m,pos) case e@ArrayCopy(a,ap,d,dp,l) => toAtom(ArrayCopy(f(a),f(ap),f(d),f(dp),f(l))(e.m))(mtype(manifest[A]),pos) - case Reflect(e@ArrayNew(n), u, es) => reflectMirrored(Reflect(ArrayNew(f(n))(e.m), mapOver(f,u), f(es)))(mtype(manifest[A]), pos) - case Reflect(e@ArrayLength(x), u, es) => reflectMirrored(Reflect(ArrayLength(f(x))(e.m), mapOver(f,u), f(es)))(mtype(manifest[A]), pos) - case Reflect(ArrayApply(l,r), u, es) => reflectMirrored(Reflect(ArrayApply(f(l),f(r))(mtype(manifest[A])), mapOver(f,u), f(es)))(mtype(manifest[A]), pos) - case Reflect(e@ArraySort(x), u, es) => reflectMirrored(Reflect(ArraySort(f(x))(e.m), mapOver(f,u), f(es)))(mtype(manifest[A]), pos) - case Reflect(ArrayUpdate(l,i,r), u, es) => reflectMirrored(Reflect(ArrayUpdate(f(l),f(i),f(r)), mapOver(f,u), f(es)))(mtype(manifest[A]), pos) - case Reflect(e@ArrayCopy(a,ap,d,dp,l), u, es) => reflectMirrored(Reflect(ArrayCopy(f(a),f(ap),f(d),f(dp),f(l))(e.m), mapOver(f,u), f(es)))(mtype(manifest[A]), pos) + case ArrayCompare(a1, a2) => array_compare(f(a1), f(a2)) + case ArrayStartsWith(a1, a2) => array_startsWith(f(a1), f(a2)) + case ArrayEndsWith(a1, a2) => array_startsWith(f(a1), f(a2)) + case ArrayContainsSlice(a1, a2) => array_startsWith(f(a1), f(a2)) + case Reflect(e@ArraySlice(arr,idx1,idx2), u, es) => reflectMirrored(Reflect(ArraySlice(f(arr), f(idx1), f(idx2))(e.m), mapOver(f,u),f(es)))(mtype(manifest[A])) + case Reflect(e@ArrayIndexOfSlice(arr1,arr2,idx), u, es) => reflectMirrored(Reflect(ArrayIndexOfSlice(f(arr1), f(arr2), f(idx))(e.m), mapOver(f,u),f(es)))(mtype(manifest[A])) + case Reflect(e@ArrayNew(n, sType), u, es) => reflectMirrored(Reflect(ArrayNew(f(n), sType)(e.m), mapOver(f,u), f(es)))(mtype(manifest[A])) + case Reflect(e@ArrayLength(x), u, es) => reflectMirrored(Reflect(ArrayLength(f(x))(e.m), mapOver(f,u), f(es)))(mtype(manifest[A])) + case Reflect(ArrayApply(l,r), u, es) => reflectMirrored(Reflect(ArrayApply(f(l),f(r))(mtype(manifest[A])), mapOver(f,u), f(es)))(mtype(manifest[A])) + case Reflect(e@ArraySort(x), u, es) => reflectMirrored(Reflect(ArraySort(f(x))(e.m), mapOver(f,u), f(es)))(mtype(manifest[A])) + case Reflect(ArrayUpdate(l,i,r), u, es) => reflectMirrored(Reflect(ArrayUpdate(f(l),f(i),f(r)), mapOver(f,u), f(es)))(mtype(manifest[A])) + case Reflect(e@ArrayCopy(a,ap,d,dp,l), u, es) => reflectMirrored(Reflect(ArrayCopy(f(a),f(ap),f(d),f(dp),f(l))(e.m), mapOver(f,u), f(es)))(mtype(manifest[A])) case _ => super.mirror(e,f) }).asInstanceOf[Exp[A]] // why?? override def syms(e: Any): List[Sym[Any]] = e match { case ArrayForeach(a, x, body) => syms(a):::syms(body) case ArrayMap(a, x, body) => syms(a):::syms(body) + case ArrayFilter(a, x, body) => syms(a):::syms(body) + case ArrayGroupBy(a, x, body) => syms(a):::syms(body) + case ArraySlice(a,idx1,idx2) => syms(a) case _ => super.syms(e) } override def boundSyms(e: Any): List[Sym[Any]] = e match { case ArrayForeach(a, x, body) => x :: effectSyms(body) case ArrayMap(a, x, body) => x :: effectSyms(body) + case ArrayFilter(a, x, body) => x :: effectSyms(body) + case ArrayGroupBy(a, x, body) => x::effectSyms(body) + case ArraySlice(a,idx1,idx2) => effectSyms(a) case _ => super.boundSyms(e) } override def symsFreq(e: Any): List[(Sym[Any], Double)] = e match { case ArrayForeach(a, x, body) => freqNormal(a):::freqHot(body) case ArrayMap(a, x, body) => freqNormal(a):::freqHot(body) + case ArrayFilter(a, x, body) => freqNormal(a):::freqHot(body) + case ArrayGroupBy(a, x, body) => freqNormal(a):::freqHot(body) case _ => super.symsFreq(e) } @@ -135,22 +216,8 @@ trait ArrayOpsExp extends ArrayOps with EffectExp with VariablesExp { trait ArrayOpsExpOpt extends ArrayOpsExp { - /** - * @author Alen Stojanov (astojanov@inf.ethz.ch) - */ - override def array_length[T:Manifest](a: Exp[Array[T]])(implicit pos: SourceContext) : Rep[Int] = a match { - case Def(ArrayNew(n: Exp[Int])) => n - case Def(ArrayFromSeq(xs)) => Const(xs.size) - case Def(ArraySort(x)) => array_length(x) - case Def(ArrayMap(x, _, _)) => array_length(x) - case Def(Reflect(ArrayNew(n: Exp[Int]), _, _)) => n - case Def(Reflect(ArrayFromSeq(xs), _, _)) => Const(xs.size) - case Def(Reflect(ArraySort(x), _, _)) => array_length(x) - case Def(Reflect(ArrayMap(x, _, _), _, _)) => array_length(x) - case _ => super.array_length(a) - } - override def array_apply[T:Manifest](x: Exp[Array[T]], n: Exp[Int])(implicit pos: SourceContext): Exp[T] = { + override def array_apply[T:Manifest](x: Exp[Array[T]], n: Exp[Size])(implicit pos: SourceContext): Exp[T] = { if (context ne null) { // find the last modification of array x // if it is an assigment at index n, just return the last value assigned @@ -168,7 +235,7 @@ trait ArrayOpsExpOpt extends ArrayOpsExp { } } - override def array_update[T:Manifest](x: Exp[Array[T]], n: Exp[Int], y: Exp[T])(implicit pos: SourceContext) = { + override def array_update[T:Manifest](x: Exp[Array[T]], n: Exp[Size], y: Exp[T])(implicit pos: SourceContext) = { if (context ne null) { // find the last modification of array x // if it is an assigment at index n with the same value, just do nothing @@ -186,6 +253,8 @@ trait ArrayOpsExpOpt extends ArrayOpsExp { } } + + } @@ -204,7 +273,10 @@ trait ScalaGenArrayOps extends BaseGenArrayOps with ScalaGenBase { val ARRAY_LITERAL_MAX_SIZE = 1000 override def emitNode(sym: Sym[Any], rhs: Def[Any]) = rhs match { - case a@ArrayNew(n) => emitValDef(sym, src"new Array[${remap(a.m)}]($n)") + case a@ArrayNew(n, sType) => { + val arrType = if (quote(sType) != "\"\"") quote(sType).replaceAll("\"","") else remap(a.m) + emitValDef(sym, src"new Array[$arrType]($n)") + } case e@ArrayFromSeq(xs) => { emitData(sym, xs) emitValDef(sym, @@ -220,49 +292,77 @@ trait ScalaGenArrayOps extends BaseGenArrayOps with ScalaGenBase { "{import scala.io.Source;(Source.fromFile(\"" + symDataPath(sym) + "\").getLines.map{Integer.parseInt(_)}).toArray}" } else { - "Array(" + (xs map quote).mkString(",") + ")" + src"Array($xs)" } ) } case ArrayApply(x,n) => emitValDef(sym, src"$x($n)") - case ArrayUpdate(x,n,y) => emitValDef(sym, src"$x($n) = $y") + case ArrayUpdate(x,n,y) => emitAssignment(sym, src"$x($n)", quote(y)) case ArrayLength(x) => emitValDef(sym, src"$x.length") - case ArrayForeach(a,x,block) => - gen"""val $sym = $a.foreach{ - |$x => + case ArrayForeach(a,x,block) => + stream.println(quote(a) + ".foreach{") + gen"""$x => |${nestedBlock(block)} |$block |}""" - case ArrayCopy(src,srcPos,dest,destPos,len) => emitValDef(sym, src"System.arraycopy($src,$srcPos,$dest,$destPos,$len)") - case a@ArraySort(x) => - gen"""val $sym = { - |val d = new Array[${remap(a.m)}]($x.length) - |System.arraycopy($x, 0, d, 0, $x.length) - |scala.util.Sorting.quickSort(d) - |d - |}""" - case n@ArrayMap(a,x,blk) => - gen"""// workaround for refinedManifest problem - |val $sym = { - |val out = ${n.array} - |val in = $a - |var i = 0 - |while (i < in.length) { - |val $x = in(i) - |${nestedBlock(blk)} - |out(i) = $blk - |i += 1 - |} - |out - |}""" - - // stream.println("val " + quote(sym) + " = " + quote(a) + ".map{") - // stream.println(quote(x) + " => ") - // emitBlock(blk) - // stream.println(quote(getBlockResult(blk))) - // stream.println("}") + case a@ArraySort(x) => + val strWriter = new java.io.StringWriter + val localStream = new PrintWriter(strWriter); + withStream(localStream) { + gen"""{ + |val d = new Array[${a.m}]($x.length) + |System.arraycopy($x, 0, d, 0, $x.length) + |scala.util.Sorting.quickSort(d) + |d + |}""" + } + emitValDef(sym, strWriter.toString) + case n@ArrayMap(a,x,blk) => + val strWriter = new java.io.StringWriter + val localStream = new PrintWriter(strWriter); + withStream(localStream) { + //stream.println("/* workaround for refinedManifest problem */") + gen"""{ + |val out = ${n.array} + |val in = $a + |var i = 0 + |while (i < in.length) { + |val $x = in(i) + |${nestedBlock(blk)} + |out(i) = $blk + |i += 1 + |} + |out + |}""" + } + emitValDef(sym, strWriter.toString) + case ArrayFilter(a,x,blk) => + emitValDef(sym, quote(a) + ".filter(" + quote(x) + "=> {") + emitBlock(blk) + emitBlockResult(blk) + stream.println("})") + case ArrayGroupBy(a,x,blk) => + emitValDef(sym, quote(a) + ".groupBy(" + quote(x) + "=> {") + emitBlock(blk) + emitBlockResult(blk) + stream.println("})") + case ArraySum(a) => emitValDef(sym, quote(a) + ".sum") case ArrayToSeq(a) => emitValDef(sym, src"$a.toSeq") - case ArraySlice(a,s,e) => emitValDef(sym, src"$a.slice($s,$e)") + case ArrayZip(a,a2) => emitValDef(sym, src"$a zip $a2") + case ArrayCorresponds(a,a2) => emitValDef(sym, src"$a.corresponds($a2){_==_}") + case ArrayMkString(a, del) => + if (del != "") + emitValDef(sym, src"$a.mkString($del)") + else + emitValDef(sym, src"$a.mkString") + case ArrayStartsWith(a,a2) => emitValDef(sym, src"$a.startsWith($a2)") + case ArrayEndsWith(a,a2) => emitValDef(sym, src"$a.endsWith($a2)") + case ArraySlice(a,from,until) => emitValDef(sym, src"$a.slice($from,$until)") + case ArrayHash(a) => emitValDef(sym, src"$a.foldLeft(0) { (cnt,x) => cnt + x.## }") + case ArrayContainsSlice(a,a2) => emitValDef(sym, src"$a.containsSlice($a2)") + case ArrayIndexOfSlice(a,a2,idx) => emitValDef(sym, src"$a.indexOfSlice($a2,$idx)") + case ArrayCompare(a,a2) => emitValDef(sym, "(" + quote(a) + ".zip(" + quote(a2) + ")).foldLeft(0){ (res, elem) => if (res == 0) elem._1 - elem._2 else res}") + case ArrayCopy(src,srcPos,dest,destPos,len) => emitValDef(sym, src"System.arraycopy($src,$srcPos,$dest,$destPos,$len)") case _ => super.emitNode(sym, rhs) } } @@ -273,10 +373,9 @@ trait CLikeGenArrayOps extends BaseGenArrayOps with CLikeGenBase { override def emitNode(sym: Sym[Any], rhs: Def[Any]) = { rhs match { - case ArrayLength(x) => emitValDef(sym, src"sizeof($x)/sizeof(*$x)") // WARN: statically allocated elements only - case ArrayApply(x,n) => emitValDef(sym, src"$x[$n]") + case ArrayLength(x) => emitValDef(sym, src"$x.length") + case ArrayApply(x,n) => emitValDef(sym, src"$x.apply($n)") case ArrayUpdate(x,n,y) => stream.println(src"$x.update($n,$y);") - case ArraySlice(x,s,e) => val tp=remap(x.tp.typeArguments(0)); emitValDef(sym, src"({ size_t sz=sizeof("+tp+")*($e-$s); "+tp+"* r = ("+tp+"*)malloc(sz); memcpy(r,(("+tp+"*)$x)+$s,sz); r; })") case _ => super.emitNode(sym, rhs) } } @@ -284,17 +383,31 @@ trait CLikeGenArrayOps extends BaseGenArrayOps with CLikeGenBase { trait CudaGenArrayOps extends CudaGenBase with CLikeGenArrayOps trait OpenCLGenArrayOps extends OpenCLGenBase with CLikeGenArrayOps -trait CGenArrayOps extends CGenBase with BaseGenArrayOps { - val IR: ArrayOpsExp - import IR._ - - override def emitNode(sym: Sym[Any], rhs: Def[Any]) = { - rhs match { - case ArrayLength(x) => emitValDef(sym, quote(x) + "->length") - case ArrayApply(x,n) => emitValDef(sym, quote(x) + "->apply(" + quote(n) + ")") - case ArrayUpdate(x,n,y) => stream.println(quote(x) + "->update(" + quote(n) + "," + quote(y) + ");") - case _ => super.emitNode(sym, rhs) - } - } +trait CGenArrayOps extends CGenEffect with CGenStruct { + val IR: ArrayOpsExp + import IR._ + + override def emitNode(sym: Sym[Any], rhs: Def[Any]) = { + rhs match { + case a@ArrayNew(n, sType) => { + val arrType = if (quote(sType) != "\"\"") quote(sType) else remap(a.m) + stream.println(arrType + "* " + quote(sym) + " = " + getMemoryAllocString(quote(n), arrType)) + } + case ArrayForeach(a,x,blk) => { + stream.println("{") + stream.println("int i = 0;") + stream.println("for ( i = 0 ; i < " + quote(findInitSymbol(a)) + "Size; i += 1) {") + emitBlock(blk) + emitBlockResult(blk) + stream.println("}") + stream.println("};") + } + case ArrayApply(x,n) => emitValDef(sym, quote(x) + "[" + quote(n) + "]") + case ArrayUpdate(x,n,y) => stream.println(quote(x) + "[" + quote(n) + "] = " + quote(y) + ";") + case ArrayCopy(src,s1,dest,s2,len) => + stream.println("memcpy(" + quote(dest) + "," + quote(src) + "," + quote(len) + ");") + stream.println(quote(dest) + "[" + quote(len) + "] = '\\0';") + case _ => super.emitNode(sym, rhs) + } + } } - diff --git a/src/common/BooleanOps.scala b/src/common/BooleanOps.scala index 75128bc3..f3f7f09e 100644 --- a/src/common/BooleanOps.scala +++ b/src/common/BooleanOps.scala @@ -3,6 +3,7 @@ package common import java.io.PrintWriter import scala.reflect.SourceContext +import scala.lms.internal._ trait LiftBoolean { this: Base => @@ -10,93 +11,145 @@ trait LiftBoolean { implicit def boolToBoolRep(b: Boolean) = unit(b) } -trait BooleanOps extends Variables { +trait BooleanOps extends Variables with Expressions { def infix_unary_!(x: Rep[Boolean])(implicit pos: SourceContext) = boolean_negate(x) - def infix_&&(lhs: Rep[Boolean], rhs: =>Rep[Boolean])(implicit pos: SourceContext) = boolean_and(lhs,rhs) - def infix_||(lhs: Rep[Boolean], rhs: =>Rep[Boolean])(implicit pos: SourceContext) = boolean_or(lhs,rhs) - - // TODO: short-circuit by default + def infix_&&(lhs: Rep[Boolean], rhs: => Rep[Boolean])(implicit pos: SourceContext) = boolean_and(lhs,rhs) + def infix_&&(lhs: Boolean, rhs: => Rep[Boolean])(implicit pos: SourceContext): Exp[Boolean] = { + if (lhs == true) rhs.asInstanceOf[Exp[Boolean]] + else Const(false) + } + def infix_||(lhs: Rep[Boolean], rhs: => Rep[Boolean])(implicit pos: SourceContext) = boolean_or(lhs,rhs) + def infix_||(lhs: Boolean, rhs: => Rep[Boolean])(implicit pos: SourceContext): Exp[Boolean] = { + if (lhs == true) Const(true) + else rhs.asInstanceOf[Exp[Boolean]] + } def boolean_negate(lhs: Rep[Boolean])(implicit pos: SourceContext): Rep[Boolean] - def boolean_and(lhs: Rep[Boolean], rhs: Rep[Boolean])(implicit pos: SourceContext): Rep[Boolean] - def boolean_or(lhs: Rep[Boolean], rhs: Rep[Boolean])(implicit pos: SourceContext): Rep[Boolean] + def boolean_and(lhs: Rep[Boolean], rhs: => Rep[Boolean])(implicit pos: SourceContext): Rep[Boolean] + def boolean_or(lhs: Rep[Boolean], rhs: => Rep[Boolean])(implicit pos: SourceContext): Rep[Boolean] } -trait BooleanOpsExp extends BooleanOps with EffectExp { +trait BooleanOpsExp extends BooleanOps with BaseExp with EffectExp { case class BooleanNegate(lhs: Exp[Boolean]) extends Def[Boolean] - case class BooleanAnd(lhs: Exp[Boolean], rhs: Exp[Boolean]) extends Def[Boolean] - case class BooleanOr(lhs: Exp[Boolean], rhs: Exp[Boolean]) extends Def[Boolean] + case class BooleanAnd(lhs: Exp[Boolean], rhs: Block[Boolean]) extends Def[Boolean] { + val c = fresh[Boolean] // used in c code generation + } + case class BooleanOr(lhs: Exp[Boolean], rhs: Block[Boolean]) extends Def[Boolean] { + val c = fresh[Boolean] // used in c code generation + } def boolean_negate(lhs: Exp[Boolean])(implicit pos: SourceContext) : Exp[Boolean] = BooleanNegate(lhs) - def boolean_and(lhs: Exp[Boolean], rhs: Exp[Boolean])(implicit pos: SourceContext) : Exp[Boolean] = BooleanAnd(lhs,rhs) - def boolean_or(lhs: Exp[Boolean], rhs: Exp[Boolean])(implicit pos: SourceContext) : Exp[Boolean] = BooleanOr(lhs,rhs) + def boolean_and(lhs: Exp[Boolean], frhs: => Exp[Boolean])(implicit pos: SourceContext) : Exp[Boolean] = { + lhs match { + case x@Const(false) => x + case x@Const(true) => frhs + case _ => { + val rhs = reifyEffects(frhs) + BooleanAnd(lhs,rhs) + } + } + } + def boolean_or(lhs: Exp[Boolean], frhs: => Exp[Boolean])(implicit pos: SourceContext) : Exp[Boolean] = { + lhs match { + case x@Const(true) => x + case x@Const(false) => frhs + case _ => { + val rhs = reifyEffects(frhs) + BooleanOr(lhs,rhs) + } + } + } override def mirror[A:Manifest](e: Def[A], f: Transformer)(implicit pos: SourceContext): Exp[A] = (e match { case BooleanNegate(x) => boolean_negate(f(x)) - case BooleanAnd(x,y) => boolean_and(f(x),f(y)) - case BooleanOr(x,y) => boolean_or(f(x),f(y)) - - case Reflect(BooleanNegate(x), u, es) => reflectMirrored(Reflect(BooleanNegate(f(x)), mapOver(f,u), f(es)))(mtype(manifest[A]), pos) - case Reflect(BooleanAnd(x,y), u, es) => reflectMirrored(Reflect(BooleanAnd(f(x),f(y)), mapOver(f,u), f(es)))(mtype(manifest[A]), pos) - case Reflect(BooleanOr(x,y), u, es) => reflectMirrored(Reflect(BooleanOr(f(x),f(y)), mapOver(f,u), f(es)))(mtype(manifest[A]), pos) + case BooleanAnd(x,y) => toAtom(BooleanAnd(f(x),f(y))) + case BooleanOr(x,y) => toAtom(BooleanOr(f(x),f(y))) case _ => super.mirror(e, f) }).asInstanceOf[Exp[A]] // why?? -} + + override def syms(e: Any): List[Sym[Any]] = e match { + case BooleanAnd(lhs,rhs) => syms(lhs):::syms(rhs) + case BooleanOr(lhs,rhs) => syms(lhs):::syms(rhs) + case _ => super.syms(e) + } + override def boundSyms(e: Any): List[Sym[Any]] = e match { + case BooleanAnd(lhs,rhs) => effectSyms(lhs) ::: effectSyms(rhs) + case BooleanOr(lhs,rhs) => effectSyms(lhs) ::: effectSyms(rhs) + case _ => super.boundSyms(e) + } -/** - * @author Alen Stojanov (astojanov@inf.ethz.ch) - */ -trait BooleanOpsExpOpt extends BooleanOpsExp { + override def symsFreq(e: Any): List[(Sym[Any], Double)] = e match { + case BooleanAnd(a, x) => freqHot(a):::freqCold(x) + case BooleanOr(a, x) => freqHot(a):::freqCold(x) + case _ => super.symsFreq(e) + } + +} + +trait BooleanOpsExpOpt extends BooleanOpsExp { override def boolean_negate(lhs: Exp[Boolean])(implicit pos: SourceContext) = lhs match { case Def(BooleanNegate(x)) => x - case Const(a) => Const(!a) case _ => super.boolean_negate(lhs) } - - override def boolean_and(lhs: Exp[Boolean], rhs: Exp[Boolean])(implicit pos: SourceContext) : Exp[Boolean] = { - (lhs, rhs) match { - case (Const(false), _) => Const(false) - case (_, Const(false)) => Const(false) - case (Const(true), x) => x - case (x, Const(true)) => x - case _ => super.boolean_and(lhs, rhs) - } - } - - override def boolean_or(lhs: Exp[Boolean], rhs: Exp[Boolean])(implicit pos: SourceContext) : Exp[Boolean] = { - (lhs, rhs) match { - case (Const(false), x) => x - case (x, Const(false)) => x - case (Const(true), _) => Const(true) - case (_, Const(true)) => Const(true) - case _ => super.boolean_or(lhs, rhs) - } - } } -trait ScalaGenBooleanOps extends ScalaGenBase { +trait ScalaGenBooleanOps extends ScalaGenBase with GenericNestedCodegen { val IR: BooleanOpsExp import IR._ override def emitNode(sym: Sym[Any], rhs: Def[Any]) = rhs match { case BooleanNegate(b) => emitValDef(sym, src"!$b") - case BooleanAnd(lhs,rhs) => emitValDef(sym, src"$lhs && $rhs") - case BooleanOr(lhs,rhs) => emitValDef(sym, src"$lhs || $rhs") + case BooleanAnd(lhs,rhs) => + val strWriter = new java.io.StringWriter + val localStream = new PrintWriter(strWriter); + withStream(localStream) { + gen"""if ($lhs == true) { + |${nestedBlock(rhs)} + |$rhs + |} else false""" + } + emitValDef(sym, strWriter.toString) + case BooleanOr(lhs,rhs) => + val strWriter = new java.io.StringWriter + val localStream = new PrintWriter(strWriter); + withStream(localStream) { + gen"""if ($lhs == false) { + |${nestedBlock(rhs)} + |$rhs + |} else true""" + } + emitValDef(sym, strWriter.toString) case _ => super.emitNode(sym,rhs) } } -trait CLikeGenBooleanOps extends CLikeGenBase { +trait CLikeGenBooleanOps extends CLikeGenBase with GenericNestedCodegen { val IR: BooleanOpsExp import IR._ - override def emitNode(sym: Sym[Any], rhs: Def[Any]) = rhs match { - case BooleanNegate(b) => emitValDef(sym, src"!$b") - case BooleanAnd(lhs,rhs) => emitValDef(sym, src"$lhs && $rhs") - case BooleanOr(lhs,rhs) => emitValDef(sym, src"$lhs || $rhs") - case _ => super.emitNode(sym,rhs) + override def emitNode(sym: Sym[Any], rhs: Def[Any]) = { + rhs match { + case BooleanNegate(b) => emitValDef(sym, src"!$b") + case b@BooleanAnd(lhs,rhs) => { + emitValDef(b.c, quote(lhs)) + stream.println("if (" + quote(lhs) + ") {") + emitBlock(rhs) + stream.println(quote(b.c) + " = " + quote(getBlockResult(rhs)) + ";") + stream.println("}") + emitValDef(sym, quote(b.c)) + } + case b@BooleanOr(lhs,rhs) => { + emitValDef(b.c, quote(lhs)) + stream.println("if (" + quote(lhs) + " == false) {") + emitBlock(rhs) + stream.println(quote(b.c) + " = " + quote(getBlockResult(rhs)) + ";") + stream.println("}") + emitValDef(sym, quote(b.c)) + } + case _ => super.emitNode(sym,rhs) + } } } diff --git a/src/common/DSLBase.scala b/src/common/DSLBase.scala new file mode 100644 index 00000000..ea02712c --- /dev/null +++ b/src/common/DSLBase.scala @@ -0,0 +1,129 @@ +package scala.lms +package common + +trait DSLBase extends BaseExp with UncheckedOps { + // keep track of top level functions + abstract class TopLevel(n: String) { + val name = n; + } + + val rec = new scala.collection.mutable.HashMap[String, TopLevel] + + case class TopLevel1 [A1, B](n: String, mA1: Manifest[A1], mB: Manifest[B], f: (Rep[A1]) => Rep[B]) extends TopLevel(n) + case class TopLevel2 [A1, A2, B](n: String, mA1: Manifest[A1], mA2: Manifest[A2], mB: Manifest[B], f: (Rep[A1], Rep[A2]) => Rep[B]) extends TopLevel(n) + case class TopLevel3 [A1, A2, A3, B](n: String, mA1: Manifest[A1], mA2: Manifest[A2], mA3: Manifest[A3], mB: Manifest[B], f: (Rep[A1], Rep[A2], Rep[A3]) => Rep[B]) extends TopLevel(n) + case class TopLevel4 [A1, A2, A3, A4, B](n: String, mA1: Manifest[A1], mA2: Manifest[A2], mA3: Manifest[A3], mA4: Manifest[A4], mB: Manifest[B], f: (Rep[A1], Rep[A2], Rep[A3], Rep[A4]) => Rep[B]) extends TopLevel(n) + case class TopLevel5 [A1, A2, A3, A4, A5, B](n: String, mA1: Manifest[A1], mA2: Manifest[A2], mA3: Manifest[A3], mA4: Manifest[A4], mA5: Manifest[A5], mB: Manifest[B], f: (Rep[A1], Rep[A2], Rep[A3], Rep[A4], Rep[A5]) => Rep[B]) extends TopLevel(n) + case class TopLevel6 [A1, A2, A3, A4, A5, A6, B](n: String, mA1: Manifest[A1], mA2: Manifest[A2], mA3: Manifest[A3], mA4: Manifest[A4], mA5: Manifest[A5], mA6: Manifest[A6], mB: Manifest[B], f: (Rep[A1], Rep[A2], Rep[A3], Rep[A4], Rep[A5], Rep[A6]) => Rep[B]) extends TopLevel(n) + case class TopLevel7 [A1, A2, A3, A4, A5, A6, A7, B](n: String, mA1: Manifest[A1], mA2: Manifest[A2], mA3: Manifest[A3], mA4: Manifest[A4], mA5: Manifest[A5], mA6: Manifest[A6], mA7: Manifest[A7], mB: Manifest[B], f: (Rep[A1], Rep[A2], Rep[A3], Rep[A4], Rep[A5], Rep[A6], Rep[A7]) => Rep[B]) extends TopLevel(n) + case class TopLevel8 [A1, A2, A3, A4, A5, A6, A7, A8, B](n: String, mA1: Manifest[A1], mA2: Manifest[A2], mA3: Manifest[A3], mA4: Manifest[A4], mA5: Manifest[A5], mA6: Manifest[A6], mA7: Manifest[A7], mA8: Manifest[A8], mB: Manifest[B], f: (Rep[A1], Rep[A2], Rep[A3], Rep[A4], Rep[A5], Rep[A6], Rep[A7], Rep[A8]) => Rep[B]) extends TopLevel(n) + case class TopLevel9 [A1, A2, A3, A4, A5, A6, A7, A8, A9, B](n: String, mA1: Manifest[A1], mA2: Manifest[A2], mA3: Manifest[A3], mA4: Manifest[A4], mA5: Manifest[A5], mA6: Manifest[A6], mA7: Manifest[A7], mA8: Manifest[A8], mA9: Manifest[A9], mB: Manifest[B], f: (Rep[A1], Rep[A2], Rep[A3], Rep[A4], Rep[A5], Rep[A6], Rep[A7], Rep[A8], Rep[A9]) => Rep[B]) extends TopLevel(n) + case class TopLevel10 [A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, B](n: String, mA1: Manifest[A1], mA2: Manifest[A2], mA3: Manifest[A3], mA4: Manifest[A4], mA5: Manifest[A5], mA6: Manifest[A6], mA7: Manifest[A7], mA8: Manifest[A8], mA9: Manifest[A9], mA10: Manifest[A10], mB: Manifest[B], f: (Rep[A1], Rep[A2], Rep[A3], Rep[A4], Rep[A5], Rep[A6], Rep[A7], Rep[A8], Rep[A9], Rep[A10]) => Rep[B]) extends TopLevel(n) + case class TopLevel11 [A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, B](n: String, mA1: Manifest[A1], mA2: Manifest[A2], mA3: Manifest[A3], mA4: Manifest[A4], mA5: Manifest[A5], mA6: Manifest[A6], mA7: Manifest[A7], mA8: Manifest[A8], mA9: Manifest[A9], mA10: Manifest[A10], mA11: Manifest[A11], mB: Manifest[B], f: (Rep[A1], Rep[A2], Rep[A3], Rep[A4], Rep[A5], Rep[A6], Rep[A7], Rep[A8], Rep[A9], Rep[A10], Rep[A11]) => Rep[B]) extends TopLevel(n) + case class TopLevel12 [A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, B](n: String, mA1: Manifest[A1], mA2: Manifest[A2], mA3: Manifest[A3], mA4: Manifest[A4], mA5: Manifest[A5], mA6: Manifest[A6], mA7: Manifest[A7], mA8: Manifest[A8], mA9: Manifest[A9], mA10: Manifest[A10], mA11: Manifest[A11], mA12: Manifest[A12], mB: Manifest[B], f: (Rep[A1], Rep[A2], Rep[A3], Rep[A4], Rep[A5], Rep[A6], Rep[A7], Rep[A8], Rep[A9], Rep[A10], Rep[A11], Rep[A12]) => Rep[B]) extends TopLevel(n) + case class TopLevel13 [A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, B](n: String, mA1: Manifest[A1], mA2: Manifest[A2], mA3: Manifest[A3], mA4: Manifest[A4], mA5: Manifest[A5], mA6: Manifest[A6], mA7: Manifest[A7], mA8: Manifest[A8], mA9: Manifest[A9], mA10: Manifest[A10], mA11: Manifest[A11], mA12: Manifest[A12], mA13: Manifest[A13], mB: Manifest[B], f: (Rep[A1], Rep[A2], Rep[A3], Rep[A4], Rep[A5], Rep[A6], Rep[A7], Rep[A8], Rep[A9], Rep[A10], Rep[A11], Rep[A12], Rep[A13]) => Rep[B]) extends TopLevel(n) + case class TopLevel14 [A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, B](n: String, mA1: Manifest[A1], mA2: Manifest[A2], mA3: Manifest[A3], mA4: Manifest[A4], mA5: Manifest[A5], mA6: Manifest[A6], mA7: Manifest[A7], mA8: Manifest[A8], mA9: Manifest[A9], mA10: Manifest[A10], mA11: Manifest[A11], mA12: Manifest[A12], mA13: Manifest[A13], mA14: Manifest[A14], mB: Manifest[B], f: (Rep[A1], Rep[A2], Rep[A3], Rep[A4], Rep[A5], Rep[A6], Rep[A7], Rep[A8], Rep[A9], Rep[A10], Rep[A11], Rep[A12], Rep[A13], Rep[A14]) => Rep[B]) extends TopLevel(n) + case class TopLevel15 [A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, B](n: String, mA1: Manifest[A1], mA2: Manifest[A2], mA3: Manifest[A3], mA4: Manifest[A4], mA5: Manifest[A5], mA6: Manifest[A6], mA7: Manifest[A7], mA8: Manifest[A8], mA9: Manifest[A9], mA10: Manifest[A10], mA11: Manifest[A11], mA12: Manifest[A12], mA13: Manifest[A13], mA14: Manifest[A14], mA15: Manifest[A15], mB: Manifest[B], f: (Rep[A1], Rep[A2], Rep[A3], Rep[A4], Rep[A5], Rep[A6], Rep[A7], Rep[A8], Rep[A9], Rep[A10], Rep[A11], Rep[A12], Rep[A13], Rep[A14], Rep[A15]) => Rep[B]) extends TopLevel(n) + case class TopLevel16 [A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, B](n: String, mA1: Manifest[A1], mA2: Manifest[A2], mA3: Manifest[A3], mA4: Manifest[A4], mA5: Manifest[A5], mA6: Manifest[A6], mA7: Manifest[A7], mA8: Manifest[A8], mA9: Manifest[A9], mA10: Manifest[A10], mA11: Manifest[A11], mA12: Manifest[A12], mA13: Manifest[A13], mA14: Manifest[A14], mA15: Manifest[A15], mA16: Manifest[A16], mB: Manifest[B], f: (Rep[A1], Rep[A2], Rep[A3], Rep[A4], Rep[A5], Rep[A6], Rep[A7], Rep[A8], Rep[A9], Rep[A10], Rep[A11], Rep[A12], Rep[A13], Rep[A14], Rep[A15], Rep[A16]) => Rep[B]) extends TopLevel(n) + case class TopLevel17 [A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, B](n: String, mA1: Manifest[A1], mA2: Manifest[A2], mA3: Manifest[A3], mA4: Manifest[A4], mA5: Manifest[A5], mA6: Manifest[A6], mA7: Manifest[A7], mA8: Manifest[A8], mA9: Manifest[A9], mA10: Manifest[A10], mA11: Manifest[A11], mA12: Manifest[A12], mA13: Manifest[A13], mA14: Manifest[A14], mA15: Manifest[A15], mA16: Manifest[A16], mA17: Manifest[A17], mB: Manifest[B], f: (Rep[A1], Rep[A2], Rep[A3], Rep[A4], Rep[A5], Rep[A6], Rep[A7], Rep[A8], Rep[A9], Rep[A10], Rep[A11], Rep[A12], Rep[A13], Rep[A14], Rep[A15], Rep[A16], Rep[A17]) => Rep[B]) extends TopLevel(n) + case class TopLevel18 [A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, B](n: String, mA1: Manifest[A1], mA2: Manifest[A2], mA3: Manifest[A3], mA4: Manifest[A4], mA5: Manifest[A5], mA6: Manifest[A6], mA7: Manifest[A7], mA8: Manifest[A8], mA9: Manifest[A9], mA10: Manifest[A10], mA11: Manifest[A11], mA12: Manifest[A12], mA13: Manifest[A13], mA14: Manifest[A14], mA15: Manifest[A15], mA16: Manifest[A16], mA17: Manifest[A17], mA18: Manifest[A18], mB: Manifest[B], f: (Rep[A1], Rep[A2], Rep[A3], Rep[A4], Rep[A5], Rep[A6], Rep[A7], Rep[A8], Rep[A9], Rep[A10], Rep[A11], Rep[A12], Rep[A13], Rep[A14], Rep[A15], Rep[A16], Rep[A17], Rep[A18]) => Rep[B]) extends TopLevel(n) + case class TopLevel19 [A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, B](n: String, mA1: Manifest[A1], mA2: Manifest[A2], mA3: Manifest[A3], mA4: Manifest[A4], mA5: Manifest[A5], mA6: Manifest[A6], mA7: Manifest[A7], mA8: Manifest[A8], mA9: Manifest[A9], mA10: Manifest[A10], mA11: Manifest[A11], mA12: Manifest[A12], mA13: Manifest[A13], mA14: Manifest[A14], mA15: Manifest[A15], mA16: Manifest[A16], mA17: Manifest[A17], mA18: Manifest[A18], mA19: Manifest[A19], mB: Manifest[B], f: (Rep[A1], Rep[A2], Rep[A3], Rep[A4], Rep[A5], Rep[A6], Rep[A7], Rep[A8], Rep[A9], Rep[A10], Rep[A11], Rep[A12], Rep[A13], Rep[A14], Rep[A15], Rep[A16], Rep[A17], Rep[A18], Rep[A19]) => Rep[B]) extends TopLevel(n) + + //case class TopLevel14_16 [A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, A22, A23, A24, A25, A26, A27, A28, A29, A30, B](n: String, (mA1: Manifest[A1], mA2: Manifest[A2], mA3: Manifest[A3], mA4: Manifest[A4], mA5: Manifest[A5], mA6: Manifest[A6], mA7: Manifest[A7], mA8: Manifest[A8], mA9: Manifest[A9], mA10: Manifest[A10], mA11: Manifest[A11], mA12: Manifest[A12], mA13: Manifest[A13], mA14: Manifest[A14]), (mA15: Manifest[A15], mA16: Manifest[A16], mA17: Manifest[A17], mA18: Manifest[A18], mA19: Manifest[A19], mA20: Manifest[A20], mA21: Manifest[A21], mA22: Manifest[A22], mA23: Manifest[A23], mA24: Manifest[A24], mA25: Manifest[A25], mA26: Manifest[A26], mA27: Manifest[A27], mA28: Manifest[A28], mA29: Manifest[A29], mA30: Manifest[A30]), mB: Manifest[B], f: (Rep[A1], Rep[A2], Rep[A3], Rep[A4], Rep[A5], Rep[A6], Rep[A7], Rep[A8], Rep[A9], Rep[A10], Rep[A11], Rep[A12], Rep[A13], Rep[A14])(Rep[A15], Rep[A16], Rep[A17], Rep[A18], Rep[A19], Rep[A20], Rep[A21], Rep[A22], Rep[A23], Rep[A24], Rep[A25], Rep[A26], Rep[A27], Rep[A28], Rep[A29], Rep[A30]) => Rep[B]) extends TopLevel(n) + + def toplevel1[A1: Manifest, B: Manifest](name: String)(f: (Rep[A1]) => Rep[B]): (Rep[A1]) => Rep[B] = { + val g = (x1: Rep[A1]) => unchecked[B](name, "(", x1, ")") + rec.getOrElseUpdate(name, TopLevel1(name, manifest[A1], manifest[B], f)) + g + } + def toplevel2[A1: Manifest, A2: Manifest, B: Manifest](name: String)(f: (Rep[A1], Rep[A2]) => Rep[B]): (Rep[A1], Rep[A2]) => Rep[B] = { + val g = (x1: Rep[A1], x2: Rep[A2]) => unchecked[B](name, "(", x1, ",", x2, ")") + rec.getOrElseUpdate(name, TopLevel2(name, manifest[A1], manifest[A2], manifest[B], f)) + g + } + def toplevel3[A1: Manifest, A2: Manifest, A3: Manifest, B: Manifest](name: String)(f: (Rep[A1], Rep[A2], Rep[A3]) => Rep[B]): (Rep[A1], Rep[A2], Rep[A3]) => Rep[B] = { + val g = (x1: Rep[A1], x2: Rep[A2], x3: Rep[A3]) => unchecked[B](name, "(", x1, ",", x2, ",", x3, ")") + rec.getOrElseUpdate(name, TopLevel3(name, manifest[A1], manifest[A2], manifest[A3], manifest[B], f)) + g + } + def toplevel4[A1: Manifest, A2: Manifest, A3: Manifest, A4: Manifest, B: Manifest](name: String)(f: (Rep[A1], Rep[A2], Rep[A3], Rep[A4]) => Rep[B]): (Rep[A1], Rep[A2], Rep[A3], Rep[A4]) => Rep[B] = { + val g = (x1: Rep[A1], x2: Rep[A2], x3: Rep[A3], x4: Rep[A4]) => unchecked[B](name, "(", x1, ",", x2, ",", x3, ",", x4, ")") + rec.getOrElseUpdate(name, TopLevel4(name, manifest[A1], manifest[A2], manifest[A3], manifest[A4], manifest[B], f)) + g + } + def toplevel5[A1: Manifest, A2: Manifest, A3: Manifest, A4: Manifest, A5: Manifest, B: Manifest](name: String)(f: (Rep[A1], Rep[A2], Rep[A3], Rep[A4], Rep[A5]) => Rep[B]): (Rep[A1], Rep[A2], Rep[A3], Rep[A4], Rep[A5]) => Rep[B] = { + val g = (x1: Rep[A1], x2: Rep[A2], x3: Rep[A3], x4: Rep[A4], x5: Rep[A5]) => unchecked[B](name, "(", x1, ",", x2, ",", x3, ",", x4, ",", x5, ")") + rec.getOrElseUpdate(name, TopLevel5(name, manifest[A1], manifest[A2], manifest[A3], manifest[A4], manifest[A5], manifest[B], f)) + g + } + def toplevel6[A1: Manifest, A2: Manifest, A3: Manifest, A4: Manifest, A5: Manifest, A6: Manifest, B: Manifest](name: String)(f: (Rep[A1], Rep[A2], Rep[A3], Rep[A4], Rep[A5], Rep[A6]) => Rep[B]): (Rep[A1], Rep[A2], Rep[A3], Rep[A4], Rep[A5], Rep[A6]) => Rep[B] = { + val g = (x1: Rep[A1], x2: Rep[A2], x3: Rep[A3], x4: Rep[A4], x5: Rep[A5], x6: Rep[A6]) => unchecked[B](name, "(", x1, ",", x2, ",", x3, ",", x4, ",", x5, ",", x6, ")") + rec.getOrElseUpdate(name, TopLevel6(name, manifest[A1], manifest[A2], manifest[A3], manifest[A4], manifest[A5], manifest[A6], manifest[B], f)) + g + } + def toplevel7[A1: Manifest, A2: Manifest, A3: Manifest, A4: Manifest, A5: Manifest, A6: Manifest, A7: Manifest, B: Manifest](name: String)(f: (Rep[A1], Rep[A2], Rep[A3], Rep[A4], Rep[A5], Rep[A6], Rep[A7]) => Rep[B]): (Rep[A1], Rep[A2], Rep[A3], Rep[A4], Rep[A5], Rep[A6], Rep[A7]) => Rep[B] = { + val g = (x1: Rep[A1], x2: Rep[A2], x3: Rep[A3], x4: Rep[A4], x5: Rep[A5], x6: Rep[A6], x7: Rep[A7]) => unchecked[B](name, "(", x1, ",", x2, ",", x3, ",", x4, ",", x5, ",", x6, ",", x7, ")") + rec.getOrElseUpdate(name, TopLevel7(name, manifest[A1], manifest[A2], manifest[A3], manifest[A4], manifest[A5], manifest[A6], manifest[A7], manifest[B], f)) + g + } + def toplevel8[A1: Manifest, A2: Manifest, A3: Manifest, A4: Manifest, A5: Manifest, A6: Manifest, A7: Manifest, A8: Manifest, B: Manifest](name: String)(f: (Rep[A1], Rep[A2], Rep[A3], Rep[A4], Rep[A5], Rep[A6], Rep[A7], Rep[A8]) => Rep[B]): (Rep[A1], Rep[A2], Rep[A3], Rep[A4], Rep[A5], Rep[A6], Rep[A7], Rep[A8]) => Rep[B] = { + val g = (x1: Rep[A1], x2: Rep[A2], x3: Rep[A3], x4: Rep[A4], x5: Rep[A5], x6: Rep[A6], x7: Rep[A7], x8: Rep[A8]) => unchecked[B](name, "(", x1, ",", x2, ",", x3, ",", x4, ",", x5, ",", x6, ",", x7, ",", x8, ")") + rec.getOrElseUpdate(name, TopLevel8(name, manifest[A1], manifest[A2], manifest[A3], manifest[A4], manifest[A5], manifest[A6], manifest[A7], manifest[A8], manifest[B], f)) + g + } + def toplevel9[A1: Manifest, A2: Manifest, A3: Manifest, A4: Manifest, A5: Manifest, A6: Manifest, A7: Manifest, A8: Manifest, A9: Manifest, B: Manifest](name: String)(f: (Rep[A1], Rep[A2], Rep[A3], Rep[A4], Rep[A5], Rep[A6], Rep[A7], Rep[A8], Rep[A9]) => Rep[B]): (Rep[A1], Rep[A2], Rep[A3], Rep[A4], Rep[A5], Rep[A6], Rep[A7], Rep[A8], Rep[A9]) => Rep[B] = { + val g = (x1: Rep[A1], x2: Rep[A2], x3: Rep[A3], x4: Rep[A4], x5: Rep[A5], x6: Rep[A6], x7: Rep[A7], x8: Rep[A8], x9: Rep[A9]) => unchecked[B](name, "(", x1, ",", x2, ",", x3, ",", x4, ",", x5, ",", x6, ",", x7, ",", x8, ",", x9, ")") + rec.getOrElseUpdate(name, TopLevel9(name, manifest[A1], manifest[A2], manifest[A3], manifest[A4], manifest[A5], manifest[A6], manifest[A7], manifest[A8], manifest[A9], manifest[B], f)) + g + } + def toplevel10[A1: Manifest, A2: Manifest, A3: Manifest, A4: Manifest, A5: Manifest, A6: Manifest, A7: Manifest, A8: Manifest, A9: Manifest, A10: Manifest, B: Manifest](name: String)(f: (Rep[A1], Rep[A2], Rep[A3], Rep[A4], Rep[A5], Rep[A6], Rep[A7], Rep[A8], Rep[A9], Rep[A10]) => Rep[B]): (Rep[A1], Rep[A2], Rep[A3], Rep[A4], Rep[A5], Rep[A6], Rep[A7], Rep[A8], Rep[A9], Rep[A10]) => Rep[B] = { + val g = (x1: Rep[A1], x2: Rep[A2], x3: Rep[A3], x4: Rep[A4], x5: Rep[A5], x6: Rep[A6], x7: Rep[A7], x8: Rep[A8], x9: Rep[A9], x10: Rep[A10]) => unchecked[B](name, "(", x1, ",", x2, ",", x3, ",", x4, ",", x5, ",", x6, ",", x7, ",", x8, ",", x9, ",", x10, ")") + rec.getOrElseUpdate(name, TopLevel10(name, manifest[A1], manifest[A2], manifest[A3], manifest[A4], manifest[A5], manifest[A6], manifest[A7], manifest[A8], manifest[A9], manifest[A10], manifest[B], f)) + g + } + def toplevel11[A1: Manifest, A2: Manifest, A3: Manifest, A4: Manifest, A5: Manifest, A6: Manifest, A7: Manifest, A8: Manifest, A9: Manifest, A10: Manifest, A11: Manifest, B: Manifest](name: String)(f: (Rep[A1], Rep[A2], Rep[A3], Rep[A4], Rep[A5], Rep[A6], Rep[A7], Rep[A8], Rep[A9], Rep[A10], Rep[A11]) => Rep[B]): (Rep[A1], Rep[A2], Rep[A3], Rep[A4], Rep[A5], Rep[A6], Rep[A7], Rep[A8], Rep[A9], Rep[A10], Rep[A11]) => Rep[B] = { + val g = (x1: Rep[A1], x2: Rep[A2], x3: Rep[A3], x4: Rep[A4], x5: Rep[A5], x6: Rep[A6], x7: Rep[A7], x8: Rep[A8], x9: Rep[A9], x10: Rep[A10], x11: Rep[A11]) => unchecked[B](name, "(", x1, ",", x2, ",", x3, ",", x4, ",", x5, ",", x6, ",", x7, ",", x8, ",", x9, ",", x10, ",", x11, ")") + rec.getOrElseUpdate(name, TopLevel11(name, manifest[A1], manifest[A2], manifest[A3], manifest[A4], manifest[A5], manifest[A6], manifest[A7], manifest[A8], manifest[A9], manifest[A10], manifest[A11], manifest[B], f)) + g + } + def toplevel12[A1: Manifest, A2: Manifest, A3: Manifest, A4: Manifest, A5: Manifest, A6: Manifest, A7: Manifest, A8: Manifest, A9: Manifest, A10: Manifest, A11: Manifest, A12: Manifest, B: Manifest](name: String)(f: (Rep[A1], Rep[A2], Rep[A3], Rep[A4], Rep[A5], Rep[A6], Rep[A7], Rep[A8], Rep[A9], Rep[A10], Rep[A11], Rep[A12]) => Rep[B]): (Rep[A1], Rep[A2], Rep[A3], Rep[A4], Rep[A5], Rep[A6], Rep[A7], Rep[A8], Rep[A9], Rep[A10], Rep[A11], Rep[A12]) => Rep[B] = { + val g = (x1: Rep[A1], x2: Rep[A2], x3: Rep[A3], x4: Rep[A4], x5: Rep[A5], x6: Rep[A6], x7: Rep[A7], x8: Rep[A8], x9: Rep[A9], x10: Rep[A10], x11: Rep[A11], x12: Rep[A12]) => unchecked[B](name, "(", x1, ",", x2, ",", x3, ",", x4, ",", x5, ",", x6, ",", x7, ",", x8, ",", x9, ",", x10, ",", x11, ",", x12, ")") + rec.getOrElseUpdate(name, TopLevel12(name, manifest[A1], manifest[A2], manifest[A3], manifest[A4], manifest[A5], manifest[A6], manifest[A7], manifest[A8], manifest[A9], manifest[A10], manifest[A11], manifest[A12], manifest[B], f)) + g + } + def toplevel13[A1: Manifest, A2: Manifest, A3: Manifest, A4: Manifest, A5: Manifest, A6: Manifest, A7: Manifest, A8: Manifest, A9: Manifest, A10: Manifest, A11: Manifest, A12: Manifest, A13: Manifest, B: Manifest](name: String)(f: (Rep[A1], Rep[A2], Rep[A3], Rep[A4], Rep[A5], Rep[A6], Rep[A7], Rep[A8], Rep[A9], Rep[A10], Rep[A11], Rep[A12], Rep[A13]) => Rep[B]): (Rep[A1], Rep[A2], Rep[A3], Rep[A4], Rep[A5], Rep[A6], Rep[A7], Rep[A8], Rep[A9], Rep[A10], Rep[A11], Rep[A12], Rep[A13]) => Rep[B] = { + val g = (x1: Rep[A1], x2: Rep[A2], x3: Rep[A3], x4: Rep[A4], x5: Rep[A5], x6: Rep[A6], x7: Rep[A7], x8: Rep[A8], x9: Rep[A9], x10: Rep[A10], x11: Rep[A11], x12: Rep[A12], x13: Rep[A13]) => unchecked[B](name, "(", x1, ",", x2, ",", x3, ",", x4, ",", x5, ",", x6, ",", x7, ",", x8, ",", x9, ",", x10, ",", x11, ",", x12, ",", x13, ")") + rec.getOrElseUpdate(name, TopLevel13(name, manifest[A1], manifest[A2], manifest[A3], manifest[A4], manifest[A5], manifest[A6], manifest[A7], manifest[A8], manifest[A9], manifest[A10], manifest[A11], manifest[A12], manifest[A13], manifest[B], f)) + g + } + def toplevel14[A1: Manifest, A2: Manifest, A3: Manifest, A4: Manifest, A5: Manifest, A6: Manifest, A7: Manifest, A8: Manifest, A9: Manifest, A10: Manifest, A11: Manifest, A12: Manifest, A13: Manifest, A14: Manifest, B: Manifest](name: String)(f: (Rep[A1], Rep[A2], Rep[A3], Rep[A4], Rep[A5], Rep[A6], Rep[A7], Rep[A8], Rep[A9], Rep[A10], Rep[A11], Rep[A12], Rep[A13], Rep[A14]) => Rep[B]): (Rep[A1], Rep[A2], Rep[A3], Rep[A4], Rep[A5], Rep[A6], Rep[A7], Rep[A8], Rep[A9], Rep[A10], Rep[A11], Rep[A12], Rep[A13], Rep[A14]) => Rep[B] = { + val g = (x1: Rep[A1], x2: Rep[A2], x3: Rep[A3], x4: Rep[A4], x5: Rep[A5], x6: Rep[A6], x7: Rep[A7], x8: Rep[A8], x9: Rep[A9], x10: Rep[A10], x11: Rep[A11], x12: Rep[A12], x13: Rep[A13], x14: Rep[A14]) => unchecked[B](name, "(", x1, ",", x2, ",", x3, ",", x4, ",", x5, ",", x6, ",", x7, ",", x8, ",", x9, ",", x10, ",", x11, ",", x12, ",", x13, ",", x14, ")") + rec.getOrElseUpdate(name, TopLevel14(name, manifest[A1], manifest[A2], manifest[A3], manifest[A4], manifest[A5], manifest[A6], manifest[A7], manifest[A8], manifest[A9], manifest[A10], manifest[A11], manifest[A12], manifest[A13], manifest[A14], manifest[B], f)) + g + } + def toplevel15[A1: Manifest, A2: Manifest, A3: Manifest, A4: Manifest, A5: Manifest, A6: Manifest, A7: Manifest, A8: Manifest, A9: Manifest, A10: Manifest, A11: Manifest, A12: Manifest, A13: Manifest, A14: Manifest, A15: Manifest, B: Manifest](name: String)(f: (Rep[A1], Rep[A2], Rep[A3], Rep[A4], Rep[A5], Rep[A6], Rep[A7], Rep[A8], Rep[A9], Rep[A10], Rep[A11], Rep[A12], Rep[A13], Rep[A14], Rep[A15]) => Rep[B]): (Rep[A1], Rep[A2], Rep[A3], Rep[A4], Rep[A5], Rep[A6], Rep[A7], Rep[A8], Rep[A9], Rep[A10], Rep[A11], Rep[A12], Rep[A13], Rep[A14], Rep[A15]) => Rep[B] = { + val g = (x1: Rep[A1], x2: Rep[A2], x3: Rep[A3], x4: Rep[A4], x5: Rep[A5], x6: Rep[A6], x7: Rep[A7], x8: Rep[A8], x9: Rep[A9], x10: Rep[A10], x11: Rep[A11], x12: Rep[A12], x13: Rep[A13], x14: Rep[A14], x15: Rep[A15]) => unchecked[B](name, "(", x1, ",", x2, ",", x3, ",", x4, ",", x5, ",", x6, ",", x7, ",", x8, ",", x9, ",", x10, ",", x11, ",", x12, ",", x13, ",", x14, ",", x15, ")") + rec.getOrElseUpdate(name, TopLevel15(name, manifest[A1], manifest[A2], manifest[A3], manifest[A4], manifest[A5], manifest[A6], manifest[A7], manifest[A8], manifest[A9], manifest[A10], manifest[A11], manifest[A12], manifest[A13], manifest[A14], manifest[A15], manifest[B], f)) + g + } + def toplevel16[A1: Manifest, A2: Manifest, A3: Manifest, A4: Manifest, A5: Manifest, A6: Manifest, A7: Manifest, A8: Manifest, A9: Manifest, A10: Manifest, A11: Manifest, A12: Manifest, A13: Manifest, A14: Manifest, A15: Manifest, A16: Manifest, B: Manifest](name: String)(f: (Rep[A1], Rep[A2], Rep[A3], Rep[A4], Rep[A5], Rep[A6], Rep[A7], Rep[A8], Rep[A9], Rep[A10], Rep[A11], Rep[A12], Rep[A13], Rep[A14], Rep[A15], Rep[A16]) => Rep[B]): (Rep[A1], Rep[A2], Rep[A3], Rep[A4], Rep[A5], Rep[A6], Rep[A7], Rep[A8], Rep[A9], Rep[A10], Rep[A11], Rep[A12], Rep[A13], Rep[A14], Rep[A15], Rep[A16]) => Rep[B] = { + val g = (x1: Rep[A1], x2: Rep[A2], x3: Rep[A3], x4: Rep[A4], x5: Rep[A5], x6: Rep[A6], x7: Rep[A7], x8: Rep[A8], x9: Rep[A9], x10: Rep[A10], x11: Rep[A11], x12: Rep[A12], x13: Rep[A13], x14: Rep[A14], x15: Rep[A15], x16: Rep[A16]) => unchecked[B](name, "(", x1, ",", x2, ",", x3, ",", x4, ",", x5, ",", x6, ",", x7, ",", x8, ",", x9, ",", x10, ",", x11, ",", x12, ",", x13, ",", x14, ",", x15, ",", x16, ")") + rec.getOrElseUpdate(name, TopLevel16(name, manifest[A1], manifest[A2], manifest[A3], manifest[A4], manifest[A5], manifest[A6], manifest[A7], manifest[A8], manifest[A9], manifest[A10], manifest[A11], manifest[A12], manifest[A13], manifest[A14], manifest[A15], manifest[A16], manifest[B], f)) + g + } + def toplevel17[A1: Manifest, A2: Manifest, A3: Manifest, A4: Manifest, A5: Manifest, A6: Manifest, A7: Manifest, A8: Manifest, A9: Manifest, A10: Manifest, A11: Manifest, A12: Manifest, A13: Manifest, A14: Manifest, A15: Manifest, A16: Manifest, A17: Manifest, B: Manifest](name: String)(f: (Rep[A1], Rep[A2], Rep[A3], Rep[A4], Rep[A5], Rep[A6], Rep[A7], Rep[A8], Rep[A9], Rep[A10], Rep[A11], Rep[A12], Rep[A13], Rep[A14], Rep[A15], Rep[A16], Rep[A17]) => Rep[B]): (Rep[A1], Rep[A2], Rep[A3], Rep[A4], Rep[A5], Rep[A6], Rep[A7], Rep[A8], Rep[A9], Rep[A10], Rep[A11], Rep[A12], Rep[A13], Rep[A14], Rep[A15], Rep[A16], Rep[A17]) => Rep[B] = { + val g = (x1: Rep[A1], x2: Rep[A2], x3: Rep[A3], x4: Rep[A4], x5: Rep[A5], x6: Rep[A6], x7: Rep[A7], x8: Rep[A8], x9: Rep[A9], x10: Rep[A10], x11: Rep[A11], x12: Rep[A12], x13: Rep[A13], x14: Rep[A14], x15: Rep[A15], x16: Rep[A16], x17: Rep[A17]) => unchecked[B](name, "(", x1, ",", x2, ",", x3, ",", x4, ",", x5, ",", x6, ",", x7, ",", x8, ",", x9, ",", x10, ",", x11, ",", x12, ",", x13, ",", x14, ",", x15, ",", x16, ",", x17, ")") + rec.getOrElseUpdate(name, TopLevel17(name, manifest[A1], manifest[A2], manifest[A3], manifest[A4], manifest[A5], manifest[A6], manifest[A7], manifest[A8], manifest[A9], manifest[A10], manifest[A11], manifest[A12], manifest[A13], manifest[A14], manifest[A15], manifest[A16], manifest[A17], manifest[B], f)) + g + } + def toplevel18[A1: Manifest, A2: Manifest, A3: Manifest, A4: Manifest, A5: Manifest, A6: Manifest, A7: Manifest, A8: Manifest, A9: Manifest, A10: Manifest, A11: Manifest, A12: Manifest, A13: Manifest, A14: Manifest, A15: Manifest, A16: Manifest, A17: Manifest, A18: Manifest, B: Manifest](name: String)(f: (Rep[A1], Rep[A2], Rep[A3], Rep[A4], Rep[A5], Rep[A6], Rep[A7], Rep[A8], Rep[A9], Rep[A10], Rep[A11], Rep[A12], Rep[A13], Rep[A14], Rep[A15], Rep[A16], Rep[A17], Rep[A18]) => Rep[B]): (Rep[A1], Rep[A2], Rep[A3], Rep[A4], Rep[A5], Rep[A6], Rep[A7], Rep[A8], Rep[A9], Rep[A10], Rep[A11], Rep[A12], Rep[A13], Rep[A14], Rep[A15], Rep[A16], Rep[A17], Rep[A18]) => Rep[B] = { + val g = (x1: Rep[A1], x2: Rep[A2], x3: Rep[A3], x4: Rep[A4], x5: Rep[A5], x6: Rep[A6], x7: Rep[A7], x8: Rep[A8], x9: Rep[A9], x10: Rep[A10], x11: Rep[A11], x12: Rep[A12], x13: Rep[A13], x14: Rep[A14], x15: Rep[A15], x16: Rep[A16], x17: Rep[A17], x18: Rep[A18]) => unchecked[B](name, "(", x1, ",", x2, ",", x3, ",", x4, ",", x5, ",", x6, ",", x7, ",", x8, ",", x9, ",", x10, ",", x11, ",", x12, ",", x13, ",", x14, ",", x15, ",", x16, ",", x17, ",", x18, ")") + rec.getOrElseUpdate(name, TopLevel18(name, manifest[A1], manifest[A2], manifest[A3], manifest[A4], manifest[A5], manifest[A6], manifest[A7], manifest[A8], manifest[A9], manifest[A10], manifest[A11], manifest[A12], manifest[A13], manifest[A14], manifest[A15], manifest[A16], manifest[A17], manifest[A18], manifest[B], f)) + g + } + def toplevel19[A1: Manifest, A2: Manifest, A3: Manifest, A4: Manifest, A5: Manifest, A6: Manifest, A7: Manifest, A8: Manifest, A9: Manifest, A10: Manifest, A11: Manifest, A12: Manifest, A13: Manifest, A14: Manifest, A15: Manifest, A16: Manifest, A17: Manifest, A18: Manifest, A19: Manifest, B: Manifest](name: String)(f: (Rep[A1], Rep[A2], Rep[A3], Rep[A4], Rep[A5], Rep[A6], Rep[A7], Rep[A8], Rep[A9], Rep[A10], Rep[A11], Rep[A12], Rep[A13], Rep[A14], Rep[A15], Rep[A16], Rep[A17], Rep[A18], Rep[A19]) => Rep[B]): (Rep[A1], Rep[A2], Rep[A3], Rep[A4], Rep[A5], Rep[A6], Rep[A7], Rep[A8], Rep[A9], Rep[A10], Rep[A11], Rep[A12], Rep[A13], Rep[A14], Rep[A15], Rep[A16], Rep[A17], Rep[A18], Rep[A19]) => Rep[B] = { + val g = (x1: Rep[A1], x2: Rep[A2], x3: Rep[A3], x4: Rep[A4], x5: Rep[A5], x6: Rep[A6], x7: Rep[A7], x8: Rep[A8], x9: Rep[A9], x10: Rep[A10], x11: Rep[A11], x12: Rep[A12], x13: Rep[A13], x14: Rep[A14], x15: Rep[A15], x16: Rep[A16], x17: Rep[A17], x18: Rep[A18], x19: Rep[A19]) => unchecked[B](name, "(", x1, ",", x2, ",", x3, ",", x4, ",", x5, ",", x6, ",", x7, ",", x8, ",", x9, ",", x10, ",", x11, ",", x12, ",", x13, ",", x14, ",", x15, ",", x16, ",", x17, ",", x18, ",", x19, ")") + rec.getOrElseUpdate(name, TopLevel19(name, manifest[A1], manifest[A2], manifest[A3], manifest[A4], manifest[A5], manifest[A6], manifest[A7], manifest[A8], manifest[A9], manifest[A10], manifest[A11], manifest[A12], manifest[A13], manifest[A14], manifest[A15], manifest[A16], manifest[A17], manifest[A18], manifest[A19], manifest[B], f)) + g + } +} \ No newline at end of file diff --git a/src/common/Functions.scala b/src/common/Functions.scala index f956571d..90982e6b 100644 --- a/src/common/Functions.scala +++ b/src/common/Functions.scala @@ -3,13 +3,13 @@ package common import java.io.PrintWriter -import scala.lms.internal.{GenericNestedCodegen, GenerationFailedException} +import scala.lms.internal.{GenericNestedCodegen, GenerationFailedException, CNestedCodegen} import scala.lms.util.ClosureCompare import scala.reflect.SourceContext trait Functions extends Base { - + def doLambda[A:Manifest,B:Manifest](fun: Rep[A] => Rep[B])(implicit pos: SourceContext): Rep[A => B] implicit def fun[A:Manifest,B:Manifest](f: Rep[A] => Rep[B]): Rep[A=>B] = doLambda(f) @@ -20,6 +20,10 @@ trait Functions extends Base { } def doApply[A:Manifest,B:Manifest](fun: Rep[A => B], arg: Rep[A])(implicit pos: SourceContext): Rep[B] + def uninlinedFunc0[B:Manifest](f: Function0[Rep[B]]): Rep[Unit=>B] + def uninlinedFunc1[A:Manifest,B:Manifest](f: Function1[Rep[A],Rep[B]]): Rep[A=>B] + def uninlinedFunc2[A1:Manifest,A2:Manifest,B:Manifest](f: Function2[Rep[A1],Rep[A2],Rep[B]]): Rep[(A1,A2)=>B] + def uninlinedFunc3[A1:Manifest,A2:Manifest,A3:Manifest,B:Manifest](f: Function3[Rep[A1],Rep[A2],Rep[A3],Rep[B]]): Rep[(A1,A2,A3)=>B] } trait TupledFunctions extends Functions with TupleOps { @@ -34,6 +38,9 @@ trait TupledFunctions extends Functions with TupleOps { implicit def fun[A1:Manifest,A2:Manifest,A3:Manifest,A4:Manifest,A5:Manifest,B:Manifest](f: (Rep[A1], Rep[A2], Rep[A3], Rep[A4], Rep[A5]) => Rep[B]): Rep[((A1,A2,A3,A4,A5))=>B] = fun((t: Rep[(A1,A2,A3,A4,A5)]) => f(tuple5_get1(t), tuple5_get2(t), tuple5_get3(t), tuple5_get4(t), tuple5_get5(t))) + class LambdaOps0[B:Manifest](f: Rep[Unit => B]) { + def apply() = doApply(f,unit()) + } class LambdaOps2[A1:Manifest,A2:Manifest,B:Manifest](f: Rep[((A1,A2)) => B]) { def apply(x1: Rep[A1], x2: Rep[A2]) = doApply(f,(x1, x2)) def apply(x: Rep[(A1,A2)]): Rep[B] = doApply(f,x) @@ -52,6 +59,8 @@ trait TupledFunctions extends Functions with TupleOps { } implicit def toLambdaOpsAny[B:Manifest](fun: Rep[Any => B]) = toLambdaOps(fun) + implicit def toLambdaOps0[B:Manifest](fun: Rep[Unit => B]) = + new LambdaOps0(fun) implicit def toLambdaOps2[A1:Manifest,A2:Manifest,B:Manifest](fun: Rep[((A1,A2)) => B]) = new LambdaOps2(fun) implicit def toLambdaOps3[A1:Manifest,A2:Manifest,A3:Manifest,B:Manifest](fun: Rep[((A1,A2,A3)) => B]) = @@ -64,7 +73,29 @@ trait TupledFunctions extends Functions with TupleOps { trait FunctionsExp extends Functions with EffectExp { case class Lambda[A:Manifest,B:Manifest](f: Exp[A] => Exp[B], x: Exp[A], y: Block[B]) extends Def[A => B] { val mA = manifest[A]; val mB = manifest[B] } - case class Apply[A:Manifest,B:Manifest](f: Exp[A => B], arg: Exp[A]) extends Def[B] { val mA = manifest[A]; val mB = manifest[B] } + case class Apply[A:Manifest,B:Manifest](f: Exp[A => B], arg: Exp[A]) extends Def[B] { + val mA = manifest[A] + val mB = manifest[B] + } + case class UninlinedFunc0[B:Manifest](b: Block[B]) extends Def[Unit => B] { + val mB = manifest[B] + } + case class UninlinedFunc1[A:Manifest,B:Manifest](s:Sym[A], b: Block[B]) extends Def[A => B] { + val mA = manifest[A] + val mB = manifest[B] + } + case class UninlinedFunc2[A1:Manifest,A2:Manifest,B:Manifest](s1:Sym[A1], s2:Sym[A2], b: Block[B]) extends Def[(A1,A2) => B] { + val mA1 = manifest[A1] + val mA2 = manifest[A2] + val mB = manifest[B] + } + case class UninlinedFunc3[A1:Manifest,A2:Manifest,A3:Manifest,B:Manifest](s1:Sym[A1], s2:Sym[A2], s3:Sym[A3], b: Block[B]) extends Def[(A1,A2,A3) => B] { + val mA1 = manifest[A1] + val mA2 = manifest[A2] + val mA3 = manifest[A3] + val mB = manifest[B] + } + // unboxedFresh and unbox are hooks that can be overridden to // implement multiple-arity functions with tuples. These two methods @@ -98,10 +129,60 @@ trait FunctionsExp extends Functions with EffectExp { } } + /* BEGINNING UNINLINED FUNCTIONS */ + val functionList0 = new scala.collection.mutable.HashMap[Sym[Any],Block[Any]]() + val functionList1 = new scala.collection.mutable.HashMap[Sym[Any],(Sym[Any],Block[Any])]() + val functionList2 = new scala.collection.mutable.HashMap[Sym[Any],(Sym[Any],Sym[Any],Block[Any])]() + val functionList3 = new scala.collection.mutable.HashMap[Sym[Any],(Sym[Any],Sym[Any],Sym[Any],Block[Any])]() + def uninlinedFunc0[B:Manifest](f: Function0[Rep[B]]) = { + val b = reifyEffects(f()) + uninlinedFunc0(b) + } + def uninlinedFunc0[B:Manifest](b: Block[B]) = { + val l = reflectEffect(UninlinedFunc0(b), Pure()) + functionList0 += (l.asInstanceOf[Sym[Any]] -> b) + l + } + + def uninlinedFunc1[A:Manifest,B:Manifest](f: Function1[Rep[A],Rep[B]]) = { + val s = fresh[A] + val b = reifyEffects(f(s)) + uninlinedFunc1(s,b) + } + def uninlinedFunc1[A:Manifest,B:Manifest](s: Sym[A], b: Block[B]) = { + val l = reflectEffect(UninlinedFunc1(s,b), Pure()) + functionList1 += (l.asInstanceOf[Sym[Any]] -> (s,b)) + l + } + + def uninlinedFunc2[A1:Manifest,A2:Manifest,B:Manifest](f: Function2[Rep[A1],Rep[A2],Rep[B]]) = { + val s1 = fresh[A1] + val s2 = fresh[A2] + val b = reifyEffects(f(s1,s2)) + uninlinedFunc2(s1,s2,b) + } + def uninlinedFunc2[A1:Manifest,A2:Manifest,B:Manifest](s1: Sym[A1], s2: Sym[A2], b: Block[B]) = { + val l = reflectEffect(UninlinedFunc2(s1,s2,b), Pure()) + functionList2 += (l.asInstanceOf[Sym[Any]] -> (s1,s2,b)) + l + } + + def uninlinedFunc3[A1:Manifest,A2:Manifest,A3:Manifest,B:Manifest](f: Function3[Rep[A1],Rep[A2],Rep[A3],Rep[B]]) = { + val s1 = fresh[A1] + val s2 = fresh[A2] + val s3 = fresh[A3] + val b = reifyEffects(f(s1,s2,s3)) + uninlinedFunc3(s1,s2,s3,b) + } + def uninlinedFunc3[A1:Manifest,A2:Manifest,A3:Manifest,B:Manifest](s1: Sym[A1], s2: Sym[A2], s3: Sym[A3], b: Block[B]) = { + val l = reflectEffect(UninlinedFunc3(s1,s2,s3,b), Pure()) + functionList3 += (l.asInstanceOf[Sym[Any]] -> (s1,s2,s3,b)) + l + } + /* END OF UNINLINED FUNCTIONS */ + override def mirror[A:Manifest](e: Def[A], f: Transformer)(implicit pos: SourceContext): Exp[A] = (e match { case e@Lambda(g,x,y) => toAtom(Lambda(f(g),f(x),f(y))(e.mA,e.mB))(mtype(manifest[A]),implicitly[SourceContext]) - case e@Apply(g,arg) => doApply(f(g), f(arg))(e.mA,mtype(e.mB),implicitly[SourceContext]) - case Reflect(e@Apply(g,arg), u, es) => reflectMirrored(Reflect(Apply(f(g),f(arg))(e.mA,mtype(e.mB)), mapOver(f,u), f(es)))(mtype(manifest[A]), pos) case _ => super.mirror(e,f) }).asInstanceOf[Exp[A]] // why?? @@ -109,8 +190,12 @@ trait FunctionsExp extends Functions with EffectExp { case Lambda(f, x, y) => syms(y) case _ => super.syms(e) } - + override def boundSyms(e: Any): List[Sym[Any]] = e match { + case UninlinedFunc0(f) => effectSyms(f) + case UninlinedFunc1(s,f) => s :: effectSyms(f) + case UninlinedFunc2(s1,s2,f) => s1 :: s2 :: effectSyms(f) + case UninlinedFunc3(s1,s2,s3,f) => s1 :: s2 :: s3 :: effectSyms(f) case Lambda(f, x, y) => syms(x) ::: effectSyms(y) case _ => super.boundSyms(e) } @@ -188,6 +273,22 @@ trait TupledFunctionsExp extends TupledFunctions with FunctionsExp with TupleOps } } + override def aliasSyms(e: Any): List[Sym[Any]] = e match { + case UninlinedFunc0(f) => effectSyms(f) + case UninlinedFunc1(s,f) => s :: effectSyms(f) + case UninlinedFunc2(s1,s2,f) => s1 :: s2 :: effectSyms(f) + case UninlinedFunc3(s1,s2,s3,f) => s1 :: s2 :: s3 :: effectSyms(f) + case _ => super.aliasSyms(e) + } + + override def containSyms(e: Any): List[Sym[Any]] = e match { + case UninlinedFunc0(f) => effectSyms(f) + case UninlinedFunc1(s,f) => s :: effectSyms(f) + case UninlinedFunc2(s1,s2,f) => s1 :: s2 :: effectSyms(f) + case UninlinedFunc3(s1,s2,s3,f) => s1 :: s2 :: s3 :: effectSyms(f) + case _ => super.containSyms(e) + } + override def boundSyms(e: Any): List[Sym[Any]] = e match { case Lambda(f, UnboxedTuple(xs), y) => xs.flatMap(syms) ::: effectSyms(y) case _ => super.boundSyms(e) @@ -223,11 +324,24 @@ trait GenericGenUnboxedTupleAccess extends GenericNestedCodegen { import IR._ override def emitNode(sym: Sym[Any], rhs: Def[Any]) = rhs match { - case FieldApply(UnboxedTuple(vars), "_1") => emitValDef(sym, quote(vars(0))) - case FieldApply(UnboxedTuple(vars), "_2") => emitValDef(sym, quote(vars(1))) - case FieldApply(UnboxedTuple(vars), "_3") => emitValDef(sym, quote(vars(2))) - case FieldApply(UnboxedTuple(vars), "_4") => emitValDef(sym, quote(vars(3))) - case FieldApply(UnboxedTuple(vars), "_5") => emitValDef(sym, quote(vars(4))) + case Tuple2Access1(UnboxedTuple(vars)) => emitValDef(sym, quote(vars(0))) + case Tuple2Access2(UnboxedTuple(vars)) => emitValDef(sym, quote(vars(1))) + + case Tuple3Access1(UnboxedTuple(vars)) => emitValDef(sym, quote(vars(0))) + case Tuple3Access2(UnboxedTuple(vars)) => emitValDef(sym, quote(vars(1))) + case Tuple3Access3(UnboxedTuple(vars)) => emitValDef(sym, quote(vars(2))) + + case Tuple4Access1(UnboxedTuple(vars)) => emitValDef(sym, quote(vars(0))) + case Tuple4Access2(UnboxedTuple(vars)) => emitValDef(sym, quote(vars(1))) + case Tuple4Access3(UnboxedTuple(vars)) => emitValDef(sym, quote(vars(2))) + case Tuple4Access4(UnboxedTuple(vars)) => emitValDef(sym, quote(vars(3))) + + case Tuple5Access1(UnboxedTuple(vars)) => emitValDef(sym, quote(vars(0))) + case Tuple5Access2(UnboxedTuple(vars)) => emitValDef(sym, quote(vars(1))) + case Tuple5Access3(UnboxedTuple(vars)) => emitValDef(sym, quote(vars(2))) + case Tuple5Access4(UnboxedTuple(vars)) => emitValDef(sym, quote(vars(3))) + case Tuple5Access5(UnboxedTuple(vars)) => emitValDef(sym, quote(vars(4))) + case _ => super.emitNode(sym, rhs) } } @@ -243,16 +357,33 @@ trait ScalaGenFunctions extends ScalaGenEffect with BaseGenFunctions { override def emitNode(sym: Sym[Any], rhs: Def[Any]) = rhs match { case e@Lambda(fun, x, y) => - emitValDef(sym, "{" + quote(x) + ": (" + remap(x.tp) + ") => ") + emitValDef(sym, "{" + quote(x, true) + ": (" + x.tp + ") => ") emitBlock(y) - stream.println(quote(getBlockResult(y)) + ": " + remap(y.tp)) + if (y.tp != manifest[Unit]) stream.println(quote(getBlockResult(y)) + ": " + y.tp) stream.println("}") + case UninlinedFunc0(b) => /* Handled in emitFunctions */ {} + case UninlinedFunc1(s,b) => /* Handled in emitFunctions */ {} + case UninlinedFunc2(s1,s2,b) => /* Handled in emitFunctions */ {} + case UninlinedFunc3(s1,s2,s3,b) => /* Handled in emitFunctions */ {} case Apply(fun, arg) => - emitValDef(sym, quote(fun) + "(" + quote(arg) + ")") + arg match { + case Const(()) => emitValDef(sym, quote(fun) + "()") + case _ => emitValDef(sym, quote(fun) + "(" + quote(arg) + ")") + } case _ => super.emitNode(sym, rhs) } + + override def emitFunctions() = { + functionList0.foreach(func => { + stream.println("def " + quote(func._1) + "() = {") + emitBlock(func._2) + emitBlockResult(func._2) + stream.println("}") + }) + functionList0.clear + } } trait ScalaGenTupledFunctions extends ScalaGenFunctions with GenericGenUnboxedTupleAccess { @@ -266,28 +397,26 @@ trait ScalaGenTupledFunctions extends ScalaGenFunctions with GenericGenUnboxedTu override def emitNode(sym: Sym[Any], rhs: Def[Any]) = rhs match { case Lambda(fun, UnboxedTuple(xs), y) => - emitValDef(sym, "{" + xs.map(s=>quote(s)+":"+remap(s.tp)).mkString("(",",",")") + " => ") + emitValDef(sym, "{" + xs.map(s=>quote(s, true)+":"+remap(s.tp)).mkString("(",",",")") + " => ") emitBlock(y) - stream.println(quote(getBlockResult(y)) + ": " + remap(y.tp)) + var ytp = remap(y.tp).toString; + if (ytp != manifest[Unit]) stream.println(quote(getBlockResult(y)) + ": " + ytp ) stream.println("}") - case Apply(fun, UnboxedTuple(args)) => emitValDef(sym, quote(fun) + args.map(quote).mkString("(", ",", ")")) - case _ => super.emitNode(sym,rhs) } - - def unwrapTupleStr[A](m: Manifest[A]): Array[String] = { - val s = m.toString + + def unwrapTupleStr(s: String): Array[String] = { if (s.startsWith("scala.Tuple")) s.slice(s.indexOf("[")+1,s.length-1).filter(c => c != ' ').split(",") - else Array(remap(m)) - } - - override def remap[A](m: Manifest[A]): String = m.toString match { + else scala.Array(s) + } + + override def remap[A](m: Manifest[A]): String = m.toString match { case f if f.startsWith("scala.Function") => val targs = m.typeArguments.dropRight(1) val res = remap(m.typeArguments.last) - val targsUnboxed = targs.flatMap(t => unwrapTupleStr(t)) + val targsUnboxed = targs.flatMap(t => unwrapTupleStr(remap(t))) val sep = if (targsUnboxed.length > 0) "," else "" "scala.Function" + (targsUnboxed.length) + "[" + targsUnboxed.mkString(",") + sep + res + "]" @@ -302,11 +431,10 @@ trait CudaGenFunctions extends CudaGenEffect with BaseGenFunctions { override def emitNode(sym: Sym[Any], rhs: Def[Any]) = { rhs match { case e@Lambda(fun, x, y) => - throw new GenerationFailedException("CudaGenFunctions: Lambda is not supported yet") // The version for inlined device function - //stream.println(addTab() + "%s %s = %s;".format(remap(x.tp), quote(x), quote(sym)+"_1")) - //emitBlock(y) - //stream.println(addTab() + "%s %s = %s;".format(remap(y.tp), quote(sym), quote(getBlockResult(y)))) + stream.println(addTab() + "%s %s = %s;".format(remap(x.tp), quote(x), quote(sym)+"_1")) + emitBlock(y) + stream.println(addTab() + "%s %s = %s;".format(remap(y.tp), quote(sym), quote(getBlockResult(y)))) // The version for separate device function /* @@ -318,6 +446,7 @@ trait CudaGenFunctions extends CudaGenEffect with BaseGenFunctions { stream.println("return %s;".format(quote(getBlockResult(y)))) stream.println("}") */ + case Apply(fun, arg) => emitValDef(sym, quote(fun) + "(" + quote(arg) + ")") @@ -342,68 +471,94 @@ trait OpenCLGenFunctions extends OpenCLGenEffect with BaseGenFunctions { } } -trait CGenFunctions extends CGenEffect with BaseGenFunctions { +trait CGenFunctions extends CNestedCodegen with CGenEffect with BaseGenFunctions { val IR: FunctionsExp import IR._ - // Case for functions with a single argument (therefore, not tupled) override def emitNode(sym: Sym[Any], rhs: Def[Any]) = rhs match { case e@Lambda(fun, x, y) => - val retType = remap(getBlockResult(y).tp) - val retTp = if (cppExplicitFunRet == "true") "function<"+retType+"("+remap(x.tp)+")>" else "auto" - stream.println(retTp+" "+quote(sym)+ - " = [&]("+remap(x.tp)+" "+quote(x)+") {") + stream.println(remap(y.tp)+" "+quote(sym)+"("+remap(x.tp)+" "+quote(x)+") {") emitBlock(y) val z = getBlockResult(y) - if (retType != "void") + if (remap(z.tp) != "void") stream.println("return " + quote(z) + ";") - stream.println("};") - case Apply(fun, arg) => - emitValDef(sym, quote(fun) + "(" + quote(arg) + ")") + stream.println("}") + case UninlinedFunc0(b) => /* Handled in emitFunctions */ {} + case UninlinedFunc1(s,b) => /* Handled in emitFunctions */ {} + case UninlinedFunc2(s1,s2,b) => /* Handled in emitFunctions */ {} + case UninlinedFunc3(s1,s2,s3,b) => /* Handled in emitFunctions */ {} + case a@Apply(fun, arg) => + arg match { + case Const(x) => x match { + case t: scala.Tuple2[Exp[_],Exp[_]] => + emitValDef(sym, quote(fun) + "(" + quote(t._1) + "," + quote(t._2) + ")") + case () => emitValDef(sym, quote(fun) + "()") + case _ => emitValDef(sym, quote(fun) + "(" + quote(arg) + ")") + } + case _ => emitValDef(sym, quote(fun) + "(" + quote(arg) + ")") + } case _ => super.emitNode(sym, rhs) } - + override def emitFunctions() = { + // Output prototypes to resolve dependencies + functionList0.foreach(f=>stream.println(remap(getBlockResult(f._2).tp) + " " + quote(f._1) + "();")) + functionList1.foreach(f=>stream.println(remap(getBlockResult(f._2._2).tp) + " " + quote(f._1) + "(" + remap(f._2._1.tp) + " " + quote(f._2._1) + ");")) + functionList2.foreach(f=>stream.println(remap(getBlockResult(f._2._3).tp) + " " + quote(f._1) + "(" + remap(f._2._1.tp) + " " + quote(f._2._1) + ", " + remap(f._2._2.tp) + " " + quote(f._2._2) +");\n")) + functionList3.foreach(f=>stream.println(remap(getBlockResult(f._2._4).tp) + " " + quote(f._1) + "(" + remap(f._2._1.tp) + " " + quote(f._2._1) + ", " + remap(f._2._2.tp) + " " + quote(f._2._2) + ", " + remap(f._2._3.tp) + " " + quote(f._2._3) + ");\n")) + // Output actual functions + functionList0.foreach(func => { + stream.println(remap(getBlockResult(func._2).tp) + " " + quote(func._1) + "() {") + emitBlock(func._2) + stream.println("return " + quote(getBlockResult(func._2)) + ";") + stream.println("}\n") + }) + functionList1.foreach(func => { + stream.print(remap(getBlockResult(func._2._2).tp) + " " + quote(func._1) + "(") + stream.print(remap(func._2._1.tp) + " " + quote(func._2._1)) + stream.println(") {") + emitBlock(func._2._2) + stream.println("return " + quote(getBlockResult(func._2._2)) + ";") + stream.println("}\n") + }) + functionList2.foreach(func => { + stream.print(remap(getBlockResult(func._2._3).tp) + " " + quote(func._1) + "(") + stream.print(remap(func._2._1.tp) + " " + quote(func._2._1) + ", ") + stream.print(remap(func._2._2.tp) + " " + quote(func._2._2)) + stream.println(") {") + emitBlock(func._2._3) + stream.println("return " + quote(getBlockResult(func._2._3)) + ";") + stream.println("}\n") + }) + functionList3.foreach(func => { + stream.print(remap(getBlockResult(func._2._4).tp) + " " + quote(func._1) + "(") + stream.print(remap(func._2._1.tp) + " " + quote(func._2._1) + ", ") + stream.print(remap(func._2._2.tp) + " " + quote(func._2._2) + ", ") + stream.print(remap(func._2._3.tp) + " " + quote(func._2._3)) + stream.println(") {") + emitBlock(func._2._4) + stream.println("return " + quote(getBlockResult(func._2._4)) + ";") + stream.println("}\n") + }) + functionList0.clear + functionList1.clear + functionList2.clear + functionList3.clear + } } trait CGenTupledFunctions extends CGenFunctions with GenericGenUnboxedTupleAccess { val IR: TupledFunctionsExp import IR._ - - /*override def quote(x: Exp[Any]) : String = x match { - case UnboxedTuple(t) => t.map(quote).mkString("((", ",", "))") - case _ => super.quote(x) - }*/ - - override def emitNode(sym: Sym[Any], rhs: Def[Any]) = rhs match { case Lambda(fun, UnboxedTuple(xs), y) => - val retType = remap(getBlockResult(y).tp) - val retTp = if (cppExplicitFunRet == "true") "function<"+retType+"("+xs.map(s=>remap(s.tp)).mkString(",")+")>" else "auto" - stream.println(retTp+" "+quote(sym)+ - " = [&]("+xs.map(s=>remap(s.tp)+" "+quote(s)).mkString(",")+") {") + stream.println(remap(y.tp)+" "+quote(sym)+"("+xs.map(s=>remap(s.tp)+" "+quote(s)).mkString(",")+") {") emitBlock(y) val z = getBlockResult(y) - if (retType != "void") + if (remap(z.tp) != "void") stream.println("return " + quote(z) + ";") - stream.println("};") - case Apply(fun, UnboxedTuple(args)) => + stream.println("}") + case a@Apply(fun, UnboxedTuple(args)) => emitValDef(sym, quote(fun) + args.map(quote).mkString("(", ",", ")")) case _ => super.emitNode(sym,rhs) } - - /*def unwrapTupleStr(s: String): Array[String] = { - if (s.startsWith("scala.Tuple")) s.slice(s.indexOf("[")+1,s.length-1).filter(c => c != ' ').split(",") - else Array(s) - }*/ - - /*override def remap[A](m: Manifest[A]): String = m.toString match { - case f if f.startsWith("scala.Function") => - val targs = m.typeArguments.dropRight(1) - val res = remap(m.typeArguments.last) - val targsUnboxed = targs.flatMap(t => unwrapTupleStr(remap(t))) - val sep = if (targsUnboxed.length > 0) "," else "" - "scala.Function" + (targsUnboxed.length) + "[" + targsUnboxed.mkString(",") + sep + res + "]" - - case _ => super.remap(m) - }*/ } diff --git a/src/common/GeneratorOps.scala b/src/common/GeneratorOps.scala new file mode 100644 index 00000000..f7ea9652 --- /dev/null +++ b/src/common/GeneratorOps.scala @@ -0,0 +1,297 @@ +package scala.lms +package common + +import java.io.PrintWriter +import scala.lms.common._ +import scala.collection.mutable +import scala.collection.mutable.Set +import scala.reflect.SourceContext +import scala.collection.mutable.Map + +trait MapOps extends Base { + type MapType[K,V] + + def newMapType[K: Manifest, V: Manifest](): Rep[MapType[K, V]] + def lookupOrDefault[K, V: Manifest](x: Rep[MapType[K, V]], key: Rep[K], defaultVal: Rep[V]): Rep[V] + def updateValue[K, V](x: Rep[MapType[K, V]], key: Rep[K], value: Rep[V]): Rep[Unit] +} + +trait GeneratorOps extends Base with Variables with LiftVariables with IfThenElse with Equal with TupleOps with ListOps with MapOps with ObjectOps with StringOps with HashMapOps with SetOps with LiftNumeric with NumericOps with ArrayOps { + def materializeGenerator[T:Manifest,U:Manifest](gen: Generator[U]): Rep[T] + def dematerializeGenerator[T:Manifest,U:Manifest](genCon: Rep[T]): Generator[U] + + def materializeTupleGenerator[T:Manifest,U:Manifest,V:Manifest](gen: TupleGenerator[U,V]): Rep[T] + def dematerializeTupleGenerator[T:Manifest,U:Manifest,V:Manifest](genCon: Rep[T]): TupleGenerator[U,V] + + implicit def generatorToRep[T:Manifest](gen: Generator[T]): Rep[T] = materializeGenerator[T,T](gen) + //TODO - This implicite should be defined for every collection type + //implicit def repToGenerator[T:Manifest,U:Manifest](genCon: Rep[T]): Generator[U] = dematerializeGenerator[T,U](genCon) + + implicit def tupledGeneratorToRep[K:Manifest, V:Manifest](gen: TupleGenerator[K,V]): Rep[(K,V)] = materializeTupleGenerator[(K,V),K,V](gen) + //TODO - This implicite should be defined for every collection type + //implicit def repToGenerator[T:Manifest,U:Manifest](genCon: Rep[T]): Generator[U] = dematerializeGenerator[T,U](genCon) + + abstract class TupleGenerator[K:Manifest, V:Manifest] extends ((Rep[(K,V)] => Rep[Unit]) => Rep[Unit]) with Serializable /*extends Generator[(K,V)]*/ { self => + + /*override*/ def map[K2:Manifest, V2:Manifest](g: Rep[(K,V)] => Rep[(K2,V2)]) = new TupleGenerator[K2,V2] { + def apply(f: Rep[(K2,V2)] => Rep[Unit]) = self.apply { + x:Rep[(K,V)] => f(g(x)) + } + } + + /*override*/ def filter(p: Rep[(K,V)] => Rep[Boolean]) = new TupleGenerator[K,V] { + def apply(f: Rep[(K,V)] => Rep[Unit]) = self.apply { + x:Rep[(K,V)] => if(p(x)) f(x) + } + } + + /*override*/ def ++(that: TupleGenerator[K,V]) = new TupleGenerator[K,V] { + def apply(f: Rep[(K,V)] => Rep[Unit]) = { + self.apply(f) + that.apply(f) + } + } + + /*override*/ def flatMap[K2:Manifest, V2:Manifest](g: Rep[(K,V)] => TupleGenerator[K2,V2]) = new TupleGenerator[K2,V2] { + def apply(f: Rep[(K2,V2)] => Rep[Unit]) = self.apply { x:Rep[(K,V)] => + val tmp : TupleGenerator[K2,V2] = g(x) + tmp(f) + } + } + + /*override*/ def reduce(h:(Rep[(K,V)],Rep[(K,V)])=>Rep[(K,V)], z:Rep[(K,V)]) = new TupleGenerator[K,V] { + def apply(f: Rep[(K,V)] => Rep[Unit]) = { + var best = z; + self.apply { x:Rep[(K,V)] => if (best==z) best=x; else best=h(best,x) } + if (best!=z) f(best) + } + } + + /*override*/ def flatten[K2:Manifest, V2:Manifest] = flatMap[K2,V2] { + x:Rep[(K,V)] => dematerializeTupleGenerator[(K,V),K2,V2](x) + } + + /*override*/ def fold[Y:Manifest](init: Rep[Y], g: Rep[(K,V)] => (Rep[Y] => Rep[Y])): Rep[Y] = { + var res = init + self.apply { + x:Rep[(K,V)] => res = g(x)(res) + } + res + } + + /*override*/ def foldLong(init: Rep[Long], g: Rep[(K,V)] => (Rep[Long] => Rep[Long])): Rep[Long] = { + var res = init + self.apply { + x:Rep[(K,V)] => res = g(x)(res) + } + res + } + + /*override*/ def foreach(g: Rep[(K,V)] => Rep[Unit]) = self.apply { + x:Rep[(K,V)] => g(x) + } + + /*override*/ def toList: Rep[List[(K,V)]] = { + var resList = List[(K,V)]() + self.apply { + x:Rep[(K,V)] => resList = x :: resList + } + resList + } + + def groupByAggregate[K2:Manifest, V2:Manifest](init: Rep[V2], group: Rep[(K, V)] => Rep[K2], + fn: Rep[(K, V)] => (Rep[V2] => Rep[V2])): Rep[MapType[K2, V2]] = { + val grps = newMapType[K2,V2]() + self.apply { + x:Rep[(K,V)] => { + val key: Rep[K2] = group(x) + val value = fn(x)(lookupOrDefault(grps,key,init)) + updateValue(grps,key,value) + } + } + grps + } + + def groupByMultipleAggregates[K2:Manifest, V2:Manifest](newMapFun: () => Rep[scala.collection.mutable.HashMap[K2, Array[V2]]], numAggs: Rep[Long], group: Rep[(K, V)] => Rep[K2], fn: (Rep[V], Rep[V2]) => Rep[V2]*): Rep[scala.collection.mutable.HashMap[K2, Array[V2]]] = { + val grps = newMapFun()//HashMap[K2,Array[V2]]() + self.apply { + x:Rep[(K,V)] => { + val key: Rep[K2] = group(x) + val aggs = grps.getOrElseUpdate(key, NewArray[V2](numAggs))//fn.length))// lookupOrDefault[K2,Array[V2]](grps,key,init)) + fn.foldLeft(0L) { (cnt,aggfn) => { + val value = aggfn(x._2,aggs(cnt)) + aggs(cnt) = value + cnt+1 + } } + unit() + } + } + grps + } + + def mkString(delimiter: Rep[String] = unit("")): Rep[String] = { + var res = string_new(unit("")) + self.apply { + x:Rep[(K,V)] => res = res + infix_ToString(x)//.ToString + if (delimiter != unit("")) res = res + delimiter + } + res + } + + + /*def slice[K2: Manifest](kp: Rep[K2], idx: Rep[List[Int]]): TupleGenerator[K, V] = { + self.filter{ + kv:Rep[(K,V)] => { + val k = kv._1.asInstanceOf[Rep[(K2,_)]] + kp == tuple2_get1(k) + } + } + }*/ + } + + // Generator[T] === (T => Unit) => Unit + abstract class Generator[T:Manifest] extends ((Rep[T] => Rep[Unit]) => Rep[Unit]) with Serializable { self => + + //Rep[T => U] != Rep[T] => Rep[U] + def map[U:Manifest](g: Rep[T] => Rep[U]) = new Generator[U] { + def apply(f: Rep[U] => Rep[Unit]) = self.apply { + x:Rep[T] => f(g(x)) + } + } + + def filter(p: Rep[T] => Rep[Boolean]) = new Generator[T] { + def apply(f: Rep[T] => Rep[Unit]) = self.apply { + x:Rep[T] => if(p(x)) f(x) + } + } + + def ++(that: Generator[T]) = new Generator[T] { + def apply(f: Rep[T] => Rep[Unit]) = { + self.apply(f) + that.apply(f) + } + } + + def flatMap[U:Manifest](g: Rep[T] => Generator[U]) = new Generator[U]{ + def apply(f: Rep[U] => Rep[Unit]) = self.apply{ x:Rep[T] => + val tmp : Generator[U] = g(x) + tmp(f) + } + } + + def reduce(h:(Rep[T],Rep[T])=>Rep[T], z:Rep[T]) = new Generator[T] { + def apply(f: Rep[T] => Rep[Unit]) = { + var best = z; + self.apply { x:Rep[T] => if (best==z) best=x; else best=h(best,x) } + if (best!=z) f(best) + } + } + + def flatten[U:Manifest] = flatMap[U] { + x:Rep[T] => dematerializeGenerator[T,U](x) + } + + def fold[Y:Manifest](init: Rep[Y], g: Rep[T] => (Rep[Y] => Rep[Y])): Rep[Y] = { + var res = init + self.apply { + x:Rep[T] => res = g(x)(res) + } + res + } + + def sum(implicit num: Numeric[T]): Rep[T] = { + var res = unit(num.zero) + self.apply { + x:Rep[T] => res = numeric_plus(x, readVar(res)) + } + readVar(res) + } + + + def foldLong(init: Rep[Long], g: Rep[T] => (Rep[Long] => Rep[Long])): Rep[Long] = { + var res = init + self.apply { + x:Rep[T] => res = g(x)(res) + } + res + } + + def foreach(g: Rep[T] => Rep[Unit]) = self.apply { + x:Rep[T] => g(x) + } + + def toList: Rep[List[T]] = { + var resList = List[T]() + self.apply { + x:Rep[T] => resList = x :: resList + } + resList + } + + } + + case class EmptyGen[T:Manifest]() extends Generator[T]{ + def apply(f: Rep[T] => Rep[Unit]) = {} + } + + def emptyGen[A:Manifest](): Generator[A] = EmptyGen[A] + + def elGen[A:Manifest](a: Rep[A]): Generator[A] = new Generator[A]{ + def apply(f: Rep[A] => Rep[Unit]) = { + f(a) + } + } + + def cond[A:Manifest](cond: Rep[Boolean], a: Generator[A], b: Generator[A]) = new Generator[A]{ + def apply(f: Rep[A] => Rep[Unit]) = { + if(cond) a(f) else b(f) + } + } +} + +trait GeneratorOpsExp extends GeneratorOps with EffectExp with VariablesExp with IfThenElseExp with EqualExp with TupleOpsExp with ListOpsExp with ObjectOpsExp with StringOpsExp with HashMapOpsExp with ListBufferExp with HashMultiMapOpsExp with SetOpsExp with ArrayOpsExp { + + case class GeneratorContainer[T: Manifest,U:Manifest](gen: Generator[U]) extends Def[T] + case class TupleGeneratorContainer[T: Manifest,U:Manifest,V:Manifest](gen: TupleGenerator[U,V]) extends Def[T] + + def materializeGenerator[T:Manifest,U:Manifest](gen: Generator[U]): Rep[T] = GeneratorContainer[T,U](gen) + def dematerializeGenerator[T:Manifest,U:Manifest](genCon: Rep[T]): Generator[U] = { + findDefinition(genCon.asInstanceOf[Sym[T]]).get.rhs match { + case Reflect(ReadVar(x), _, _) => x.asInstanceOf[GeneratorContainer[T,U]].gen + case x => x.asInstanceOf[GeneratorContainer[T,U]].gen + } + } + def materializeTupleGenerator[T:Manifest,U:Manifest,V:Manifest](gen: TupleGenerator[U,V]): Rep[T] = TupleGeneratorContainer[T,U,V](gen) + def dematerializeTupleGenerator[T:Manifest,U:Manifest,V:Manifest](genCon: Rep[T]): TupleGenerator[U,V] = { + findDefinition(genCon.asInstanceOf[Sym[T]]).get.rhs match { + case Reflect(ReadVar(x), _, _) => x.asInstanceOf[TupleGeneratorContainer[T,U,V]].gen + case x => x.asInstanceOf[TupleGeneratorContainer[T,U,V]].gen + } + } +} + +trait ScalaGenGeneratorOps extends ScalaGenVariables + with ScalaGenIfThenElse with ScalaGenEqual with ScalaGenListOps with ScalaGenTupleOps { + val IR: GeneratorOpsExp + import IR._ + + override def emitNode(sym: Sym[Any], rhs: Def[Any]) = rhs match { + // currently, we shoud explicitly call toList method on a generator, in order to convert it again to list + + case TupleGeneratorContainer(gen) => val genList = gen.toList; emitNode(sym, Def.unapply(genList).get) + case GeneratorContainer(gen) => + case _ => super.emitNode(sym, rhs) + } + +} + +trait CGenGeneratorOps extends CGenVariables + with CGenIfThenElse with CLikeGenEqual with CLikeGenListOps { + val IR: GeneratorOpsExp + import IR._ + + /*override def emitNode(sym: Sym[Any], rhs: Def[Any]) = rhs match { + case _ => super.emitNode(sym, rhs) + }*/ + +} diff --git a/src/common/IOOps.scala b/src/common/IOOps.scala index 62f9e076..74f52201 100644 --- a/src/common/IOOps.scala +++ b/src/common/IOOps.scala @@ -1,8 +1,8 @@ package scala.lms package common -import java.io.{File, FileReader, FileWriter, BufferedReader, BufferedWriter, PrintWriter} -import scala.lms.internal.{GenerationFailedException} +import java.io.{File, FileReader, FileWriter, BufferedReader, BufferedWriter, PrintWriter, FileOutputStream, ObjectOutputStream, FileInputStream, ObjectInputStream} +import scala.lms.internal.{GenerationFailedException, GenericNestedCodegen} import util.OverloadHack import scala.reflect.SourceContext @@ -19,11 +19,13 @@ trait IOOps extends Variables with OverloadHack { def infix_getCanonicalFile(f: Rep[File])(implicit pos: SourceContext) = file_getcanonicalfile(f) def infix_getPath(f: Rep[File])(implicit pos: SourceContext) = file_getpath(f) def infix_listFiles(f: Rep[File])(implicit pos: SourceContext) = file_listfiles(f) + def infix_close(f: Rep[File])(implicit pos: SourceContext, o: Overloaded1) = file_close(f) // Only for the C code gen def obj_file_apply(dir: Rep[String])(implicit pos: SourceContext): Rep[File] def file_getcanonicalfile(f: Rep[File])(implicit pos: SourceContext): Rep[File] def file_getpath(f: Rep[File])(implicit pos: SourceContext): Rep[String] def file_listfiles(f: Rep[File])(implicit pos: SourceContext): Rep[Array[File]] + def file_close(f: Rep[File])(implicit pos: SourceContext): Rep[Unit] // Only for the C code gen /** * BufferedReader @@ -32,7 +34,7 @@ trait IOOps extends Variables with OverloadHack { def apply(f: Rep[FileReader])(implicit pos: SourceContext) = obj_br_apply(f) } def infix_readLine(b: Rep[BufferedReader])(implicit pos: SourceContext) = br_readline(b) - def infix_close(b: Rep[BufferedReader])(implicit pos: SourceContext) = br_close(b) + def infix_close(b: Rep[BufferedReader])(implicit pos: SourceContext, o: Overloaded2) = br_close(b) def obj_br_apply(f: Rep[FileReader])(implicit pos: SourceContext): Rep[BufferedReader] def br_readline(b: Rep[BufferedReader])(implicit pos: SourceContext): Rep[String] @@ -67,6 +69,58 @@ trait IOOps extends Variables with OverloadHack { def apply(s: Rep[String])(implicit pos: SourceContext) = obj_fw_apply(s) } def obj_fw_apply(s: Rep[String])(implicit pos: SourceContext): Rep[FileWriter] + + /** + * ObjectOutputStream + */ + object FileInputStream { + def apply(s: Rep[String])(implicit pos: SourceContext) = obj_fis_apply(s) + } + class FileInputStreamOps(x: Rep[FileInputStream]) { + def available()(implicit pos: SourceContext) = obj_fis_available(x) + } + implicit def fisToFisOps(x: Rep[FileInputStream]) = new FileInputStreamOps(x) + def obj_fis_apply(s: Rep[String]): Rep[FileInputStream] + def obj_fis_available(s: Rep[FileInputStream]): Rep[Int] + + object ObjectInputStream { + def apply(s: Rep[FileInputStream])(implicit pos: SourceContext) = obj_ois_apply(s) + } + class ObjectInputStreamOps(x: Rep[ObjectInputStream]) { + def readObject(dynamicType: String = null)(implicit pos: SourceContext) = obj_ois_readObject(x, dynamicType) + def close()(implicit pos: SourceContext) = obj_ois_close(x) + } + implicit def oisTooisOps(x: Rep[ObjectInputStream]) = new ObjectInputStreamOps(x) + def obj_ois_apply(s: Rep[FileInputStream]): Rep[ObjectInputStream] + def obj_ois_close(s: Rep[ObjectInputStream]): Rep[Unit] + def obj_ois_readObject(x: Rep[ObjectInputStream], dynamicType: String = null): Rep[Object] + + object ObjectOutputStream { + def apply(s: Rep[FileOutputStream])(implicit pos: SourceContext) = obj_oos_apply(s, unit(false)) + def apply(s: Rep[FileOutputStream], x: Rep[Boolean])(implicit pos: SourceContext) = obj_oos_apply(s, x) + } + class ObjectOutputStreamOps(x: Rep[ObjectOutputStream]) { + def writeObject(elem: Rep[Any])(implicit pos: SourceContext) = obj_oos_writeObject(x,elem) + def close()(implicit pos: SourceContext) = obj_oos_close(x) + } + implicit def oosToOoosOps(x: Rep[ObjectOutputStream]) = new ObjectOutputStreamOps(x) + def obj_oos_apply(s: Rep[FileOutputStream], x: Rep[Boolean])(implicit pos: SourceContext): Rep[ObjectOutputStream] + def obj_oos_writeObject(s: Rep[ObjectOutputStream], elem: Rep[Any])(implicit pos: SourceContext): Rep[Unit] + def obj_oos_close(s: Rep[ObjectOutputStream])(implicit pos: SourceContext): Rep[Unit] + + /** + * FileWriter + */ + object FileOutputStream { + def apply(s: Rep[File])(implicit pos: SourceContext) = obj_fos_apply(s) + } + def obj_fos_apply(s: Rep[File])(implicit pos: SourceContext): Rep[FileOutputStream] + + object FileLineCount { + def apply(s: Rep[String])(implicit pos: SourceContext) = file_line_count(s) + } + def file_line_count(s: Rep[String])(implicit pos: SourceContext): Rep[Int] + } trait IOOpsExp extends IOOps with DSLOpsExp { @@ -74,42 +128,66 @@ trait IOOpsExp extends IOOps with DSLOpsExp { case class FileGetCanonicalFile(f: Exp[File]) extends Def[File] case class FileGetPath(f: Exp[File]) extends Def[String] case class FileListFiles(f: Exp[File]) extends Def[Array[File]] + case class FileClose(f: Exp[File]) extends Def[Unit] // Only for the C code gen case class ObjBrApply(f: Exp[FileReader]) extends Def[BufferedReader] case class ObjBwApply(f: Exp[FileWriter]) extends Def[BufferedWriter] case class ObjFrApply(s: Exp[String]) extends Def[FileReader] case class ObjFwApply(s: Exp[String]) extends Def[FileWriter] + case class ObjOosApply(s: Exp[FileOutputStream], x: Rep[Boolean]) extends Def[ObjectOutputStream] + case class ObjOosWriteObject(s: Exp[ObjectOutputStream], elem: Exp[Any]) extends Def[Unit] + case class ObjOosClose(s: Exp[ObjectOutputStream]) extends Def[Unit] + case class ObjFosApply(s: Exp[File]) extends Def[FileOutputStream] + case class ObjFisApply(s: Exp[String]) extends Def[FileInputStream] + case class ObjOisApply(s: Exp[FileInputStream]) extends Def[ObjectInputStream] + case class ObjOisClose(s: Exp[ObjectInputStream]) extends Def[Unit] + case class ObjOisAvailable(s: Exp[FileInputStream]) extends Def[Int] + case class ObjOisReadObject(s: Exp[ObjectInputStream], dynamicType: String = null) extends Def[Object] case class BwWrite(b: Exp[BufferedWriter], s: Rep[String]) extends Def[Unit] case class BwClose(b: Exp[BufferedWriter]) extends Def[Unit] case class BrReadline(b: Exp[BufferedReader]) extends Def[String] case class BrClose(b: Exp[BufferedReader]) extends Def[Unit] + case class CountFileLines(b: Exp[String]) extends Def[Int] { + val f = fresh[java.io.File] // used in c code gen + } def obj_file_apply(dir: Exp[String])(implicit pos: SourceContext): Exp[File] = reflectEffect(ObjFileApply(dir)) def file_getcanonicalfile(f: Exp[File])(implicit pos: SourceContext) = FileGetCanonicalFile(f) def file_getpath(f: Exp[File])(implicit pos: SourceContext) = FileGetPath(f) def file_listfiles(f: Exp[File])(implicit pos: SourceContext) = FileListFiles(f) + def file_close(f: Exp[File])(implicit pos: SourceContext) = FileClose(f) // Only for the C code gen def obj_br_apply(f: Exp[FileReader])(implicit pos: SourceContext): Exp[BufferedReader] = reflectEffect(ObjBrApply(f)) def obj_bw_apply(f: Exp[FileWriter])(implicit pos: SourceContext): Exp[BufferedWriter] = reflectEffect(ObjBwApply(f)) def obj_fr_apply(s: Exp[String])(implicit pos: SourceContext): Exp[FileReader] = reflectEffect(ObjFrApply(s)) def obj_fw_apply(s: Exp[String])(implicit pos: SourceContext): Exp[FileWriter] = reflectEffect(ObjFwApply(s)) + def obj_oos_apply(s: Exp[FileOutputStream], x: Rep[Boolean])(implicit pos: SourceContext): Exp[ObjectOutputStream] = reflectEffect(ObjOosApply(s,x)) + def obj_oos_writeObject(s: Exp[ObjectOutputStream], elem: Exp[Any])(implicit pos: SourceContext): Exp[Unit] = reflectEffect(ObjOosWriteObject(s, elem)) + def obj_oos_close(s: Exp[ObjectOutputStream])(implicit pos: SourceContext): Exp[Unit] = reflectEffect(ObjOosClose(s)) + def obj_fos_apply(s: Exp[File])(implicit pos: SourceContext): Exp[FileOutputStream] = reflectEffect(ObjFosApply(s)) + def obj_fis_apply(s: Rep[String]) = reflectEffect(ObjFisApply(s)) + def obj_ois_apply(s: Rep[FileInputStream]) = reflectEffect(ObjOisApply(s)) + def obj_ois_close(s: Rep[ObjectInputStream]) = reflectEffect(ObjOisClose(s)) + def obj_fis_available(s: Rep[FileInputStream]) = reflectEffect(ObjOisAvailable(s)) + def obj_ois_readObject(x: Rep[ObjectInputStream], dynamicType: String = null) = reflectEffect(ObjOisReadObject(x, dynamicType)) def bw_write(b: Exp[BufferedWriter], s: Exp[String])(implicit pos: SourceContext) = reflectEffect(BwWrite(b,s)) def bw_close(b: Exp[BufferedWriter])(implicit pos: SourceContext) = reflectEffect(BwClose(b)) def br_readline(b: Exp[BufferedReader])(implicit pos: SourceContext) : Exp[String] = reflectEffect(BrReadline(b)) def br_close(b: Exp[BufferedReader])(implicit pos: SourceContext) : Exp[Unit] = reflectEffect(BrClose(b)) + def file_line_count(s: Rep[String])(implicit pos: SourceContext) = reflectEffect(CountFileLines(s)) override def mirror[A:Manifest](e: Def[A], f: Transformer)(implicit pos: SourceContext): Exp[A] = ({ e match { - case Reflect(ObjFrApply(s), u, es) => reflectMirrored(Reflect(ObjFrApply(f(s)), mapOver(f,u), f(es)))(mtype(manifest[A]), pos) - case Reflect(ObjBrApply(x), u, es) => reflectMirrored(Reflect(ObjBrApply(f(x)), mapOver(f,u), f(es)))(mtype(manifest[A]), pos) - case Reflect(ObjFwApply(s), u, es) => reflectMirrored(Reflect(ObjFwApply(f(s)), mapOver(f,u), f(es)))(mtype(manifest[A]), pos) - case Reflect(ObjBwApply(x), u, es) => reflectMirrored(Reflect(ObjBwApply(f(x)), mapOver(f,u), f(es)))(mtype(manifest[A]), pos) - case Reflect(BrReadline(b), u, es) => reflectMirrored(Reflect(BrReadline(f(b)), mapOver(f,u), f(es)))(mtype(manifest[A]), pos) - case Reflect(BwWrite(b,s), u, es) => reflectMirrored(Reflect(BwWrite(f(b),f(s)), mapOver(f,u), f(es)))(mtype(manifest[A]), pos) - case Reflect(BrClose(b), u, es) => reflectMirrored(Reflect(BrClose(f(b)), mapOver(f,u), f(es)))(mtype(manifest[A]), pos) - case Reflect(BwClose(b), u, es) => reflectMirrored(Reflect(BwClose(f(b)), mapOver(f,u), f(es)))(mtype(manifest[A]), pos) + case Reflect(ObjFrApply(s), u, es) => obj_fr_apply(f(s)) + case Reflect(ObjBrApply(x), u, es) => obj_br_apply(f(x)) + case Reflect(ObjFwApply(s), u, es) => obj_fw_apply(f(s)) + case Reflect(ObjBwApply(x), u, es) => obj_bw_apply(f(x)) + case Reflect(BrReadline(b), u, es) => br_readline(f(b)) + case Reflect(BwWrite(b,s), u, es) => bw_write(f(b),f(s)) + case Reflect(BrClose(b), u, es) => br_close(f(b)) + case Reflect(BwClose(b), u, es) => bw_close(f(b)) case _ => super.mirror(e,f) } }).asInstanceOf[Exp[A]] @@ -124,32 +202,68 @@ trait ScalaGenIOOps extends ScalaGenBase { case FileGetCanonicalFile(f) => emitValDef(sym, src"$f.getCanonicalFile()") case FileGetPath(f) => emitValDef(sym, src"$f.getPath()") case FileListFiles(f) => emitValDef(sym, src"$f.listFiles()") + case FileClose(f) => throw new GenerationFailedException("File.close is not defined for Scala Generation, only for C! Maybe you meant to close the BufferedStreams instead.") case ObjBrApply(f) => emitValDef(sym, src"new java.io.BufferedReader($f)") case ObjBwApply(f) => emitValDef(sym, src"new java.io.BufferedWriter($f)") case ObjFrApply(s) => emitValDef(sym, src"new java.io.FileReader($s)") case ObjFwApply(s) => emitValDef(sym, src"new java.io.FileWriter($s)") + case ObjOosApply(s,x) => + if (x == Const(true)) { + emitValDef(sym, src"new java.io.ObjectOutputStream($s){") + gen"""override protected def writeStreamHeader() { + |reset(); + |} + |}""" + } else emitValDef(sym, src"new java.io.ObjectOutputStream($s)") + case ObjOosWriteObject(s, elem) => gen"$s.writeObject($elem)" + case ObjOosClose(s) => gen"$s.close" + case ObjFosApply(s) => emitValDef(sym, src"new java.io.FileOutputStream($s,true)") + case ObjFisApply(s) => emitValDef(sym, src"new java.io.FileInputStream($s)") + case ObjOisApply(s) => emitValDef(sym, src"new java.io.ObjectInputStream($s)") + case ObjOisClose(s) => emitValDef(sym, src"$s.close") + case ObjOisAvailable(s) => emitValDef(sym, src"$s.available") + case ObjOisReadObject(s, dtype) => { + if (dtype == null) emitValDef(sym, src"$s.readObject") + else emitValDef(sym, src"$s.readObject.asInstanceOf[$dtype]") + } case BwWrite(b,s) => emitValDef(sym, src"$b.write($s)") case BwClose(b) => emitValDef(sym, src"$b.close()") case BrReadline(b) => emitValDef(sym, src"$b.readLine()") case BrClose(b) => emitValDef(sym, src"$b.close()") + case CountFileLines(b) => emitValDef(sym, "{import scala.sys.process._; Integer.parseInt(((\"wc -l \" +" + quote(b) + ") #| \"awk {print($1)}\" !!).replaceAll(\"\\\\s+$\", \"\"))}") case _ => super.emitNode(sym, rhs) } } -trait CLikeGenIOOps extends CLikeGenBase { +trait CLikeGenIOOps extends CLikeGenBase with GenericNestedCodegen { val IR: IOOpsExp import IR._ + override def remap[A](m: Manifest[A]) = { + m match { + case s if s == manifest[File] => "FILE*" + case _ => super.remap(m) + } + } + override def emitNode(sym: Sym[Any], rhs: Def[Any]) = rhs match { + case ObjFileApply(dir) => emitValDef(sym, "fopen(" + quote(dir) + ", \"rw\")") + case ObjOosApply(s,x) => quote(s) + case ObjFosApply(s) => quote(s) + case FileClose(s) => stream.println("fclose(" + quote(s) + ")") case ObjBrApply(f) => throw new GenerationFailedException("CLikeGenIOOps: Java IO operations are not supported") case ObjFrApply(s) => throw new GenerationFailedException("CLikeGenIOOps: Java IO operations are not supported") case BrReadline(b) => throw new GenerationFailedException("CLikeGenIOOps: Java IO operations are not supported") case BrClose(b) => throw new GenerationFailedException("CLikeGenIOOps: Java IO operations are not supported") + case c@CountFileLines(b) => { + emitValDef(c.f, "popen(\"wc -l " + quote(b).replace("\"","") + "\",\"r\");") + stream.println("int " + quote(sym) + " = 0;") + stream.println("fscanf(" + quote(c.f) + ",\"%d\", &" + quote(sym) + ");") + stream.println("pclose(" + quote(c.f) + ");") + } case _ => super.emitNode(sym, rhs) } } trait CudaGenIOOps extends CudaGenBase with CLikeGenIOOps trait OpenCLGenIOOps extends OpenCLGenBase with CLikeGenIOOps -trait CGenIOOps extends CGenBase with CLikeGenIOOps - - +trait CGenIOOps extends CGenBase with CLikeGenIOOps diff --git a/src/common/IfThenElse.scala b/src/common/IfThenElse.scala index c691c9cf..7673201c 100644 --- a/src/common/IfThenElse.scala +++ b/src/common/IfThenElse.scala @@ -40,7 +40,7 @@ trait IfThenElseExp extends IfThenElse with EffectExp { val a = reifyEffectsHere(thenp) val b = reifyEffectsHere(elsep) - ifThenElse(cond,a,b) + ifThenElse[T](cond,a,b) } def ifThenElse[T:Manifest](cond: Rep[Boolean], thenp: Block[T], elsep: Block[T])(implicit pos: SourceContext) = { @@ -54,7 +54,7 @@ trait IfThenElseExp extends IfThenElse with EffectExp { // (see TestMutation, for now sticking to old behavior) ////reflectEffect(IfThenElse(cond,thenp,elsep), ae orElse be) - reflectEffectInternal(IfThenElse(cond,thenp,elsep), ae orElse be) + reflectEffectInternal(IfThenElse[T](cond,thenp,elsep), ae orElse be) } override def mirrorDef[A:Manifest](e: Def[A], f: Transformer)(implicit pos: SourceContext): Def[A] = e match { @@ -67,7 +67,7 @@ trait IfThenElseExp extends IfThenElse with EffectExp { if (f.hasContext) __ifThenElse(f(c),f.reflectBlock(a),f.reflectBlock(b)) else - reflectMirrored(Reflect(IfThenElse(f(c),f(a),f(b)), mapOver(f,u), f(es)))(mtype(manifest[A]), pos) + reflectMirrored(Reflect(IfThenElse(f(c),f(a),f(b)), mapOver(f,u), f(es)))(mtype(manifest[A])) case IfThenElse(c,a,b) => if (f.hasContext) __ifThenElse(f(c),f.reflectBlock(a),f.reflectBlock(b)) @@ -200,6 +200,12 @@ trait IfThenElseExpOpt extends IfThenElseExp { this: BooleanOpsExp with EqualExp } } + + + + + + trait BaseGenIfThenElse extends GenericNestedCodegen { val IR: IfThenElseExp import IR._ @@ -229,13 +235,18 @@ trait ScalaGenIfThenElse extends ScalaGenEffect with BaseGenIfThenElse { override def emitNode(sym: Sym[Any], rhs: Def[Any]) = rhs match { case IfThenElse(c,a,b) => - stream.println("val " + quote(sym) + " = if (" + quote(c) + ") {") - emitBlock(a) - stream.println(quote(getBlockResult(a))) - stream.println("} else {") - emitBlock(b) - stream.println(quote(getBlockResult(b))) - stream.println("}") + val strWriter = new java.io.StringWriter + val localStream = new PrintWriter(strWriter); + withStream(localStream) { + stream.println("if (" + quote(c) + ") {") + emitBlock(a) + emitBlockResult(a) + stream.println("} else {") + emitBlock(b) + emitBlockResult(b) + stream.print("}") + } + emitValDef(sym, strWriter.toString) case _ => super.emitNode(sym, rhs) } } @@ -245,7 +256,7 @@ trait ScalaGenIfThenElseFat extends ScalaGenIfThenElse with ScalaGenFat with Bas override def emitFatNode(symList: List[Sym[Any]], rhs: FatDef) = rhs match { case SimpleFatIfThenElse(c,as,bs) => - def quoteList[T](xs: List[Exp[T]]) = if (xs.length > 1) xs.map(quote).mkString("(",",",")") else xs.map(quote).mkString(",") + def quoteList[T](xs: List[Exp[T]]) = if (xs.length > 1) xs.map(x => quote(x, true)).mkString("(",",",")") else xs.map(x => quote(x,true)).mkString(",") if (symList.length > 1) stream.println("// TODO: use vars instead of tuples to return multiple values") stream.println("val " + quoteList(symList) + " = if (" + quote(c) + ") {") emitFatBlock(as) @@ -375,37 +386,23 @@ trait CGenIfThenElse extends CGenEffect with BaseGenIfThenElse { override def emitNode(sym: Sym[Any], rhs: Def[Any]) = { rhs match { case IfThenElse(c,a,b) => - //TODO: using if-else does not work - remap(sym.tp) match { - case "void" => + //TODO: using if-else does noIIIt work + isVoidType(sym.tp) match { + case true => stream.println("if (" + quote(c) + ") {") emitBlock(a) stream.println("} else {") emitBlock(b) stream.println("}") - case _ => - if (cppIfElseAutoRet == "true") { - val ten = quote(sym) + "True" - val fen = quote(sym) + "False" - def emitCondFun[T: Manifest](fname: String, block: Block[T]) { - stream.println("auto " + fname + " = [&]() {"); - emitBlock(block) - stream.println("return " + quote(getBlockResult(block)) + ";") - stream.println("};") - } - emitCondFun(ten, a) - emitCondFun(fen, b) - stream.println("auto " + quote(sym) + " = " + quote(c) + " ? " + ten + "() : " + fen + "();") - } else { - stream.println("%s %s;".format(remap(sym.tp),quote(sym))) - stream.println("if (" + quote(c) + ") {") - emitBlock(a) - stream.println("%s = %s;".format(quote(sym),quote(getBlockResult(a)))) - stream.println("} else {") - emitBlock(b) - stream.println("%s = %s;".format(quote(sym),quote(getBlockResult(b)))) - stream.println("}") - } + case false => + stream.println("%s %s;".format(remap(getBlockResult(a).tp),quote(sym))) + stream.println("if (" + quote(c) + ") {") + emitBlock(a) + stream.println("%s = %s;".format(quote(sym),quote(getBlockResult(a)))) + stream.println("} else {") + emitBlock(b) + stream.println("%s = %s;".format(quote(sym),quote(getBlockResult(b)))) + stream.println("}") } /* val booll = remap(sym.tp).equals("void") @@ -436,7 +433,17 @@ trait CGenIfThenElseFat extends CGenIfThenElse with CGenFat with BaseGenIfThenEl import IR._ override def emitFatNode(symList: List[Sym[Any]], rhs: FatDef) = rhs match { - case SimpleFatIfThenElse(c,a,b) => sys.error("TODO: implement fat if C codegen") + case SimpleFatIfThenElse(c,as,bs) => + def quoteList[T](xs: List[Exp[T]]) = if (xs.length > 1) xs.map(x => quote(x, true)).mkString("(",",",")").replace("()","") else xs.map(x => quote(x,true)).mkString(",").replace("()","") + if (symList.length > 1) stream.println("// TODO: use vars instead of tuples to return multiple values") + stream.println("if (" + quote(c) + ") {") + emitFatBlock(as) + stream.println(quoteList(as.map(getBlockResult))) + stream.println("} else {") + emitFatBlock(bs) + stream.println(quoteList(bs.map(getBlockResult))) + stream.println("}") + case _ => super.emitFatNode(symList, rhs) } } diff --git a/src/common/IterableOps.scala b/src/common/IterableOps.scala index 57d44740..a41ec626 100644 --- a/src/common/IterableOps.scala +++ b/src/common/IterableOps.scala @@ -39,8 +39,8 @@ trait IterableOpsExp extends IterableOps with EffectExp with VariablesExp { override def mirror[A:Manifest](e: Def[A], f: Transformer)(implicit pos: SourceContext): Exp[A] = { (e match { case e@IterableToArray(x) => iterable_toarray(f(x))(e.m,pos) - case Reflect(e@IterableForeach(x,y,b), u, es) => reflectMirrored(Reflect(IterableForeach(f(x),f(y).asInstanceOf[Sym[_]],f(b)), mapOver(f,u), f(es)))(mtype(manifest[A]), pos) - case Reflect(e@IterableToArray(x), u, es) => reflectMirrored(Reflect(IterableToArray(f(x))(e.m), mapOver(f,u), f(es)))(mtype(manifest[A]), pos) + case Reflect(e@IterableForeach(x,y,b), u, es) => reflectMirrored(Reflect(IterableForeach(f(x),f(y).asInstanceOf[Sym[_]],f(b)), mapOver(f,u), f(es)))(mtype(manifest[A])) + case Reflect(e@IterableToArray(x), u, es) => reflectMirrored(Reflect(IterableToArray(f(x))(e.m), mapOver(f,u), f(es)))(mtype(manifest[A])) case _ => super.mirror(e,f) }).asInstanceOf[Exp[A]] // why?? } @@ -73,11 +73,11 @@ trait ScalaGenIterableOps extends BaseGenIterableOps with ScalaGenBase { override def emitNode(sym: Sym[Any], rhs: Def[Any]) = rhs match { case IterableForeach(a,x,block) => - gen"""val $sym=$a.foreach{ - |$x => - |${nestedBlock(block)} - |$block - |}""" + gen"""$a.foreach{ + |$x => + |${nestedBlock(block)} + |$block + |}""" case IterableToArray(a) => emitValDef(sym, src"$a.toArray") case _ => super.emitNode(sym, rhs) } diff --git a/src/common/ListOps.scala b/src/common/ListOps.scala index d251d1bd..51508f7c 100644 --- a/src/common/ListOps.scala +++ b/src/common/ListOps.scala @@ -10,6 +10,10 @@ trait ListOps extends Variables { object List { def apply[A:Manifest](xs: Rep[A]*)(implicit pos: SourceContext) = list_new(xs) } + + object NewList { + def apply[A:Manifest](xs: Rep[A]*)(implicit pos: SourceContext) = list_new(xs) + } implicit def varToListOps[T:Manifest](x: Var[List[T]]) = new ListOpsCls(readVar(x)) // FIXME: dep on var is not nice implicit def repToListOps[T:Manifest](a: Rep[List[T]]) = new ListOpsCls(a) @@ -17,23 +21,25 @@ trait ListOps extends Variables { class ListOpsCls[A:Manifest](l: Rep[List[A]]) { def map[B:Manifest](f: Rep[A] => Rep[B]) = list_map(l,f) + def foreach(f: Rep[A] => Rep[Unit]) = list_foreach(l,f) def flatMap[B : Manifest](f: Rep[A] => Rep[List[B]]) = list_flatMap(l,f) def filter(f: Rep[A] => Rep[Boolean]) = list_filter(l, f) def sortBy[B:Manifest:Ordering](f: Rep[A] => Rep[B]) = list_sortby(l,f) def ::(e: Rep[A]) = list_prepend(l,e) def ++ (l2: Rep[List[A]]) = list_concat(l, l2) def mkString = list_mkString(l) - def mkString(s:Rep[String]) = list_mkString2(l,s) def head = list_head(l) def tail = list_tail(l) def isEmpty = list_isEmpty(l) def toArray = list_toarray(l) def toSeq = list_toseq(l) + def contains(e: Rep[A]) = list_contains(l,e) } def list_new[A:Manifest](xs: Seq[Rep[A]])(implicit pos: SourceContext): Rep[List[A]] def list_fromseq[A:Manifest](xs: Rep[Seq[A]])(implicit pos: SourceContext): Rep[List[A]] def list_map[A:Manifest,B:Manifest](l: Rep[List[A]], f: Rep[A] => Rep[B])(implicit pos: SourceContext): Rep[List[B]] + def list_foreach[A:Manifest](l: Rep[List[A]], f: Rep[A] => Rep[Unit])(implicit pos: SourceContext): Rep[Unit] def list_flatMap[A : Manifest, B : Manifest](xs: Rep[List[A]], f: Rep[A] => Rep[List[B]])(implicit pos: SourceContext): Rep[List[B]] def list_filter[A : Manifest](l: Rep[List[A]], f: Rep[A] => Rep[Boolean])(implicit pos: SourceContext): Rep[List[A]] def list_sortby[A:Manifest,B:Manifest:Ordering](l: Rep[List[A]], f: Rep[A] => Rep[B])(implicit pos: SourceContext): Rep[List[A]] @@ -43,16 +49,19 @@ trait ListOps extends Variables { def list_concat[A:Manifest](xs: Rep[List[A]], ys: Rep[List[A]])(implicit pos: SourceContext): Rep[List[A]] def list_cons[A:Manifest](x: Rep[A], xs: Rep[List[A]])(implicit pos: SourceContext): Rep[List[A]] // FIXME remove? def list_mkString[A : Manifest](xs: Rep[List[A]])(implicit pos: SourceContext): Rep[String] - def list_mkString2[A : Manifest](xs: Rep[List[A]], sep:Rep[String])(implicit pos: SourceContext): Rep[String] def list_head[A:Manifest](xs: Rep[List[A]])(implicit pos: SourceContext): Rep[A] def list_tail[A:Manifest](xs: Rep[List[A]])(implicit pos: SourceContext): Rep[List[A]] def list_isEmpty[A:Manifest](xs: Rep[List[A]])(implicit pos: SourceContext): Rep[Boolean] + def list_contains[A:Manifest](xs: Rep[List[A]], e: Rep[A])(implicit pos: SourceContext): Rep[Boolean] } trait ListOpsExp extends ListOps with EffectExp with VariablesExp { - case class ListNew[A:Manifest](xs: Seq[Rep[A]]) extends Def[List[A]] + case class ListNew[A:Manifest](xs: Seq[Rep[A]]) extends Def[List[A]] { + val m = manifest[A] + } case class ListFromSeq[A:Manifest](xs: Rep[Seq[A]]) extends Def[List[A]] case class ListMap[A:Manifest,B:Manifest](l: Exp[List[A]], x: Sym[A], block: Block[B]) extends Def[List[B]] + case class ListForeach[A:Manifest](l: Exp[List[A]], x: Sym[A], block: Block[Unit]) extends Def[Unit] case class ListFlatMap[A:Manifest, B:Manifest](l: Exp[List[A]], x: Sym[A], block: Block[List[B]]) extends Def[List[B]] case class ListFilter[A : Manifest](l: Exp[List[A]], x: Sym[A], block: Block[Boolean]) extends Def[List[A]] case class ListSortBy[A:Manifest,B:Manifest:Ordering](l: Exp[List[A]], x: Sym[A], block: Block[B]) extends Def[List[A]] @@ -62,10 +71,10 @@ trait ListOpsExp extends ListOps with EffectExp with VariablesExp { case class ListConcat[A:Manifest](xs: Rep[List[A]], ys: Rep[List[A]]) extends Def[List[A]] case class ListCons[A:Manifest](x: Rep[A], xs: Rep[List[A]]) extends Def[List[A]] case class ListMkString[A:Manifest](l: Exp[List[A]]) extends Def[String] - case class ListMkString2[A:Manifest](l: Exp[List[A]], s: Exp[String]) extends Def[String] case class ListHead[A:Manifest](xs: Rep[List[A]]) extends Def[A] case class ListTail[A:Manifest](xs: Rep[List[A]]) extends Def[List[A]] case class ListIsEmpty[A:Manifest](xs: Rep[List[A]]) extends Def[Boolean] + case class ListContains[A:Manifest](xs: Rep[List[A]], e: Rep[A]) extends Def[Boolean] def list_new[A:Manifest](xs: Seq[Rep[A]])(implicit pos: SourceContext) = ListNew(xs) def list_fromseq[A:Manifest](xs: Rep[Seq[A]])(implicit pos: SourceContext) = ListFromSeq(xs) @@ -74,6 +83,11 @@ trait ListOpsExp extends ListOps with EffectExp with VariablesExp { val b = reifyEffects(f(a)) reflectEffect(ListMap(l, a, b), summarizeEffects(b).star) } + def list_foreach[A:Manifest](l: Exp[List[A]], f: Exp[A] => Exp[Unit])(implicit pos: SourceContext) = { + val a = fresh[A] + val b = reifyEffects(f(a)) + reflectEffect(ListForeach(l, a, b), summarizeEffects(b).star) + } def list_flatMap[A:Manifest, B:Manifest](l: Exp[List[A]], f: Exp[A] => Exp[List[B]])(implicit pos: SourceContext) = { val a = fresh[A] val b = reifyEffects(f(a)) @@ -95,10 +109,10 @@ trait ListOpsExp extends ListOps with EffectExp with VariablesExp { def list_concat[A:Manifest](xs: Rep[List[A]], ys: Rep[List[A]])(implicit pos: SourceContext) = ListConcat(xs,ys) def list_cons[A:Manifest](x: Rep[A], xs: Rep[List[A]])(implicit pos: SourceContext) = ListCons(x,xs) def list_mkString[A:Manifest](l: Exp[List[A]])(implicit pos: SourceContext) = ListMkString(l) - def list_mkString2[A:Manifest](l: Rep[List[A]], sep:Rep[String])(implicit pos: SourceContext) = ListMkString2(l,sep) def list_head[A:Manifest](xs: Rep[List[A]])(implicit pos: SourceContext) = ListHead(xs) def list_tail[A:Manifest](xs: Rep[List[A]])(implicit pos: SourceContext) = ListTail(xs) def list_isEmpty[A:Manifest](xs: Rep[List[A]])(implicit pos: SourceContext) = ListIsEmpty(xs) + def list_contains[A:Manifest](xs: Rep[List[A]], e: Rep[A])(implicit pos: SourceContext) = ListContains(xs,e) override def mirror[A:Manifest](e: Def[A], f: Transformer)(implicit pos: SourceContext): Exp[A] = { (e match { @@ -109,6 +123,7 @@ trait ListOpsExp extends ListOps with EffectExp with VariablesExp { override def syms(e: Any): List[Sym[Any]] = e match { case ListMap(a, x, body) => syms(a):::syms(body) + case ListForeach(a, x, body) => syms(a):::syms(body) case ListFlatMap(a, _, body) => syms(a) ::: syms(body) case ListFilter(a, _, body) => syms(a) ::: syms(body) case ListSortBy(a, x, body) => syms(a):::syms(body) @@ -117,6 +132,7 @@ trait ListOpsExp extends ListOps with EffectExp with VariablesExp { override def boundSyms(e: Any): List[Sym[Any]] = e match { case ListMap(a, x, body) => x :: effectSyms(body) + case ListForeach(a, x, body) => x :: effectSyms(body) case ListFlatMap(_, x, body) => x :: effectSyms(body) case ListFilter(_, x, body) => x :: effectSyms(body) case ListSortBy(a, x, body) => x :: effectSyms(body) @@ -125,6 +141,7 @@ trait ListOpsExp extends ListOps with EffectExp with VariablesExp { override def symsFreq(e: Any): List[(Sym[Any], Double)] = e match { case ListMap(a, x, body) => freqNormal(a):::freqHot(body) + case ListForeach(a, x, body) => freqNormal(a):::freqHot(body) case ListFlatMap(a, _, body) => freqNormal(a) ::: freqHot(body) case ListFilter(a, _, body) => freqNormal(a) ::: freqHot(body) case ListSortBy(a, x, body) => freqNormal(a):::freqHot(body) @@ -152,7 +169,7 @@ trait ScalaGenListOps extends BaseGenListOps with ScalaGenEffect { import IR._ override def emitNode(sym: Sym[Any], rhs: Def[Any]) = rhs match { - case ListNew(xs) => emitValDef(sym, src"List(${(xs map {quote}).mkString(",")})") + case l@ListNew(xs) => emitValDef(sym, src"List[${l.m}](${(xs map {quote}).mkString(",")})") case ListConcat(xs,ys) => emitValDef(sym, src"$xs ::: $ys") case ListCons(x, xs) => emitValDef(sym, src"$x :: $xs") case ListHead(xs) => emitValDef(sym, src"$xs.head") @@ -160,30 +177,58 @@ trait ScalaGenListOps extends BaseGenListOps with ScalaGenEffect { case ListIsEmpty(xs) => emitValDef(sym, src"$xs.isEmpty") case ListFromSeq(xs) => emitValDef(sym, src"List($xs: _*)") case ListMkString(xs) => emitValDef(sym, src"$xs.mkString") - case ListMkString2(xs,s) => emitValDef(sym, src"$xs.mkString($s)") case ListMap(l,x,blk) => - gen"""val $sym = $l.map { $x => - |${nestedBlock(blk)} - |$blk - |}""" - case ListFlatMap(l, x, b) => - gen"""val $sym = $l.flatMap { $x => - |${nestedBlock(b)} - |$b - |}""" - case ListFilter(l, x, b) => - gen"""val $sym = $l.filter { $x => - |${nestedBlock(b)} - |$b - |}""" - case ListSortBy(l,x,blk) => - gen"""val $sym = $l.sortBy { $x => + val strWriter = new java.io.StringWriter + val localStream = new PrintWriter(strWriter); + withStream(localStream) { + gen"""$l.map { $x => + |${nestedBlock(blk)} + |$blk + |}""" + } + emitValDef(sym, strWriter.toString) + case ListForeach(l,x,blk) => { + gen"""$l.foreach { $x => |${nestedBlock(blk)} - |$blk |}""" + } + case ListFlatMap(l, x, b) => { + val strWriter = new java.io.StringWriter + val localStream = new PrintWriter(strWriter); + withStream(localStream) { + gen"""$l.flatMap { $x => + |${nestedBlock(b)} + |$b + |}""" + } + emitValDef(sym, strWriter.toString) + } + case ListFilter(l, x, b) => { + val strWriter = new java.io.StringWriter + val localStream = new PrintWriter(strWriter); + withStream(localStream) { + gen"""$l.filter { $x => + |${nestedBlock(b)} + |$b + |}""" + } + emitValDef(sym, strWriter.toString) + } + case ListSortBy(l,x,blk) => { + val strWriter = new java.io.StringWriter + val localStream = new PrintWriter(strWriter); + withStream(localStream) { + gen"""$l.sortBy { $x => + |${nestedBlock(blk)} + |$blk + |}""" + } + emitValDef(sym, strWriter.toString) + } case ListPrepend(l,e) => emitValDef(sym, src"$e :: $l") case ListToArray(l) => emitValDef(sym, src"$l.toArray") case ListToSeq(l) => emitValDef(sym, src"$l.toSeq") + case ListContains(l, e) => emitValDef(sym, src"$l.contains($e)") case _ => super.emitNode(sym, rhs) } } diff --git a/src/common/OPMOps.scala b/src/common/OPMOps.scala new file mode 100644 index 00000000..9337ce67 --- /dev/null +++ b/src/common/OPMOps.scala @@ -0,0 +1,74 @@ +package scala.lms +package common + +import java.io.PrintWriter + +import scala.lms.internal.{GenericNestedCodegen, GenerationFailedException} +import scala.reflect.SourceContext + +trait OMPOps extends Base { + def parallel_region(b: => Rep[Unit]): Rep[Unit] + def critical_region(b: => Rep[Unit]): Rep[Unit] +} + +trait OMPOpsExp extends OMPOps with BaseExp with EffectExp { + + case class ParallelRegion(b: Block[Unit]) extends Def[Unit] + def parallel_region(b: => Exp[Unit]): Exp[Unit] = { + val br = reifyEffects(b) + reflectEffect(ParallelRegion(br)) + } + + case class CriticalRegion(b: Block[Unit]) extends Def[Unit] + def critical_region(b: => Exp[Unit]): Exp[Unit] = { + val br = reifyEffects(b) + reflectEffect(CriticalRegion(br)) + } + + override def boundSyms(e: Any): List[Sym[Any]] = e match { + case ParallelRegion(b) => effectSyms(b) + case CriticalRegion(b) => effectSyms(b) + case _ => super.boundSyms(e) + } + + override def syms(e: Any): List[Sym[Any]] = e match { + case ParallelRegion(body) => syms(body) + case CriticalRegion(body) => syms(body) + case _ => super.syms(e) + } + + override def symsFreq(e: Any): List[(Sym[Any], Double)] = e match { + case ParallelRegion(body) => freqHot(body) + case CriticalRegion(body) => freqHot(body) + case _ => super.symsFreq(e) + } + +} + +trait BaseGenOMPOps extends GenericNestedCodegen { + val IR: OMPOpsExp + import IR._ +} + +trait ScalaGenOMPOps extends ScalaGenEffect with BaseGenOMPOps { + +} + +trait CGenOMPOps extends CGenEffect with BaseGenOMPOps{ + import IR._ + + override def emitNode(sym: Sym[Any], rhs: Def[Any]) = rhs match { + case ParallelRegion(body) => + gen"""#pragma omp parallel + |{ + |${nestedBlock(body)} + |}""" + case CriticalRegion(body) => + gen"""#pragma omp critical + |{ + |${nestedBlock(body)} + |}""" + case _ => super.emitNode(sym, rhs) + } +} + diff --git a/src/common/Packages.scala b/src/common/Packages.scala index 0786eb9d..d92a5cc1 100644 --- a/src/common/Packages.scala +++ b/src/common/Packages.scala @@ -13,60 +13,64 @@ trait LiftScala extends LiftAll with LiftVariables with LiftEquals { } trait ScalaOpsPkg extends Base - with ImplicitOps with NumericOps with FractionalOps with OrderingOps with StringOps - with RangeOps with IOOps with ArrayOps with BooleanOps with PrimitiveOps with MiscOps - with Equal with IfThenElse with Variables with While with TupleOps with ListOps - with SeqOps with MathOps with CastingOps with SetOps with ObjectOps with ArrayBufferOps + with Structs with ImplicitOps with NumericOps with FractionalOps with OrderingOps + with StringOps with RangeOps with IOOps with ArrayOps with BooleanOps + with PrimitiveOps with MiscOps with Functions with Equal with IfThenElse + with Variables with While with TupleOps with ListOps with SeqOps with MathOps + with CastingOps with SetOps with ObjectOps with ArrayBufferOps + with UncheckedOps -trait ScalaOpsPkgExp extends ScalaOpsPkg - with ImplicitOpsExp with NumericOpsExp with FractionalOpsExp with OrderingOpsExp with StringOpsExp - with RangeOpsExp with IOOpsExp with ArrayOpsExp with BooleanOpsExp with PrimitiveOpsExp with MiscOpsExp - with FunctionsExp with EqualExp with IfThenElseExp with VariablesExp with WhileExp with TupleOpsExp with ListOpsExp - with SeqOpsExp with DSLOpsExp with MathOpsExp with CastingOpsExp with SetOpsExp with ObjectOpsExp with ArrayBufferOpsExp +trait ScalaOpsPkgExp extends ScalaOpsPkg + with StructExp with ImplicitOpsExp with NumericOpsExp with FractionalOpsExp with OrderingOpsExp + with StringOpsExp with RangeOpsExp with IOOpsExp with ArrayOpsExp with BooleanOpsExp + with PrimitiveOpsExp with MiscOpsExp with FunctionsExp with EqualExp with IfThenElseExp + with VariablesExp with WhileExp with TupleOpsExp with ListOpsExp with SeqOpsExp with MathOpsExp + with CastingOpsExp with SetOpsExp with ObjectOpsExp with ArrayBufferOpsExp + with UncheckedOpsExp +trait ScalaOpsPkgExpOpt extends ScalaOpsPkgExp + with StructExpOptCommon with NumericOpsExpOpt + with ArrayOpsExpOpt with ListOpsExpOpt + with EqualExpOpt with IfThenElseExpOpt with VariablesExpOpt with WhileExpOpt + with ObjectOpsExpOpt -/** - * Code gen: each target must define a code generator package. - */ - - -///////// -// Scala +/** Code gen: each target must define a code generator package. */ trait ScalaCodeGenPkg extends ScalaGenImplicitOps with ScalaGenNumericOps with ScalaGenFractionalOps with ScalaGenOrderingOps with ScalaGenStringOps with ScalaGenRangeOps with ScalaGenIOOps with ScalaGenArrayOps with ScalaGenBooleanOps with ScalaGenPrimitiveOps with ScalaGenMiscOps with ScalaGenFunctions with ScalaGenEqual with ScalaGenIfThenElse with ScalaGenVariables with ScalaGenWhile with ScalaGenTupleOps with ScalaGenListOps with ScalaGenSeqOps with ScalaGenDSLOps with ScalaGenMathOps with ScalaGenCastingOps with ScalaGenSetOps - with ScalaGenObjectOps with ScalaGenArrayBufferOps + with ScalaGenObjectOps with ScalaGenArrayBufferOps { val IR: ScalaOpsPkgExp } ///// // C -trait CCodeGenPkg extends CGenImplicitOps with CGenNumericOps with CGenFractionalOps with CGenOrderingOps - with CGenStringOps with CGenRangeOps with CGenIOOps with CGenArrayOps with CGenBooleanOps +trait COpsPkg extends ScalaOpsPkg +trait COpsPkgExp extends ScalaOpsPkgExp +trait CCodeGenPkg extends CGenDSLOps with CGenImplicitOps with CGenNumericOps with CGenFractionalOps with CGenOrderingOps + with CGenStringOps /*with CGenRangeOps*/ with CGenIOOps with CGenArrayOps with CGenBooleanOps with CGenPrimitiveOps with CGenMiscOps with CGenFunctions with CGenEqual with CGenIfThenElse - with CGenVariables with CGenWhile with CGenTupleOps with CGenListOps - with CGenSeqOps with CGenDSLOps with CGenMathOps with CGenCastingOps with CGenSetOps - with CGenObjectOps with CGenArrayBufferOps - { val IR: ScalaOpsPkgExp } + with CGenVariables with CGenWhile + with CGenMathOps with CGenCastingOps with CGenSetOps with CGenArrayBufferOps with CGenUncheckedOps + { val IR: COpsPkgExp } /////// // Cuda -trait CudaCodeGenPkg extends CudaGenImplicitOps with CudaGenNumericOps with CudaGenFractionalOps with CudaGenOrderingOps - with CudaGenStringOps with CudaGenRangeOps with CudaGenIOOps with CudaGenArrayOps with CudaGenBooleanOps - with CudaGenPrimitiveOps with CudaGenMiscOps with CudaGenFunctions with CudaGenEqual with CudaGenIfThenElse - with CudaGenVariables with CudaGenWhile with CudaGenTupleOps with CudaGenListOps - with CudaGenSeqOps with CudaGenDSLOps with CudaGenMathOps with CudaGenCastingOps with CudaGenSetOps - with CudaGenObjectOps with CudaGenArrayBufferOps +// CudaGenDSLOps will be used after all the basic generators are passed +trait CudaCodeGenPkg extends CudaGenDSLOps with CudaGenImplicitOps with CudaGenNumericOps with CudaGenFractionalOps with CudaGenOrderingOps + with CudaGenStringOps /*with CudaGenRangeOps*/ with CudaGenIOOps with CudaGenArrayOps with CudaGenBooleanOps + with CudaGenPrimitiveOps with CudaGenMiscOps /*with CudaGenFunctions*/ with CudaGenEqual with CudaGenIfThenElse + with CudaGenVariables with CudaGenWhile + with CudaGenMathOps with CudaGenCastingOps with CudaGenSetOps with CudaGenArrayBufferOps { val IR: ScalaOpsPkgExp } -///////// -// OpenCL -trait OpenCLCodeGenPkg extends OpenCLGenImplicitOps with OpenCLGenNumericOps with OpenCLGenFractionalOps with OpenCLGenOrderingOps - with OpenCLGenStringOps with OpenCLGenRangeOps with OpenCLGenIOOps with OpenCLGenArrayOps with OpenCLGenBooleanOps - with OpenCLGenPrimitiveOps with OpenCLGenMiscOps with OpenCLGenFunctions with OpenCLGenEqual with OpenCLGenIfThenElse - with OpenCLGenVariables with OpenCLGenWhile with OpenCLGenTupleOps with OpenCLGenListOps - with OpenCLGenSeqOps with OpenCLGenDSLOps with OpenCLGenMathOps with OpenCLGenCastingOps with OpenCLGenSetOps - with OpenCLGenObjectOps with OpenCLGenArrayBufferOps - { val IR: ScalaOpsPkgExp } \ No newline at end of file +//trait CudaCodeGenPkg extends CudaGenNumericOps with CudaGenRangeOps with CudaGenFractionalOps +// with CudaGenMiscOps with CudaGenFunctions with CudaGenVariables with CudaGenDSLOps with CudaGenImplicitOps { val IR: ScalaOpsPkgExp } + +trait OpenCLCodeGenPkg extends OpenCLGenDSLOps with OpenCLGenImplicitOps with OpenCLGenNumericOps with OpenCLGenFractionalOps with OpenCLGenOrderingOps + with OpenCLGenStringOps /*with OpenCLGenRangeOps*/ with OpenCLGenIOOps with OpenCLGenArrayOps with OpenCLGenBooleanOps + with OpenCLGenPrimitiveOps with OpenCLGenMiscOps /*with OpenCLGenFunctions*/ with OpenCLGenEqual with OpenCLGenIfThenElse + with OpenCLGenVariables with OpenCLGenWhile + with OpenCLGenMathOps with OpenCLGenCastingOps with OpenCLGenSetOps with OpenCLGenArrayBufferOps + { val IR: ScalaOpsPkgExp } diff --git a/src/common/RangeOps.scala b/src/common/RangeOps.scala index 8afb8dfb..3bb28f66 100644 --- a/src/common/RangeOps.scala +++ b/src/common/RangeOps.scala @@ -7,23 +7,42 @@ import scala.lms.internal.{GenericNestedCodegen, GenerationFailedException} import scala.reflect.SourceContext trait RangeOps extends Base { - // workaround for infix not working with manifests - implicit def repRangeToRangeOps(r: Rep[Range]) = new rangeOpsCls(r) - class rangeOpsCls(r: Rep[Range]){ + trait LongRange + + implicit class rangeOpsCls(r: Rep[Range]) { def foreach(f: Rep[Int] => Rep[Unit])(implicit pos: SourceContext) = range_foreach(r, f) } + implicit class lrangeOpsCls(r: Rep[LongRange]) { + def foreach(f: Rep[Long] => Rep[Unit])(implicit pos: SourceContext) = lrange_foreach(r, f) + def parforeach(f: Rep[Long] => Rep[Unit])(implicit pos: SourceContext) = lrange_par_foreach(r, f) + } def infix_until(start: Rep[Int], end: Rep[Int])(implicit pos: SourceContext) = range_until(start,end) + def infix_until(start: Rep[Long], end: Rep[Long])(implicit pos: SourceContext, o: Overloaded1) = lrange_until(start,end) def infix_start(r: Rep[Range])(implicit pos: SourceContext) = range_start(r) + def infix_start(r: Rep[LongRange])(implicit pos: SourceContext, o: Overloaded1) = lrange_start(r) def infix_step(r: Rep[Range])(implicit pos: SourceContext) = range_step(r) + def infix_step(r: Rep[LongRange])(implicit pos: SourceContext, o: Overloaded1) = lrange_step(r) def infix_end(r: Rep[Range])(implicit pos: SourceContext) = range_end(r) - //def infix_foreach(r: Rep[Range], f: Rep[Int] => Rep[Unit]) = range_foreach(r, f) + def infix_end(r: Rep[LongRange])(implicit pos: SourceContext, o: Overloaded1) = lrange_end(r) def range_until(start: Rep[Int], end: Rep[Int])(implicit pos: SourceContext): Rep[Range] def range_start(r: Rep[Range])(implicit pos: SourceContext) : Rep[Int] def range_step(r: Rep[Range])(implicit pos: SourceContext) : Rep[Int] def range_end(r: Rep[Range])(implicit pos: SourceContext) : Rep[Int] def range_foreach(r: Rep[Range], f: (Rep[Int]) => Rep[Unit])(implicit pos: SourceContext): Rep[Unit] + + def lrange_until(start: Rep[Long], end: Rep[Long])(implicit pos: SourceContext): Rep[LongRange] + def lrange_start(r: Rep[LongRange])(implicit pos: SourceContext) : Rep[Long] + def lrange_step(r: Rep[LongRange])(implicit pos: SourceContext) : Rep[Long] + def lrange_end(r: Rep[LongRange])(implicit pos: SourceContext) : Rep[Long] + def lrange_foreach(r: Rep[LongRange], f: (Rep[Long]) => Rep[Unit])(implicit pos: SourceContext): Rep[Unit] + + def lrange_par_until(start: Rep[Long], end: Rep[Long])(implicit pos: SourceContext): Rep[LongRange] + def lrange_par_start(r: Rep[LongRange])(implicit pos: SourceContext) : Rep[Long] + def lrange_par_step(r: Rep[LongRange])(implicit pos: SourceContext) : Rep[Long] + def lrange_par_end(r: Rep[LongRange])(implicit pos: SourceContext) : Rep[Long] + def lrange_par_foreach(r: Rep[LongRange], f: (Rep[Long]) => Rep[Unit])(implicit pos: SourceContext): Rep[Unit] } trait RangeOpsExp extends RangeOps with FunctionsExp { @@ -31,19 +50,44 @@ trait RangeOpsExp extends RangeOps with FunctionsExp { case class RangeStart(r: Exp[Range]) extends Def[Int] case class RangeStep(r: Exp[Range]) extends Def[Int] case class RangeEnd(r: Exp[Range]) extends Def[Int] - //case class RangeForeach(r: Exp[Range], i: Exp[Int], body: Exp[Unit]) extends Def[Unit] case class RangeForeach(start: Exp[Int], end: Exp[Int], i: Sym[Int], body: Block[Unit]) extends Def[Unit] + case class LongUntil(start: Exp[Long], end: Exp[Long]) extends Def[LongRange] + case class LongRangeStart(r: Exp[LongRange]) extends Def[Long] + case class LongRangeStep(r: Exp[LongRange]) extends Def[Long] + case class LongRangeEnd(r: Exp[LongRange]) extends Def[Long] + case class LongRangeForeach(start: Exp[Long], end: Exp[Long], i: Sym[Long], body: Block[Unit]) extends Def[Unit] + + case class LongParUntil(start: Exp[Long], end: Exp[Long]) extends Def[LongRange] + case class LongRangeParStart(r: Exp[LongRange]) extends Def[Long] + case class LongRangeParStep(r: Exp[LongRange]) extends Def[Long] + case class LongRangeParEnd(r: Exp[LongRange]) extends Def[Long] + case class LongRangeParForeach(start: Exp[Long], end: Exp[Long], i: Sym[Long], body: Block[Unit]) extends Def[Unit] + + def lrange_par_until(start: Exp[Long], end: Exp[Long])(implicit pos: SourceContext) : Exp[LongRange] = LongParUntil(start, end) + def lrange_par_start(r: Exp[LongRange])(implicit pos: SourceContext) : Exp[Long] = r match { + case Def(LongParUntil(start, end)) => start + case _ => LongRangeParStart(r) + } + def lrange_par_step(r: Exp[LongRange])(implicit pos: SourceContext) : Exp[Long] = LongRangeParStep(r) + def lrange_par_end(r: Exp[LongRange])(implicit pos: SourceContext) : Exp[Long] = r match { + case Def(LongParUntil(start, end)) => end + case _ => LongRangeParEnd(r) + } + def lrange_par_foreach(r: Exp[LongRange], block: Exp[Long] => Exp[Unit])(implicit pos: SourceContext): Exp[Unit] = { + val i = fresh[Long] + val a = reifyEffects(block(i)) + reflectEffect(LongRangeParForeach(r.start, r.end, i, a), summarizeEffects(a).star) + } + def range_until(start: Exp[Int], end: Exp[Int])(implicit pos: SourceContext) : Exp[Range] = Until(start, end) - def range_start(r: Exp[Range])(implicit pos: SourceContext) : Exp[Int] = r match { + def range_start(r: Exp[Range])(implicit pos: SourceContext) : Exp[Int] = r match { case Def(Until(start, end)) => start - case Def(Reflect(Until(start, end), u, es)) => start case _ => RangeStart(r) } def range_step(r: Exp[Range])(implicit pos: SourceContext) : Exp[Int] = RangeStep(r) - def range_end(r: Exp[Range])(implicit pos: SourceContext) : Exp[Int] = r match { + def range_end(r: Exp[Range])(implicit pos: SourceContext) : Exp[Int] = r match { case Def(Until(start, end)) => end - case Def(Reflect(Until(start, end), u, es)) => end case _ => RangeEnd(r) } def range_foreach(r: Exp[Range], block: Exp[Int] => Exp[Unit])(implicit pos: SourceContext) : Exp[Unit] = { @@ -51,33 +95,52 @@ trait RangeOpsExp extends RangeOps with FunctionsExp { val a = reifyEffects(block(i)) reflectEffect(RangeForeach(r.start, r.end, i, a), summarizeEffects(a).star) } - + + def lrange_until(start: Exp[Long], end: Exp[Long])(implicit pos: SourceContext) : Exp[LongRange] = LongUntil(start, end) + def lrange_start(r: Exp[LongRange])(implicit pos: SourceContext) : Exp[Long] = r match { + case Def(LongUntil(start, end)) => start + case _ => LongRangeStart(r) + } + def lrange_step(r: Exp[LongRange])(implicit pos: SourceContext) : Exp[Long] = LongRangeStep(r) + def lrange_end(r: Exp[LongRange])(implicit pos: SourceContext) : Exp[Long] = r match { + case Def(LongUntil(start, end)) => end + case _ => LongRangeEnd(r) + } + def lrange_foreach(r: Exp[LongRange], block: Exp[Long] => Exp[Unit])(implicit pos: SourceContext) : Exp[Unit] = { + val i = fresh[Long] + val a = reifyEffects(block(i)) + reflectEffect(LongRangeForeach(r.start, r.end, i, a), summarizeEffects(a).star) + } + override def mirror[A:Manifest](e: Def[A], f: Transformer)(implicit pos: SourceContext): Exp[A] = (e match { - case Reflect(RangeForeach(s,e,i,b), u, es) => reflectMirrored(Reflect(RangeForeach(f(s),f(e),f(i).asInstanceOf[Sym[Int]],f(b)), mapOver(f,u), f(es)))(mtype(manifest[A]), pos) - case Reflect(RangeStart(r), u, es) => reflectMirrored(Reflect(RangeStart(f(r)), mapOver(f,u), f(es)))(mtype(manifest[A]), pos) - case Reflect(RangeStep(r), u, es) => reflectMirrored(Reflect(RangeStep(f(r)), mapOver(f,u), f(es)))(mtype(manifest[A]), pos) - case Reflect(RangeEnd(r), u, es) => reflectMirrored(Reflect(RangeEnd(f(r)), mapOver(f,u), f(es)))(mtype(manifest[A]), pos) - case Reflect(Until(s,e), u, es) => reflectMirrored(Reflect(Until(f(s),f(e)), mapOver(f,u), f(es)))(mtype(manifest[A]), pos) + case Reflect(RangeForeach(s,e,i,b), u, es) => reflectMirrored(Reflect(RangeForeach(f(s),f(e),f(i).asInstanceOf[Sym[Int]],f(b)), mapOver(f,u), f(es)))(mtype(manifest[A])) + case Reflect(LongRangeForeach(s,e,i,b), u, es) => reflectMirrored(Reflect(LongRangeForeach(f(s),f(e),f(i).asInstanceOf[Sym[Long]],f(b)), mapOver(f,u), f(es)))(mtype(manifest[A])) + case Reflect(LongRangeParForeach(s,e,i,b), u, es) => reflectMirrored(Reflect(LongRangeParForeach(f(s),f(e),f(i).asInstanceOf[Sym[Long]],f(b)), mapOver(f,u), f(es)))(mtype(manifest[A])) case _ => super.mirror(e,f) }).asInstanceOf[Exp[A]] override def syms(e: Any): List[Sym[Any]] = e match { case RangeForeach(start, end, i, body) => syms(start):::syms(end):::syms(body) + case LongRangeForeach(start, end, i, body) => syms(start):::syms(end):::syms(body) + case LongRangeParForeach(start, end, i, body) => syms(start):::syms(end):::syms(body) case _ => super.syms(e) } override def boundSyms(e: Any): List[Sym[Any]] = e match { case RangeForeach(start, end, i, y) => i :: effectSyms(y) + case LongRangeForeach(start, end, i, y) => i :: effectSyms(y) + case LongRangeParForeach(start, end, i, y) => i :: effectSyms(y) case _ => super.boundSyms(e) } override def symsFreq(e: Any): List[(Sym[Any], Double)] = e match { case RangeForeach(start, end, i, body) => freqNormal(start):::freqNormal(end):::freqHot(body) + case LongRangeForeach(start, end, i, body) => freqNormal(start):::freqNormal(end):::freqHot(body) + case LongRangeParForeach(start, end, i, body) => freqNormal(start):::freqNormal(end):::freqHot(body) case _ => super.symsFreq(e) } - } trait BaseGenRangeOps extends GenericNestedCodegen { @@ -99,16 +162,16 @@ trait ScalaGenRangeOps extends ScalaGenEffect with BaseGenRangeOps { stream.println(quote(getBlockResult(body))) stream.println("}") } - */ + */ case RangeForeach(start, end, i, body) => { // do not need to print unit result //stream.println(quote(getBlockResult(body))) - gen"""var $i : Int = $start - |val $sym = while ($i < $end) { - |${nestedBlock(body)} - |$i = $i + 1 - |}""" + gen"var $i : Int = $start" + emitValDef(sym, src"while ($i < $end) {") + gen"""${nestedBlock(body)} + |$i = $i + 1 + |}""" } case _ => super.emitNode(sym, rhs) @@ -121,28 +184,28 @@ trait CudaGenRangeOps extends CudaGenEffect with BaseGenRangeOps { override def emitNode(sym: Sym[Any], rhs: Def[Any]) = rhs match { case Until(start, end) => - gen"""${addTab()}int ${sym}_start = $start; - |${addTab()}int ${sym}_end = $end;""" - // Do nothing: will be handled by RangeForeach + gen"""${addTab()}int ${sym}_start = $start; + |${addTab()}int ${sym}_end = $end;""" + // Do nothing: will be handled by RangeForeach - // TODO: What if the range is not continuous integer set? + // TODO: What if the range is not continuous integer set? case RangeForeach(start, end, i, body) => { /* - //var freeVars = buildScheduleForResult(body).filter(scope.contains(_)).map(_.sym) - val freeVars = getFreeVarBlock(body,List(i.asInstanceOf[Sym[Any]])) - // Add the variables of range to the free variable list if necessary - var paramList = freeVars - //val Until(startIdx,endIdx) = findDefinition(r.asInstanceOf[Sym[Range]]).map(_.rhs).get.asInstanceOf[Until] - if(start.isInstanceOf[Sym[Any]]) paramList = start.asInstanceOf[Sym[Any]] :: paramList - if(end.isInstanceOf[Sym[Any]]) paramList = end.asInstanceOf[Sym[Any]] :: paramList - paramList = paramList.distinct - val paramListStr = paramList.map(ele=>remap(ele.tp) + " " + quote(ele)).mkString(", ") - */ - gen"${addTab()}for(int $i=$start; $i < $end; $i++) {" - tabWidth += 1 - emitBlock(body) - tabWidth -= 1 - gen"${addTab()}}" + //var freeVars = buildScheduleForResult(body).filter(scope.contains(_)).map(_.sym) + val freeVars = getFreeVarBlock(body,List(i.asInstanceOf[Sym[Any]])) + // Add the variables of range to the free variable list if necessary + var paramList = freeVars + //val Until(startIdx,endIdx) = findDefinition(r.asInstanceOf[Sym[Range]]).map(_.rhs).get.asInstanceOf[Until] + if(start.isInstanceOf[Sym[Any]]) paramList = start.asInstanceOf[Sym[Any]] :: paramList + if(end.isInstanceOf[Sym[Any]]) paramList = end.asInstanceOf[Sym[Any]] :: paramList + paramList = paramList.distinct + val paramListStr = paramList.map(ele=>remap(ele.tp) + " " + quote(ele)).mkString(", ") + */ + gen"${addTab()}for(int $i=$start; $i < $end; $i++) {" + tabWidth += 1 + emitBlock(body) + tabWidth -= 1 + gen"${addTab()}}" } case _ => super.emitNode(sym, rhs) } @@ -157,8 +220,8 @@ trait OpenCLGenRangeOps extends OpenCLGenEffect with BaseGenRangeOps { throw new GenerationFailedException("OpenCLGenRangeOps: Range vector is not supported") case RangeForeach(start, end, i, body) => gen"""for(int $i=$start; $i < $end; $i++) { - |${nestedBlock(body)} - |}""" + |${nestedBlock(body)} + |}""" case _ => super.emitNode(sym, rhs) } @@ -172,10 +235,24 @@ trait CGenRangeOps extends CGenEffect with BaseGenRangeOps { case Until(start, end) => throw new GenerationFailedException("CGenRangeOps: Range vector is not supported") case RangeForeach(start, end, i, body) => - gen"""for(int $i=$start; $i < $end; $i++) { - |${nestedBlock(body)} - |}""" - + // Some compilers don't like the initialization inside for + stream.println(remap(i.tp) + " " + quote(i) + ";") + gen"""for($i=$start; $i < $end; $i++) { + |${nestedBlock(body)} + |}""" + case LongRangeForeach(start, end, i, body) => + // Some compilers don't like the initialization inside for + stream.println(remap(i.tp) + " " + quote(i) + ";") + gen"""for($i=$start; $i < $end; $i++) { + |${nestedBlock(body)} + |}""" + case LongRangeParForeach(start, end, i, body) => + // Some compilers don't like the initialization inside for + stream.println(remap(i.tp) + " " + quote(i) + ";") + stream.println("#pragma omp parallel for private(" + quote(i) + ")") + gen"""for($i=$start; $i < $end; $i++) { + |${nestedBlock(body)} + |}""" case _ => super.emitNode(sym, rhs) } } diff --git a/src/common/StringOps.scala b/src/common/StringOps.scala index c80911f9..6b5bdf8c 100644 --- a/src/common/StringOps.scala +++ b/src/common/StringOps.scala @@ -3,7 +3,7 @@ package common import java.io.PrintWriter import scala.lms.util.OverloadHack -import scala.lms.internal.{GenerationFailedException} +import scala.lms.internal.{GenerationFailedException,CNestedCodegen} import scala.reflect.SourceContext trait LiftString { @@ -15,114 +15,137 @@ trait LiftString { trait StringOps extends Variables with OverloadHack { // NOTE: if something doesn't get lifted, this won't give you a compile time error, // since string concat is defined on all objects - + def infix_+(s1: String, s2: Rep[Any])(implicit o: Overloaded1, pos: SourceContext) = string_plus(unit(s1), s2) def infix_+[T:Manifest](s1: String, s2: Var[T])(implicit o: Overloaded2, pos: SourceContext) = string_plus(unit(s1), readVar(s2)) - def infix_+(s1: Rep[String], s2: Rep[Any])(implicit o: Overloaded1, pos: SourceContext) = string_plus(s1, s2) - def infix_+[T:Manifest](s1: Rep[String], s2: Var[T])(implicit o: Overloaded2, pos: SourceContext) = string_plus(s1, readVar(s2)) + def infix_+[T:Manifest](s1: Rep[String], s2: Rep[T])(implicit o: Overloaded1, pos: SourceContext): Rep[String] = { + if (manifest[T] == classManifest[Array[Byte]]) + string_plus(s1, string_new(s2)) + else string_plus(s1, s2) + } + def infix_+[T:Manifest](s1: Rep[String], s2: Var[T])(implicit o: Overloaded2, pos: SourceContext): Rep[String] = string_plus(s1, readVar(s2)) def infix_+(s1: Rep[String], s2: Rep[String])(implicit o: Overloaded3, pos: SourceContext) = string_plus(s1, s2) def infix_+(s1: Rep[String], s2: Var[String])(implicit o: Overloaded4, pos: SourceContext) = string_plus(s1, readVar(s2)) def infix_+(s1: Rep[Any], s2: Rep[String])(implicit o: Overloaded5, pos: SourceContext) = string_plus(s1, s2) def infix_+(s1: Rep[Any], s2: Var[String])(implicit o: Overloaded6, pos: SourceContext) = string_plus(s1, readVar(s2)) def infix_+(s1: Rep[Any], s2: String)(implicit o: Overloaded7, pos: SourceContext) = string_plus(s1, unit(s2)) - - def infix_+(s1: Var[String], s2: Rep[Any])(implicit o: Overloaded8, pos: SourceContext) = string_plus(readVar(s1), s2) + + def infix_+(s1: Var[String], s2: Rep[Any])(implicit o: Overloaded8, pos: SourceContext) = string_plus(readVar(s1), s2) def infix_+[T:Manifest](s1: Var[String], s2: Var[T])(implicit o: Overloaded9, pos: SourceContext) = string_plus(readVar(s1), readVar(s2)) - def infix_+(s1: Var[String], s2: Rep[String])(implicit o: Overloaded10, pos: SourceContext) = string_plus(readVar(s1), s2) - def infix_+(s1: Var[String], s2: Var[String])(implicit o: Overloaded11, pos: SourceContext) = string_plus(readVar(s1), readVar(s2)) + def infix_+(s1: Var[String], s2: Rep[String])(implicit o: Overloaded10, pos: SourceContext) = string_plus(readVar(s1), s2) + def infix_+(s1: Var[String], s2: Var[String])(implicit o: Overloaded11, pos: SourceContext) = string_plus(readVar(s1), readVar(s2)) def infix_+[T:Manifest](s1: Var[T], s2: Rep[String])(implicit o: Overloaded12, pos: SourceContext) = string_plus(readVar(s1), s2) def infix_+[T:Manifest](s1: Var[T], s2: Var[String])(implicit o: Overloaded13, pos: SourceContext) = string_plus(readVar(s1), readVar(s2)) def infix_+[T:Manifest](s1: Var[T], s2: String)(implicit o: Overloaded14, pos: SourceContext) = string_plus(readVar(s1), unit(s2)) - + def infix_getBytes(s1: Rep[String])(implicit pos: SourceContext) = string_getBytes(s1) + // these are necessary to be more specific than arithmetic/numeric +. is there a more generic form of this that will work? - //def infix_+[R:Manifest](s1: Rep[String], s2: R)(implicit c: R => Rep[Any], o: Overloaded15, pos: SourceContext) = string_plus(s1, c(s2)) + //def infix_+[R:Manifest](s1: Rep[String], s2: R)(implicit c: R => Rep[Any], o: Overloaded15, pos: SourceContext) = string_plus(s1, c(s2)) def infix_+(s1: Rep[String], s2: Double)(implicit o: Overloaded15, pos: SourceContext) = string_plus(s1, unit(s2)) def infix_+(s1: Rep[String], s2: Float)(implicit o: Overloaded16, pos: SourceContext) = string_plus(s1, unit(s2)) def infix_+(s1: Rep[String], s2: Int)(implicit o: Overloaded17, pos: SourceContext) = string_plus(s1, unit(s2)) def infix_+(s1: Rep[String], s2: Long)(implicit o: Overloaded18, pos: SourceContext) = string_plus(s1, unit(s2)) - def infix_+(s1: Rep[String], s2: Short)(implicit o: Overloaded19, pos: SourceContext) = string_plus(s1, unit(s2)) - + def infix_+(s1: Rep[String], s2: Short)(implicit o: Overloaded19, pos: SourceContext) = string_plus(s1, unit(s2)) + def infix_startsWith(s1: Rep[String], s2: Rep[String])(implicit pos: SourceContext) = string_startswith(s1,s2) + def infix_endsWith(s1: Rep[String], s2: Rep[String])(implicit pos: SourceContext) = string_endswith(s1,s2) + def infix_replaceAll(s1: Rep[String], d1: Rep[String], d2: Rep[String])(implicit pos: SourceContext) = string_replaceAll(s1,d1,d2) def infix_trim(s: Rep[String])(implicit pos: SourceContext) = string_trim(s) - def infix_split(s: Rep[String], separators: Rep[String])(implicit pos: SourceContext) = string_split(s, separators, unit(0)) - def infix_split(s: Rep[String], separators: Rep[String], limit: Rep[Int])(implicit pos: SourceContext) = string_split(s, separators, limit) - def infix_charAt(s: Rep[String], i: Rep[Int])(implicit pos: SourceContext) = string_charAt(s,i) - def infix_endsWith(s: Rep[String], e: Rep[String])(implicit pos: SourceContext) = string_endsWith(s,e) - def infix_contains(s1: Rep[String], s2: Rep[String])(implicit pos: SourceContext) = string_contains(s1,s2) + def infix_length(s: Rep[String])(implicit pos: SourceContext) = string_length(s) + def infix_split(s: Rep[String], separators: Rep[String])(implicit pos: SourceContext) = string_split(s, separators) def infix_toDouble(s: Rep[String])(implicit pos: SourceContext) = string_todouble(s) def infix_toFloat(s: Rep[String])(implicit pos: SourceContext) = string_tofloat(s) def infix_toInt(s: Rep[String])(implicit pos: SourceContext) = string_toint(s) def infix_toLong(s: Rep[String])(implicit pos: SourceContext) = string_tolong(s) - def infix_substring(s: Rep[String], start: Rep[Int], end: Rep[Int])(implicit pos: SourceContext) = string_substring(s,start,end) - - // FIXME: enabling this causes trouble with DeliteOpSuite. investigate!! - //def infix_length(s: Rep[String])(implicit pos: SourceContext) = string_length(s) + def infix_substring(s: Rep[String], beginIndex: Rep[Int])(implicit pos: SourceContext) = string_substring(s, beginIndex) + def infix_substring(s: Rep[String], beginIndex: Rep[Int], endIndex: Rep[Int])(implicit pos: SourceContext) = string_substring(s, beginIndex, endIndex) object String { def valueOf(a: Rep[Any])(implicit pos: SourceContext) = string_valueof(a) } + def string_new(s: Rep[Any]): Rep[String] def string_plus(s: Rep[Any], o: Rep[Any])(implicit pos: SourceContext): Rep[String] def string_startswith(s1: Rep[String], s2: Rep[String])(implicit pos: SourceContext): Rep[Boolean] + def string_endswith(s1: Rep[String], s2: Rep[String])(implicit pos: SourceContext): Rep[Boolean] + def string_replaceAll(s1: Rep[String], d1: Rep[String], d2: Rep[String])(implicit pos: SourceContext): Rep[String] def string_trim(s: Rep[String])(implicit pos: SourceContext): Rep[String] - def string_split(s: Rep[String], separators: Rep[String], limit: Rep[Int])(implicit pos: SourceContext): Rep[Array[String]] + def string_length(s: Rep[String])(implicit pos: SourceContext): Rep[Int] + def string_split(s: Rep[String], separators: Rep[String])(implicit pos: SourceContext): Rep[Array[String]] def string_valueof(d: Rep[Any])(implicit pos: SourceContext): Rep[String] - def string_charAt(s: Rep[String], i: Rep[Int])(implicit pos: SourceContext): Rep[Char] - def string_endsWith(s: Rep[String], e: Rep[String])(implicit pos: SourceContext): Rep[Boolean] - def string_contains(s1: Rep[String], s2: Rep[String])(implicit pos: SourceContext): Rep[Boolean] def string_todouble(s: Rep[String])(implicit pos: SourceContext): Rep[Double] def string_tofloat(s: Rep[String])(implicit pos: SourceContext): Rep[Float] def string_toint(s: Rep[String])(implicit pos: SourceContext): Rep[Int] def string_tolong(s: Rep[String])(implicit pos: SourceContext): Rep[Long] - def string_substring(s: Rep[String], start:Rep[Int], end:Rep[Int])(implicit pos: SourceContext): Rep[String] - def string_length(s: Rep[String])(implicit pos: SourceContext): Rep[Int] + def string_substring(s: Rep[String], beginIndex: Rep[Int])(implicit pos: SourceContext): Rep[String] + def string_substring(s: Rep[String], beginIndex: Rep[Int], endIndex: Rep[Int])(implicit pos: SourceContext): Rep[String] + def string_getBytes(s1: Rep[String])(implicit pos: SourceContext): Rep[Array[Byte]] + def string_containsSlice(s1: Rep[String],s2:Rep[String])(implicit pos: SourceContext): Rep[Boolean] + def string_compareTo(s1: Rep[String],s2:Rep[String])(implicit pos: SourceContext): Rep[Int] + def string_indexOfSlice(s1: Rep[String],s2:Rep[String],idx:Rep[Int])(implicit pos: SourceContext): Rep[Int] } -trait StringOpsExp extends StringOps with VariablesExp { +trait StringOpsExp extends StringOps with VariablesExp with Structs { + case class StringNew(s: Rep[Any]) extends Def[String] case class StringPlus(s: Exp[Any], o: Exp[Any]) extends Def[String] case class StringStartsWith(s1: Exp[String], s2: Exp[String]) extends Def[Boolean] + case class StringEndsWith(s1: Exp[String], s2: Exp[String]) extends Def[Boolean] { + val lensuf = fresh[Int] + val lenstr = fresh[Int] + } + case class StringReplaceAll(s1: Exp[String], d1: Exp[String], d2: Exp[String]) extends Def[String] case class StringTrim(s: Exp[String]) extends Def[String] - case class StringSplit(s: Exp[String], separators: Exp[String], limit: Exp[Int]) extends Def[Array[String]] - case class StringEndsWith(s: Exp[String], e: Exp[String]) extends Def[Boolean] - case class StringCharAt(s: Exp[String], i: Exp[Int]) extends Def[Char] + case class StringLength(s: Exp[String]) extends Def[Int] + case class StringSplit(s: Exp[String], separators: Exp[String]) extends Def[Array[String]] case class StringValueOf(a: Exp[Any]) extends Def[String] case class StringToDouble(s: Exp[String]) extends Def[Double] case class StringToFloat(s: Exp[String]) extends Def[Float] case class StringToInt(s: Exp[String]) extends Def[Int] - case class StringContains(s1: Exp[String], s2: Exp[String]) extends Def[Boolean] case class StringToLong(s: Exp[String]) extends Def[Long] - case class StringSubstring(s: Exp[String], start:Exp[Int], end:Exp[Int]) extends Def[String] - case class StringLength(s: Exp[String]) extends Def[Int] + case class StringSubstring(s: Exp[String], beginIndex: Exp[Int]) extends Def[String] + case class StringGetBytes(s: Exp[String]) extends Def[Array[Byte]] + case class StringSubstringWithEndIndex(s: Exp[String], beginIndex: Exp[Int], endIndex: Exp[Int]) extends Def[String] + case class StringContainsSlice(s1: Exp[String], s2: Exp[String]) extends Def[Boolean] + case class StringCompareTo(s1: Exp[String], s2: Exp[String]) extends Def[Int] + case class StringIndexOfSlice(s1: Exp[String], s2: Exp[String], idx: Exp[Int]) extends Def[Int] + def string_new(s: Rep[Any]) = StringNew(s) def string_plus(s: Exp[Any], o: Exp[Any])(implicit pos: SourceContext): Rep[String] = StringPlus(s,o) def string_startswith(s1: Exp[String], s2: Exp[String])(implicit pos: SourceContext) = StringStartsWith(s1,s2) + def string_endswith(s1: Exp[String], s2: Exp[String])(implicit pos: SourceContext) = StringEndsWith(s1,s2) + def string_replaceAll(s1: Exp[String], d1: Exp[String], d2: Exp[String])(implicit pos: SourceContext) = StringReplaceAll(s1,d1,d2) def string_trim(s: Exp[String])(implicit pos: SourceContext) : Rep[String] = StringTrim(s) - def string_split(s: Exp[String], separators: Exp[String], limit: Exp[Int])(implicit pos: SourceContext) : Rep[Array[String]] = StringSplit(s, separators, limit) + def string_length(s: Exp[String])(implicit pos: SourceContext) : Rep[Int] = StringLength(s) + def string_split(s: Exp[String], separators: Exp[String])(implicit pos: SourceContext) : Rep[Array[String]] = StringSplit(s, separators) def string_valueof(a: Exp[Any])(implicit pos: SourceContext) = StringValueOf(a) - def string_charAt(s: Exp[String], i: Exp[Int])(implicit pos: SourceContext) = StringCharAt(s,i) - def string_endsWith(s: Exp[String], e: Exp[String])(implicit pos: SourceContext) = StringEndsWith(s,e) - def string_contains(s1: Exp[String], s2: Exp[String])(implicit pos: SourceContext) = StringContains(s1,s2) - def string_todouble(s: Rep[String])(implicit pos: SourceContext) = StringToDouble(s) - def string_tofloat(s: Rep[String])(implicit pos: SourceContext) = StringToFloat(s) - def string_toint(s: Rep[String])(implicit pos: SourceContext) = StringToInt(s) - def string_tolong(s: Rep[String])(implicit pos: SourceContext) = StringToLong(s) - def string_substring(s: Rep[String], start:Rep[Int], end:Rep[Int])(implicit pos: SourceContext) = StringSubstring(s,start,end) - def string_length(s: Rep[String])(implicit pos: SourceContext) = StringLength(s) + def string_todouble(s: Exp[String])(implicit pos: SourceContext) = StringToDouble(s) + def string_tofloat(s: Exp[String])(implicit pos: SourceContext) = StringToFloat(s) + def string_toint(s: Exp[String])(implicit pos: SourceContext) = StringToInt(s) + def string_tolong(s: Exp[String])(implicit pos: SourceContext) = StringToLong(s) + def string_substring(s: Exp[String], beginIndex: Exp[Int])(implicit pos: SourceContext) = StringSubstring(s, beginIndex) + def string_substring(s: Exp[String], beginIndex: Exp[Int], endIndex: Exp[Int])(implicit pos: SourceContext) = StringSubstringWithEndIndex(s, beginIndex, endIndex) + def string_getBytes(s1: Rep[String])(implicit pos: SourceContext) = StringGetBytes(s1) + def string_containsSlice(s1: Rep[String], s2:Rep[String])(implicit pos: SourceContext) = StringContainsSlice(s1,s2) + def string_compareTo(s1: Rep[String],s2:Rep[String])(implicit pos: SourceContext) = StringCompareTo(s1,s2) + def string_indexOfSlice(s1: Rep[String], s2:Rep[String], idx: Rep[Int])(implicit pos: SourceContext) = StringIndexOfSlice(s1,s2,idx) override def mirror[A:Manifest](e: Def[A], f: Transformer)(implicit pos: SourceContext): Exp[A] = (e match { + case StringNew(a) => string_new(f(a)) case StringPlus(a,b) => string_plus(f(a),f(b)) - case StringStartsWith(s1, s2) => string_startswith(f(s1), f(s2)) case StringTrim(s) => string_trim(f(s)) - case StringSplit(s,sep,l) => string_split(f(s),f(sep),f(l)) + case StringStartsWith(s1,s2) => string_startswith(f(s1),f(s2)) + case StringEndsWith(s1,s2) => string_endswith(f(s1),f(s2)) + case StringReplaceAll(s1,d1,d2) => string_replaceAll(f(s1),f(d1),f(d2)) + case StringSplit(s,sep) => string_split(f(s),f(sep)) case StringToDouble(s) => string_todouble(f(s)) case StringToFloat(s) => string_tofloat(f(s)) case StringToInt(s) => string_toint(f(s)) - case StringEndsWith(s, e) => string_endsWith(f(s),f(e)) - case StringCharAt(s,i) => string_charAt(f(s),f(i)) - case StringValueOf(a) => string_valueof(f(a)) - case StringContains(s1,s2) => string_contains(f(s1),f(s2)) - case StringSubstring(s,a,b) => string_substring(f(s),f(a),f(b)) - case StringLength(s) => string_length(f(s)) + case StringToLong(s) => string_tolong(f(s)) + case StringSubstring(s, beginIndex) => string_substring(f(s), f(beginIndex)) + case StringSubstringWithEndIndex(s, beginIndex, endIndex) => string_substring(f(s), f(beginIndex), f(endIndex)) + case StringContainsSlice(s1,s2) => string_containsSlice(f(s1), f(s2)) + case StringCompareTo(s1,s2) => string_compareTo(f(s1),f(s2)) + case StringIndexOfSlice(s1,s2,idx) => string_indexOfSlice(f(s1),f(s2),f(idx)) case _ => super.mirror(e,f) }).asInstanceOf[Exp[A]] } @@ -130,22 +153,23 @@ trait StringOpsExp extends StringOps with VariablesExp { trait ScalaGenStringOps extends ScalaGenBase { val IR: StringOpsExp import IR._ - + override def emitNode(sym: Sym[Any], rhs: Def[Any]) = rhs match { + case StringNew(s1) => emitValDef(sym, src"new String($s1)") case StringPlus(s1,s2) => emitValDef(sym, src"$s1+$s2") case StringStartsWith(s1,s2) => emitValDef(sym, src"$s1.startsWith($s2)") + case StringEndsWith(s1,s2) => emitValDef(sym, src"$s1.endsWith($s2)") + case StringReplaceAll(s1,d1,d2) => emitValDef(sym, src"$s1.replaceAll($d1,$d2)") case StringTrim(s) => emitValDef(sym, src"$s.trim()") - case StringSplit(s, sep, l) => emitValDef(sym, src"$s.split($sep,$l)") - case StringEndsWith(s, e) => emitValDef(sym, "%s.endsWith(%s)".format(quote(s), quote(e))) - case StringCharAt(s,i) => emitValDef(sym, "%s.charAt(%s)".format(quote(s), quote(i))) + case StringSplit(s, sep) => emitValDef(sym, src"$s.split($sep)") case StringValueOf(a) => emitValDef(sym, src"java.lang.String.valueOf($a)") case StringToDouble(s) => emitValDef(sym, src"$s.toDouble") case StringToFloat(s) => emitValDef(sym, src"$s.toFloat") case StringToInt(s) => emitValDef(sym, src"$s.toInt") case StringToLong(s) => emitValDef(sym, src"$s.toLong") - case StringContains(s1,s2) => emitValDef(sym, "%s.contains(%s)".format(quote(s1),quote(s2))) - case StringSubstring(s,a,b) => emitValDef(sym, src"$s.substring($a,$b)") - case StringLength(s) => emitValDef(sym, src"$s.length") + case StringGetBytes(s) => emitValDef(sym, src"$s.getBytes") + case StringSubstring(s, beginIndex) => emitValDef(sym, src"$s.substring($beginIndex)") + case StringSubstringWithEndIndex(s, beginIndex, endIndex) => emitValDef(sym, src"$s.substring($beginIndex, $endIndex)") case _ => super.emitNode(sym, rhs) } } @@ -153,44 +177,53 @@ trait ScalaGenStringOps extends ScalaGenBase { trait CudaGenStringOps extends CudaGenBase { val IR: StringOpsExp import IR._ + + override def emitNode(sym: Sym[Any], rhs: Def[Any]) = rhs match { + case StringPlus(s1,s2) => throw new GenerationFailedException("CudaGen: Not GPUable") + case StringTrim(s) => throw new GenerationFailedException("CudaGen: Not GPUable") + case StringSplit(s, sep) => throw new GenerationFailedException("CudaGen: Not GPUable") + case _ => super.emitNode(sym, rhs) + } } trait OpenCLGenStringOps extends OpenCLGenBase { val IR: StringOpsExp import IR._ + + override def emitNode(sym: Sym[Any], rhs: Def[Any]) = rhs match { + case StringPlus(s1,s2) => throw new GenerationFailedException("OpenCLGen: Not GPUable") + case StringTrim(s) => throw new GenerationFailedException("OpenCLGen: Not GPUable") + case StringSplit(s, sep) => throw new GenerationFailedException("OpenCLGen: Not GPUable") + case _ => super.emitNode(sym, rhs) + } } -trait CGenStringOps extends CGenBase { +trait CGenStringOps extends CGenBase with CNestedCodegen { val IR: StringOpsExp import IR._ override def emitNode(sym: Sym[Any], rhs: Def[Any]) = rhs match { - case StringPlus(s1,s2) if remap(s1.tp) == "string" && remap(s2.tp) == "string" => emitValDef(sym,"string_plus(%s,%s)".format(quote(s1),quote(s2))) - case StringStartsWith(s1,s2) => emitValDef(sym, "string_startsWith(%s,%s)".format(quote(s1),quote(s2))) - case StringTrim(s) => emitValDef(sym, "string_trim(%s)".format(quote(s))) - case StringSplit(s, sep, Const(0)) => emitValDef(sym, "string_split(%s,%s)".format(quote(s),quote(sep))) - //case StringEndsWith(s, e) => emitValDef(sym, "(strlen(%s)>=strlen(%s)) && strncmp(%s+strlen(%s)-strlen(%s),%s,strlen(%s))".format(quote(s),quote(e),quote(s),quote(e),quote(s),quote(e),quote(e))) - case StringCharAt(s,i) => emitValDef(sym, "string_charAt(%s,%s)".format(quote(s), quote(i))) - //case StringValueOf(a) => - case StringToDouble(s) => emitValDef(sym, "string_toDouble(%s)".format(quote(s))) - case StringToFloat(s) => emitValDef(sym, "string_toFloat(%s)".format(quote(s))) - case StringToInt(s) => emitValDef(sym, "string_toInt(%s)".format(quote(s))) -/* - case StringSubstring(s,a,b) => emitValDef(sym, src"({ int l=$b-$a; char* r=(char*)malloc(l); memcpy(r,((char*)$s)+$a,l); r[l]=0; r; })") - case StringPlus(s1,s2) => s2.tp.toString match { - // Warning: memory leaks. We need a global mechanism like reference counting, possibly release pool(*) wrapping functions. - // (*) See https://developer.apple.com/library/mac/documentation/Cocoa/Reference/Foundation/Classes/NSAutoreleasePool_Class/Reference/Reference.html - case "java.lang.String" => emitValDef(sym,src"({ int l1=strlen($s1),l2=strlen($s2); char* r=(char*)malloc(l1+l2+1); memcpy(r,$s1,l1); memcpy(r+l1,$s2,l2); r[l1+l2]=0; r; })") - case "Char" => emitValDef(sym,src"({ int l1=strlen($s1); char* r=(char*)malloc(l1+2); memcpy(r,$s1,l1); r[l1]=$s2; r[l1+2]=0; r; })") + case StringNew(s1) => emitValDef(sym, src"$s1") + case StringLength(s1) => emitValDef(sym, src"tpch_strlen($s1)") + case StringPlus(s1,s2) => emitValDef(sym,src"strcat($s1,$s2);") + case StringStartsWith(s1,s2) => emitValDef(sym, "strncmp(" + quote(s1) + "," + quote(s2) + ", tpch_strlen(" + quote(s2) + ")) == 0;") + case sew@StringEndsWith(s1,s2) => { + emitValDef(sew.lenstr,"tpch_strlen("+quote(s1)+")") + emitValDef(sew.lensuf,"tpch_strlen("+quote(s2)+")") + emitValDef(sym, "strncmp(" + quote(s1) + "+" + quote(sew.lenstr) + "-" + quote(sew.lensuf) + "," + quote(s2) + ", " + quote(sew.lensuf) + ") == 0;") } - case StringToInt(s) => emitValDef(sym,src"atoi($s)") - case StringToLong(s) => emitValDef(sym,src"atol($s)") - case StringToFloat(s) => emitValDef(sym,src"atof($s)") - case StringToDouble(s) => emitValDef(sym,src"atof($s)") - case StringLength(s) => emitValDef(sym, src"strlen($s)") + case StringContainsSlice(s1,s2) => + emitValDef(sym, "tpch_strstr(" + quote(s1) + "," + quote(s2) + ") >= " + quote(s1)) + case StringCompareTo(s1,s2) => + emitValDef(sym, "tpch_strcmp(" + quote(s1) + "," + quote(s2) + ")") + case StringIndexOfSlice(s1,s2,idx) => + emitValDef(sym, "tpch_strstr(&(" + quote(s1) + "[" + quote(idx) + "])," + quote(s2) + ") - " + quote(s1)) + stream.println("if (" + quote(sym) + " < 0) " + quote(sym) + " = -1;") + case StringSubstringWithEndIndex(s, beginIndex, endIndex) => + emitValDef(sym, src"malloc($endIndex - $beginIndex + 1); memcpy(" + quote(sym) + "," + quote(s) + src", $endIndex - $beginIndex);" ) + // stream.println(src"char " + quote(sym) + src"[$endIndex - $beginIndex + 1]; memcpy(" + quote(sym) + "," + quote(s) + src", $endIndex - $beginIndex);") UNSAFE case StringTrim(s) => throw new GenerationFailedException("CGenStringOps: StringTrim not implemented yet") case StringSplit(s, sep) => throw new GenerationFailedException("CGenStringOps: StringSplit not implemented yet") -*/ case _ => super.emitNode(sym, rhs) } } diff --git a/src/common/Structs.scala b/src/common/Structs.scala index 6e23349a..1fae7bbf 100644 --- a/src/common/Structs.scala +++ b/src/common/Structs.scala @@ -1,29 +1,53 @@ package scala.lms package common -import reflect.{SourceContext, RefinedManifest} -import util.OverloadHack -import java.io.PrintWriter -import internal.{GenericNestedCodegen, GenericFatCodegen} +import scala.lms.common._ -abstract class Record extends Struct +import scala.lms.util.OverloadHack +import scala.lms.internal.{GenericNestedCodegen, GenericFatCodegen} +import scala.reflect.{SourceContext, RefinedManifest} +import java.io.{StringWriter,PrintWriter} +import scala.language.dynamics -trait StructOps extends Base { +/** + * Taken char-for-char from the delite-develop branch of lms + */ + +trait Structs extends Base with Variables { /** * Allows to write things like “val z = new Record { val re = 1.0; val im = -1.0 }; print(z.re)” */ - + abstract class Record extends Struct + abstract class CompositeRecord[T1:Manifest,T2:Manifest] extends Record def __new[T : Manifest](args: (String, Boolean, Rep[T] => Rep[_])*): Rep[T] = record_new(args) - class RecordOps(record: Rep[Record]) { - def selectDynamic[T : Manifest](field: String): Rep[T] = record_select[T](record, field) + class RecordOps[T1<:Record:Manifest](record: Rep[T1]) extends Dynamic { + def apply[TF: Manifest](field: String): Rep[TF] = record_select[T1,TF](record, field) + def selectDynamic[TF : Manifest](field: String): Rep[TF] = record_select[T1,TF](record, field) + def concatenate[T2: Manifest](record2: Rep[T2], leftAlias: String = "", rightAlias: String = ""): Rep[CompositeRecord[T1,T2]] = record_concatenate[T1,T2](record, record2, leftAlias, rightAlias) + def print = record_print[T1](record) + } + implicit def recordToRecordOps[T<:Record:Manifest](record: Rep[T]) = new RecordOps[T](record) + implicit def varrecordToRecordOps[T<:Record:Manifest](record: Var[T]) = new RecordOps[T](readVar(record)) + def infix_f[T<:Record:Manifest](record: Rep[T]) = new RecordOps(record)(manifest[T]) + + def registerStruct[T<:Record:Manifest](name: String, elems: Seq[(String, Manifest[_])]) + def structName[T](m: Manifest[T]): String + + object DefaultRecord { + def apply[T:Manifest]() = default_record[T]() } - implicit def recordToRecordOps(record: Rep[Record]) = new RecordOps(record) + def default_record[T: Manifest](): Rep[T] def record_new[T : Manifest](fields: Seq[(String, Boolean, Rep[T] => Rep[_])]): Rep[T] - def record_select[T : Manifest](record: Rep[Record], field: String): Rep[T] + def record_new[T:Manifest](structName: String, fieldSyms: Seq[(String, Rep[Any])]): Rep[T] + def record_select[T1:Manifest, TF:Manifest](record: Rep[T1], field: String): Rep[TF] + def record_concatenate[T1:Manifest, T2:Manifest](record: Rep[T1], record2: Rep[T2], leftAlias: String = "", rightAlias: String = ""): Rep[CompositeRecord[T1,T2]] + def record_print[T<:Record:Manifest](record: Rep[T]): Rep[Unit] def field[T:Manifest](struct: Rep[Any], index: String)(implicit pos: SourceContext): Rep[T] + def record_hash[T:Manifest](record: Rep[T]): Rep[Int] + def record_equals[T:Manifest](record1: Rep[T], record2: Rep[T]): Rep[Boolean] } trait StructTags { @@ -34,7 +58,7 @@ trait StructTags { case class MapTag[T]() extends StructTag[T] } -trait StructExp extends StructOps with StructTags with BaseExp with EffectExp with VariablesExp with ObjectOpsExp with StringOpsExp with OverloadHack { +trait StructExp extends Structs with StructTags with EffectExp with WhileExp with VariablesExp with ObjectOpsExp with StringOpsExp with FunctionsExp with MiscOpsExp with RangeOpsExp with ArrayOps with BooleanOps with Equal with PrimitiveOps with NumericOps with OrderingOps { // TODO: structs should take Def parameters that define how to generate constructor and accessor calls @@ -49,11 +73,11 @@ trait StructExp extends StructOps with StructTags with BaseExp with EffectExp wi } /* override def fresh[T:Manifest] = manifest[T] match { - case s if s <:< manifest[Record] => - val m = spawnRefinedManifest - super.fresh(m) - case _ => super.fresh - } */ //TODO: best way to ensure full structural type is always available? +case s if s <:< manifest[Record] => +val m = spawnRefinedManifest +super.fresh(m) +case _ => super.fresh +} */ //TODO: best way to ensure full structural type is always available? object Struct { def unapply[T](d: Def[T]) = unapplyStruct(d) @@ -73,8 +97,22 @@ trait StructExp extends StructOps with StructTags with BaseExp with EffectExp wi case _ => None } - case class SimpleStruct[T](tag: StructTag[T], elems: Seq[(String, Rep[Any])]) extends AbstractStruct[T] - case class FieldApply[T](struct: Rep[Any], index: String) extends AbstractField[T] + case class DefaultRecordDef[T:Manifest]() extends Def[T] { + val m = manifest[T] + } + case class SimpleStruct[T:Manifest](tag: StructTag[T], elems: Seq[(String, Rep[Any])]) extends AbstractStruct[T] { + if (tag.isInstanceOf[ClassTag[_]]) registerStruct(tag.asInstanceOf[ClassTag[_]].name, elems.map(e => (e._1, e._2.tp))) + } + case class ConcatenateRecords[T1:Manifest, T2:Manifest](x: Rep[T1], y: Rep[T2], leftAlias: String, rightAlias: String) extends Def[CompositeRecord[T1,T2]] { + val m1 = manifest[T1] + val m2 = manifest[T2] + } + case class RecordPrint[T<:Record:Manifest](rec: Rep[T]) extends Def[Unit] + case class RecordHash[T:Manifest](rec: Rep[T]) extends Def[Int] + case class RecordEquals[T:Manifest](rec1: Rep[T], rec2: Rep[T]) extends Def[Boolean] + case class FieldApply[T:Manifest](struct: Rep[Any], index: String) extends AbstractField[T] { + val m = manifest[T] + } case class FieldUpdate[T:Manifest](struct: Exp[Any], index: String, rhs: Exp[T]) extends Def[Unit] def struct[T:Manifest](tag: StructTag[T], elems: (String, Rep[Any])*)(implicit o: Overloaded1, pos: SourceContext): Rep[T] = struct[T](tag, elems) @@ -84,34 +122,45 @@ trait StructExp extends StructOps with StructTags with BaseExp with EffectExp wi def var_field[T:Manifest](struct: Rep[Any], index: String)(implicit pos: SourceContext): Var[T] = Variable(FieldApply[Var[T]](struct, index)) def field_update[T:Manifest](struct: Exp[Any], index: String, rhs: Exp[T]): Exp[Unit] = reflectWrite(struct)(FieldUpdate(struct, index, rhs)) + def default_record[T: Manifest]() = DefaultRecordDef[T]() def record_new[T : Manifest](fields: Seq[(String, Boolean, Rep[T] => Rep[_])]) = { val x: Sym[T] = Sym[T](-99) // self symbol -- not defined anywhere, so make it obvious!! (TODO) val fieldSyms = fields map { case (index, false, rhs) => (index, rhs(x)) case (index, true, rhs) => (index, var_new(rhs(x)).e) } - struct(AnonTag(manifest.asInstanceOf[RefinedManifest[T]]), fieldSyms) - } - - def record_select[T : Manifest](record: Rep[Record], fieldName: String) = { - field(record, fieldName) - } - - def imm_field(struct: Exp[Any], name: String, f: Exp[Any])(implicit pos: SourceContext): Exp[Any] = { - if (f.tp.erasure.getSimpleName == "Variable") { - field(struct,name)(mtype(f.tp.typeArguments(0)),pos) - } - else { - object_unsafe_immutable(f)(mtype(f.tp),pos) - } - } - - // don't let unsafeImmutable hide struct-ness - override def object_unsafe_immutable[A:Manifest](lhs: Exp[A])(implicit pos: SourceContext) = lhs match { - case Def(Struct(tag,elems)) => struct[A](tag, elems.map(t => (t._1, imm_field(lhs, t._1, t._2)))) - case Def(d@Reflect(Struct(tag, elems), u, es)) => struct[A](tag, elems.map(t => (t._1, imm_field(lhs, t._1, t._2)))) - case _ => super.object_unsafe_immutable(lhs) - } + struct(ClassTag(structName(manifest[T])),/*AnonTag(manifest[T].asInstanceOf[RefinedManifest[T]]),*/ fieldSyms) + } + + def record_new[T:Manifest](structName: String, fieldSyms: Seq[(String, Rep[Any])]) = { + struct(ClassTag(structName), fieldSyms) + } + + def record_select[T1 : Manifest, TF: Manifest](record: Rep[T1], fieldName: String) = { + field[TF](record, fieldName) + } + def record_concatenate[T1:Manifest,T2:Manifest](record: Rep[T1], record2: Rep[T2], leftAlias: String = "", rightAlias: String = "") = { + val name1 = structName(manifest[T1]).replace("CompositeRecord","") + val s1 = { + if (encounteredStructs.contains(name1)) encounteredStructs(name1) + else name1.split("Anon").filter(x => x.length != 0).map(x => encounteredStructs("Anon" + x)).flatten.toList + } + val name2 = structName(manifest[T2]).replace("CompositeRecord","") + val s2 = { + if (encounteredStructs.contains(name2)) encounteredStructs(name2) + else name2.split("Anon").filter(x => x.length != 0).map(x => encounteredStructs("Anon" + x)).flatten.toList + } + val elems: Seq[(String, Manifest[_])] = { + for (s <- s1) yield leftAlias + s._1 -> s._2 + } ++ { + for (s <- s2) yield rightAlias + s._1 -> s._2 + } + registerStruct(name1 + name2, elems) + /*reflectEffect*/(ConcatenateRecords(record, record2, leftAlias, rightAlias)(record.tp,record2.tp)) + } + def record_print[T<:Record:Manifest](rec: Rep[T]) = reflectEffect(RecordPrint[T](rec)) + def record_hash[T:Manifest](rec: Rep[T]) = reflectEffect(RecordHash(rec)) + def record_equals[T:Manifest](rec1: Rep[T], rec2: Rep[T]) = reflectEffect(RecordEquals(rec1,rec2)) override def syms(e: Any): List[Sym[Any]] = e match { case s:AbstractStruct[_] => s.elems.flatMap(e => syms(e._2)).toList @@ -167,18 +216,26 @@ trait StructExp extends StructOps with StructTags with BaseExp with EffectExp wi override def mirror[A:Manifest](e: Def[A], f: Transformer)(implicit pos: SourceContext): Exp[A] = (e match { case SimpleStruct(tag, elems) => struct(tag, elems map { case (k,v) => (k, f(v)) })(mtype(manifest[A]),pos) - case FieldApply(struct, key) => field(f(struct), key)(mtype(manifest[A]),pos) - case Reflect(FieldApply(struct, key), u, es) => reflectMirrored(Reflect(FieldApply(f(struct), key), mapOver(f,u), f(es)))(mtype(manifest[A]), pos) - case Reflect(FieldUpdate(struct, key, rhs), u, es) => reflectMirrored(Reflect(FieldUpdate(f(struct), key, f(rhs)), mapOver(f,u), f(es)))(mtype(manifest[A]), pos) - case Reflect(SimpleStruct(tag, elems), u, es) => reflectMirrored(Reflect(SimpleStruct(tag, elems map { case (k,v) => (k, f(v)) }), mapOver(f,u), f(es)))(mtype(manifest[A]), pos) + case fa@FieldApply(struct, key) => record_select(f(struct), key)(struct.tp,fa.m) + case cr@ConcatenateRecords(rec1,rec2,leftAlias,rightAlias) => record_concatenate(f(rec1),f(rec2),leftAlias,rightAlias)(rec1.tp,rec2.tp) + case Reflect(fa@FieldApply(struct, key), u, es) => record_select(f(struct),key)(struct.tp,fa.m) + case Reflect(FieldUpdate(struct, key, rhs), u, es) => reflectMirrored(Reflect(FieldUpdate(f(struct), key, f(rhs)), mapOver(f,u), f(es)))(mtype(manifest[A])) + case Reflect(SimpleStruct(tag, elems), u, es) => reflectMirrored(Reflect(SimpleStruct(tag, elems map { case (k,v) => (k, f(v)) }), mapOver(f,u), f(es)))(mtype(manifest[A])) + case Reflect(RecordPrint(rec), u, es) => reflectMirrored(Reflect(RecordPrint(f(rec)), mapOver(f,u), f(es))) + case Reflect(RecordHash(rec), u, es) => reflectMirrored(Reflect(RecordHash(f(rec)), mapOver(f,u), f(es))) + case Reflect(RecordEquals(rec1,rec2), u, es) => reflectMirrored(Reflect(RecordEquals(f(rec1),f(rec2)), mapOver(f,u), f(es))) + case Reflect(cr@ConcatenateRecords(rec1,rec2,leftAlias,rightAlias), u, es) => record_concatenate(f(rec1),f(rec2),leftAlias,rightAlias)(rec1.tp,rec2.tp) + case dr@DefaultRecordDef() => default_record()(dr.m) case _ => super.mirror(e,f) }).asInstanceOf[Exp[A]] def structName[T](m: Manifest[T]): String = m match { // FIXME: move to codegen? we should be able to have different policies/naming schemes - case rm: RefinedManifest[_] => "Anon" + math.abs(rm.fields.map(f => f._1.## + f._2.toString.##).sum) - case _ if (m <:< manifest[AnyVal]) => m.toString - case _ if m.erasure.isArray => "ArrayOf" + structName(m.typeArguments.head) + case s if s <:< manifest[CompositeRecord[Any,Any]] => m.typeArguments.map(structName(_)).mkString + case rm: RefinedManifest[_] => "Anon" + math.abs(rm.fields.zipWithIndex.map { case (f, i) => (f._1.## + f._2.toString.##) * (i + 1) /* don't want to multiply by 0 */ }.sum) + case s if m.erasure.isArray => "ArrayOf" + m.typeArguments.map(a => structName(a)).mkString("") + case s if (m <:< manifest[AnyVal]) => m.toString + case s if m.erasure.getSimpleName == "Tuple2" => "Tuple2" + m.typeArguments.foldLeft("")((x,y) => x + structName(y)) case _ => m.erasure.getSimpleName + m.typeArguments.map(a => structName(a)).mkString("") } @@ -191,8 +248,8 @@ trait StructExp extends StructOps with StructTags with BaseExp with EffectExp wi case _ => super.object_tostring(x) } - def registerStruct[T](name: String, elems: Seq[(String, Rep[Any])]) { - encounteredStructs += name -> elems.map(e => (e._1, e._2.tp)) + def registerStruct[T<:Record:Manifest](name: String, elems: Seq[(String, Manifest[_])]) = { + encounteredStructs += name -> elems } val encounteredStructs = new scala.collection.mutable.HashMap[String, Seq[(String, Manifest[_])]] } @@ -223,10 +280,10 @@ trait StructExpOpt extends StructExp { //TODO: need to be careful unwrapping Structs of vars since partial unwrapping can result in reads & writes to two different memory locations in the generated code //(the original var and the struct) /* override def var_field[T:Manifest](struct: Exp[Any], index: String)(implicit pos: SourceContext): Var[T] = fieldLookup(struct, index) match { - case Some(x: Exp[Var[T]]) if x.tp == manifest[Var[T]] => Variable(x) - case Some(x) => throw new RuntimeException("ERROR: " + index + " is not a variable field of type " + struct.tp) - case None => super.var_field(struct, index) - } */ +case Some(x: Exp[Var[T]]) if x.tp == manifest[Var[T]] => Variable(x) +case Some(x) => throw new RuntimeException("ERROR: " + index + " is not a variable field of type " + struct.tp) +case None => super.var_field(struct, index) +} */ } @@ -305,12 +362,12 @@ trait StructFatExpOptCommon extends StructFatExp with StructExpOptCommon with If def phiB[T:Manifest](c: Exp[Boolean], a1: Block[Unit], a2: Block[T], b1: Block[Unit], b2: Block[T])(parent: Exp[Unit]): Exp[T] = if (a2 == b2) a2.res else Phi(c,a1,a2,b1,b2)(parent) // FIXME: duplicate override def syms(x: Any): List[Sym[Any]] = x match { - // case Phi(c,a,u,b,v) => syms(List(c,a,b)) + // case Phi(c,a,u,b,v) => syms(List(c,a,b)) case _ => super.syms(x) } override def symsFreq(e: Any): List[(Sym[Any], Double)] = e match { - // case Phi(c,a,u,b,v) => freqNormal(c) ++ freqCold(a) ++ freqCold(b) + // case Phi(c,a,u,b,v) => freqNormal(c) ++ freqCold(a) ++ freqCold(b) case _ => super.symsFreq(e) } @@ -333,7 +390,7 @@ trait StructFatExpOptCommon extends StructFatExp with StructExpOptCommon with If override def ifThenElse[T:Manifest](cond: Rep[Boolean], a: Block[T], b: Block[T])(implicit pos: SourceContext) = (deReify(a),deReify(b)) match { case ((u, Def(Struct(tagA,elemsA))), (v, Def(Struct(tagB, elemsB)))) => //assert(tagA == tagB, tagA+" !== "+tagB) - if (tagA != tagB) println("ERROR: "+tagA+" !== "+tagB) + if (tagA != tagB) System.out.println("ERROR: "+tagA+" !== "+tagB) // create stm that computes all values at once // return struct of syms val combinedResult = super.ifThenElse(cond,u,v) @@ -346,7 +403,6 @@ trait StructFatExpOptCommon extends StructFatExp with StructExpOptCommon with If case _ => super.ifThenElse(cond,a,b) } - } trait BaseGenFatStruct extends GenericFatCodegen { @@ -368,7 +424,7 @@ trait BaseGenFatStruct extends GenericFatCodegen { val ss = phis collect { case TP(s, _) => s } val us = phis collect { case TP(_, Phi(c,a,u,b,v)) => u } // assert c,a,b match val vs = phis collect { case TP(_, Phi(c,a,u,b,v)) => v } - val c = phis collect { case TP(_, Phi(c,a,u,b,v)) => c } reduceLeft { (c1,c2) => assert(c1 == c2); c1 } + val c = phis collect { case TP(_, Phi(c,a,u,b,v)) => c } reduceLeft { (c1,c2) => assert(c1 == c2); c1 } TTP(ss, phis map (_.rhs), SimpleFatIfThenElse(c,us,vs)) } def fatif(s:Sym[Unit],o:Def[Unit],c:Exp[Boolean],a:Block[Unit],b:Block[Unit]) = fatphi(s) match { @@ -416,26 +472,77 @@ trait ScalaGenStruct extends ScalaGenBase with BaseGenStruct { import IR._ override def emitNode(sym: Sym[Any], rhs: Def[Any]) = rhs match { + case r@DefaultRecordDef() => + def defaultValue[T: ClassManifest]: T = classManifest[T].erasure.toString match { + case "boolean" => false.asInstanceOf[T] + case "byte" => (0: Byte).asInstanceOf[T] + case "short" => (0: Short).asInstanceOf[T] + case "char" => '\0'.asInstanceOf[T] + case "int" => 0.asInstanceOf[T] + case "long" => 0L.asInstanceOf[T] + case "float" => 0.0F.asInstanceOf[T] + case "double" => 0.0.asInstanceOf[T] + case _ => null.asInstanceOf[T] + } + val name = structName(r.m) + val fieldTypes = encounteredStructs(name).map(_._2) + val caseClassStr = if (fieldTypes.length < 22) "" else "new " + stream.println("val " + quote(sym) + " = " + caseClassStr + name + "(" + (for (f <- fieldTypes) yield defaultValue(f)).mkString(",") + ")") case Struct(tag, elems) => - registerStruct(structName(sym.tp), elems) - emitValDef(sym, "new " + structName(sym.tp) + "(" + elems.map(e => quote(e._2)).mkString(",") + ")") + val header = if (elems.length < 22) "" else "new " + emitValDef(sym, header + structName(sym.tp) + "(" + elems.map(e => quote(e._2)).mkString(",") + ")") case FieldApply(struct, index) => emitValDef(sym, quote(struct) + "." + index) case FieldUpdate(struct, index, rhs) => emitValDef(sym, quote(struct) + "." + index + " = " + quote(rhs)) + case RecordPrint(t) => + stream.println(src"println($t.toString)") + case RecordHash(t) => + emitValDef(sym, quote(t) + ".hashCode") + case RecordEquals(t1,t2) => + emitValDef(sym, quote(t1) + " equals " + quote(t2)) + case c@ConcatenateRecords(record1, record2, leftAlias, rightAlias) => + val name1 = structName(c.m1).replace("CompositeRecord", "") + val s1 = encounteredStructs(name1) + val name2 = structName(c.m2).replace("CompositeRecord", "") + val s2 = encounteredStructs(name2) + val header = if (s1.length + s2.length < 22) "" else "new " + emitValDef(sym, header + name1 + name2 + "(" + s1.map(x => quote(record1) + "." + x._1).mkString(",") + "," + s2.map(x => quote(record2) + "." + x._1).mkString(",") + ")") case _ => super.emitNode(sym, rhs) } override def remap[A](m: Manifest[A]) = m match { - case s if s <:< manifest[Record] => structName(m) + case s if s <:< manifest[Record] => structName(m).replace("CompositeRecord","") case _ => super.remap(m) } override def emitDataStructures(stream: PrintWriter) { for ((name, elems) <- encounteredStructs) { stream.println() - stream.print("case class " + name + "(") - stream.println(elems.map(e => e._1 + ": " + remap(e._2)).mkString(", ") + ")") + if (elems.length < 22) stream.print("case ") + stream.print("class " + name + "(") + stream.println(elems.map(e => "val " + e._1 + ": " + remap(e._2)).mkString(", ") + ") {") + stream.println("override def toString() = {") + stream.println(elems.map(e => { + if (e._2 == manifest[Array[Byte]]) "(if(" + e._1 + " != null) new String(" + e._1 + ") else \"\")" + else if (e._2.erasure.isArray) e._1 + ".mkString(\" \")" + else e._1 + }).mkString(" + \"|\" + ") + "+\"\"") + stream.println("}") + stream.println("override def hashCode() = {") + stream.println(elems.map(e => { + if (e._2.erasure.isArray) e._1 + ".foldLeft(0) { (cnt,x) => cnt + x.## }" + else e._1 + ".hashCode" + }).mkString(" + ")) + stream.println("}") + stream.println("override def equals(y: Any) = {") + stream.println("val e = y.asInstanceOf[" + name + "]") + stream.println(elems.map(e => { + if (e._2.erasure.isArray) e._1 + ".corresponds(e." + e._1 + "){_==_}" + else e._1 + " == e." + e._1 + }).mkString(" && ")) + stream.println("}") + stream.println("}") } stream.flush() super.emitDataStructures(stream) @@ -443,10 +550,97 @@ trait ScalaGenStruct extends ScalaGenBase with BaseGenStruct { } -trait CGenStruct extends CGenBase with BaseGenStruct -trait CudaGenStruct extends CudaGenBase with BaseGenStruct +trait CGenStruct extends CGenBase with BaseGenStruct { + val IR: StructExp + import IR._ + def remapToPrintFDescr[A:Manifest](m: Manifest[A]): String = m match { + case s if m == manifest[Int] => "%d|" + case s if m == manifest[Double] => "%lf|" + case s if m == manifest[java.lang.Character] => "%c|" + case s if m == manifest[Long] => "%lu|" + case s if m == manifest[Array[Byte]] => "%s|" + case s if m == manifest[Byte] => "%c|" + case s if m == manifest[java.lang.String] => "%s|" + case dflt@_ => throw new Exception("Unsupported printf descr " + dflt + " when emitting struct. Stringifying...") + } + + override def emitNode(sym: Sym[Any], rhs: Def[Any]) = rhs match { + case r@DefaultRecordDef() => + def defaultValue[T: ClassManifest]: T = classManifest[T].erasure.toString match { + case "boolean" => false.asInstanceOf[T] + case "byte" => (0: Byte).asInstanceOf[T] + case "short" => (0: Short).asInstanceOf[T] + case "char" => '\0'.asInstanceOf[T] + case "int" => 0.asInstanceOf[T] + case "long" => 0L.asInstanceOf[T] + case "float" => 0.0F.asInstanceOf[T] + case "double" => 0.0.asInstanceOf[T] + case _ => "NULL".asInstanceOf[T] + } + val name = structName(r.m) + val fields = encounteredStructs(name) + allocStruct(sym, "struct " + structName(r.m), stream) + //fields.foreach( f => stream.println(quote(sym) + "->" + f._1 + " = " + defaultValue(f._2) + ";")) + fields.foreach( f => stream.println(quote(sym) + "." + f._1 + " = " + defaultValue(f._2) + ";")) + case Struct(tag, elems) => + val fields = encounteredStructs(tag.asInstanceOf[ClassTag[_]].name).map(x => x._1) zip elems + allocStruct(sym, "struct " + tag.asInstanceOf[ClassTag[_]].name, stream) + //fields.foreach( f => stream.println(quote(sym) + "->" + f._1 + " = " + quote(f._2._2) + ";")) + fields.foreach( f => stream.println(quote(sym) + "." + f._1 + " = " + quote(f._2._2) + ";")) + case fa@FieldApply(struct, index) => + //emitValDef(sym, quote(struct) + "->" + index + ";") + emitValDef(sym, quote(struct) + "." + index + ";") + case FieldUpdate(struct, index, rhs) => + emitValDef(sym, quote(struct) + "->" + index + " = " + quote(rhs)) + case RecordPrint(t) => + stream.println("print_" + structName(t.tp) + "(" + quote(t) + ");") + case c@ConcatenateRecords(record1, record2, leftAlias, rightAlias) => + val name1 = structName(record1.tp).replace("CompositeRecord", "") + val s1 = encounteredStructs(name1) + val name2 = structName(record2.tp).replace("CompositeRecord", "") + val s2 = encounteredStructs(name2) + allocStruct(sym, remap(sym.tp).replace("*",""), stream) + //stream.println(s1.map(x => quote(sym) + "->" + leftAlias + x._1 + " = " + quote(record1) + "->" + x._1).mkString(";\n") + ";") + //stream.println(s2.map(x => quote(sym) + "->" + rightAlias + x._1 + " = " + quote(record2) + "->" + x._1).mkString(";\n") + ";") + stream.println(s1.map(x => quote(sym) + "." + leftAlias + x._1 + " = " + quote(record1) + "." + x._1).mkString(";\n") + ";") + stream.println(s2.map(x => quote(sym) + "." + rightAlias + x._1 + " = " + quote(record2) + "." + x._1).mkString(";\n") + ";") + case _ => super.emitNode(sym, rhs) + } + + override def remap[A](m: Manifest[A]) = m match { + case s if s <:< manifest[CompositeRecord[Any,Any]] => "struct " + structName(m) // + "*" + case s if s <:< manifest[Record] => "struct " + structName(m) // + "*" + case s if s.toString.contains("Pointer") => // TODO find a better place + remap(m.typeArguments.head) + "*" + case _ => super.remap(m) + } + + override def emitDataStructures(stream: PrintWriter) { + // Forward references to resolve dependencies + val hs = new scala.collection.mutable.LinkedHashMap[String,Seq[(String, Manifest[_])]] + def hit(name: String, xs: Seq[(String,Manifest[_])]): Unit = { + xs foreach { x => + val name = structName(x._2) + encounteredStructs.get(name).map(x => hit(name, x)) + } + hs(name) = xs + } + encounteredStructs.foreach((hit _).tupled) + + for ((name, elems) <- hs) { + stream.println() + stream.println("struct " + name + " {") + for(e <- elems) stream.println(remap(e._2) + " " + e._1 + ";") + stream.println("};") + } + stream.flush() + super.emitDataStructures(stream) + } +} +/*trait CudaGenStruct extends CudaGenBase with BaseGenStruct trait OpenCLGenStruct extends OpenCLGenBase with BaseGenStruct trait CudaGenFatStruct extends CudaGenStruct with BaseGenFatStruct trait OpenCLGenFatStruct extends OpenCLGenStruct with BaseGenFatStruct trait CGenFatStruct extends CGenStruct with BaseGenFatStruct +*/ diff --git a/src/common/Variables.scala b/src/common/Variables.scala index 98ca5020..8244ce5a 100755 --- a/src/common/Variables.scala +++ b/src/common/Variables.scala @@ -3,6 +3,7 @@ package common import java.io.PrintWriter import scala.reflect.SourceContext +import scala.collection.mutable import scala.lms.util.OverloadHack import scala.reflect.SourceContext @@ -53,7 +54,13 @@ trait Variables extends Base with OverloadHack with VariableImplicits with ReadV def var_plusequals[T:Manifest](lhs: Var[T], rhs: Rep[T])(implicit pos: SourceContext): Rep[Unit] def var_minusequals[T:Manifest](lhs: Var[T], rhs: Rep[T])(implicit pos: SourceContext): Rep[Unit] def var_timesequals[T:Manifest](lhs: Var[T], rhs: Rep[T])(implicit pos: SourceContext): Rep[Unit] + def var_divideequals[T:Manifest](lhs: Var[T], rhs: Rep[T])(implicit pos: SourceContext): Rep[Unit] + def var_tripleshift[T:Manifest](lhs: Var[T], rhs: Rep[T])(implicit pos: SourceContext): Rep[T] + def var_doubleshift[T:Manifest](lhs: Var[T], rhs: Rep[T])(implicit pos: SourceContext): Rep[T] + def var_leftshift[T:Manifest](lhs: Var[T], rhs: Rep[T])(implicit pos: SourceContext): Rep[T] + def var_logicalOr[T:Manifest](lhs: Var[T], rhs: Rep[T])(implicit pos: SourceContext): Rep[T] + def var_logicalAnd[T:Manifest](lhs: Var[T], rhs: Rep[T])(implicit pos: SourceContext): Rep[T] def __assign[T:Manifest](lhs: Var[T], rhs: T)(implicit pos: SourceContext) = var_assign(lhs, unit(rhs)) def __assign[T](lhs: Var[T], rhs: Rep[T])(implicit o: Overloaded1, mT: Manifest[T], pos: SourceContext) = var_assign(lhs, rhs) @@ -76,6 +83,12 @@ trait Variables extends Base with OverloadHack with VariableImplicits with ReadV def infix_/=[T](lhs: Var[T], rhs: T)(implicit o: Overloaded1, mT: Manifest[T], pos: SourceContext) = var_divideequals(lhs, unit(rhs)) def infix_/=[T](lhs: Var[T], rhs: Rep[T])(implicit o: Overloaded2, mT: Manifest[T], pos: SourceContext) = var_divideequals(lhs,rhs) def infix_/=[T](lhs: Var[T], rhs: Var[T])(implicit o: Overloaded3, mT: Manifest[T], pos: SourceContext) = var_divideequals(lhs,readVar(rhs)) + def infix_>>>[T](lhs: Var[T], rhs: Rep[T])(implicit o: Overloaded3, mT: Manifest[T], pos: SourceContext) = var_tripleshift(lhs,rhs) + def infix_>>[T](lhs: Var[T], rhs: Rep[T])(implicit o: Overloaded3, mT: Manifest[T], pos: SourceContext) = var_doubleshift(lhs,rhs) + def infix_>>[T](lhs: Var[T], rhs: Var[T])(implicit o: Overloaded4, mT: Manifest[T], pos: SourceContext) = var_doubleshift(lhs,readVar(rhs)) + def infix_<<[T](lhs: Var[T], rhs: Rep[T])(implicit o: Overloaded3, mT: Manifest[T], pos: SourceContext) = var_leftshift(lhs,rhs) + def infix_|[T](lhs: Var[T], rhs: Var[T])(implicit o: Overloaded3, mT: Manifest[T], pos: SourceContext) = var_logicalOr(lhs,readVar(rhs)) + def infix_&[T](lhs: Var[T], rhs: Var[T])(implicit o: Overloaded3, mT: Manifest[T], pos: SourceContext) = var_logicalAnd(lhs,readVar(rhs)) } trait VariablesExp extends Variables with ImplicitOpsExp with VariableImplicits with ReadVarImplicitExp { @@ -95,6 +108,11 @@ trait VariablesExp extends Variables with ImplicitOpsExp with VariableImplicits case class VarMinusEquals[T:Manifest](lhs: Var[T], rhs: Exp[T]) extends Def[Unit] case class VarTimesEquals[T:Manifest](lhs: Var[T], rhs: Exp[T]) extends Def[Unit] case class VarDivideEquals[T:Manifest](lhs: Var[T], rhs: Exp[T]) extends Def[Unit] + case class VarDoubleShift[T:Manifest](lhs: Var[T], rhs: Exp[T]) extends Def[T] + case class VarTripleShift[T:Manifest](lhs: Var[T], rhs: Exp[T]) extends Def[T] + case class VarLeftShift[T:Manifest](lhs: Var[T], rhs: Exp[T]) extends Def[T] + case class VarLogicalOr[T:Manifest](lhs: Var[T], rhs: Exp[T]) extends Def[T] + case class VarLogicalAnd[T:Manifest](lhs: Var[T], rhs: Exp[T]) extends Def[T] def var_new[T:Manifest](init: Exp[T])(implicit pos: SourceContext): Var[T] = { //reflectEffect(NewVar(init)).asInstanceOf[Var[T]] @@ -125,6 +143,22 @@ trait VariablesExp extends Variables with ImplicitOpsExp with VariableImplicits reflectWrite(lhs.e)(VarDivideEquals(lhs, rhs)) Const() } + + def var_tripleshift[T:Manifest](lhs: Var[T], rhs: Exp[T])(implicit pos: SourceContext): Exp[T] = { + reflectEffect(VarTripleShift(lhs,rhs)) + } + def var_doubleshift[T:Manifest](lhs: Var[T], rhs: Exp[T])(implicit pos: SourceContext): Exp[T] = { + reflectEffect(VarDoubleShift(lhs,rhs)) + } + def var_leftshift[T:Manifest](lhs: Var[T], rhs: Exp[T])(implicit pos: SourceContext): Exp[T] = { + reflectEffect(VarLeftShift(lhs,rhs)) + } + def var_logicalOr[T:Manifest](lhs: Var[T], rhs: Exp[T])(implicit pos: SourceContext): Exp[T] = { + reflectEffect(VarLogicalOr(lhs,rhs)) + } + def var_logicalAnd[T:Manifest](lhs: Var[T], rhs: Exp[T])(implicit pos: SourceContext): Exp[T] = { + reflectEffect(VarLogicalAnd(lhs,rhs)) + } override def aliasSyms(e: Any): List[Sym[Any]] = e match { case NewVar(a) => Nil @@ -170,17 +204,38 @@ trait VariablesExp extends Variables with ImplicitOpsExp with VariableImplicits case _ => super.copySyms(e) } - + def findInitSymbol(s: Exp[_]): Exp[_] = { + println(s) + findDefinition(s.asInstanceOf[Sym[_]]).get match { + case TP(_, Reflect(v @ ReadVar(Variable(x)),_,_)) => findInitSymbol(x) + case TP(_, Reflect(NewVar(x),_,_)) => { + if (x.tp != manifest[Nothing]) findInitSymbol(x) + else { + var sym : Option[Exp[_]] = None + globalDefs.find { x => x match { + case TP(_,Reflect(Assign(Variable(v1),v2),_,_)) => { + if (v1 == s) {sym = Some(v2); true} + else false; + } + case _ => false + } } + if (sym != None) findInitSymbol(sym.get) + else throw new RuntimeException("findInitSymbol failed (1) during lookup in DynamicRecords while looking for " + sym + ".") + } + } + case TP(sym, _) => sym + case sy@_ => throw new RuntimeException("findInitSymbol failed (2) during lookup in DynamicRecords while looking for " + sy + ".") + } + } override def mirror[A:Manifest](e: Def[A], f: Transformer)(implicit pos: SourceContext): Exp[A] = (e match { - case ReadVar(Variable(a)) => readVar(Variable(f(a))) - case Reflect(NewVar(a), u, es) => reflectMirrored(Reflect(NewVar(f(a)), mapOver(f,u), f(es)))(mtype(manifest[A]), pos) - case Reflect(ReadVar(Variable(a)), u, es) => reflectMirrored(Reflect(ReadVar(Variable(f(a))), mapOver(f,u), f(es)))(mtype(manifest[A]), pos) - case Reflect(Assign(Variable(a),b), u, es) => reflectMirrored(Reflect(Assign(Variable(f(a)), f(b)), mapOver(f,u), f(es)))(mtype(manifest[A]), pos) - case Reflect(VarPlusEquals(Variable(a),b), u, es) => reflectMirrored(Reflect(VarPlusEquals(Variable(f(a)), f(b)), mapOver(f,u), f(es)))(mtype(manifest[A]), pos) - case Reflect(VarMinusEquals(Variable(a),b), u, es) => reflectMirrored(Reflect(VarMinusEquals(Variable(f(a)), f(b)), mapOver(f,u), f(es)))(mtype(manifest[A]), pos) - case Reflect(VarTimesEquals(Variable(a),b), u, es) => reflectMirrored(Reflect(VarTimesEquals(Variable(f(a)), f(b)), mapOver(f,u), f(es)))(mtype(manifest[A]), pos) - case Reflect(VarDivideEquals(Variable(a),b), u, es) => reflectMirrored(Reflect(VarDivideEquals(Variable(f(a)), f(b)), mapOver(f,u), f(es)))(mtype(manifest[A]), pos) + case Reflect(NewVar(a), u, es) => reflectMirrored(Reflect(NewVar(f(a)), mapOver(f,u), f(es)))(mtype(manifest[A])) + case Reflect(ReadVar(Variable(a)), u, es) => reflectMirrored(Reflect(ReadVar(Variable(f(a))), mapOver(f,u), f(es)))(mtype(manifest[A])) + case Reflect(Assign(Variable(a),b), u, es) => reflectMirrored(Reflect(Assign(Variable(f(a)), f(b)), mapOver(f,u), f(es)))(mtype(manifest[A])) + case Reflect(VarPlusEquals(Variable(a),b), u, es) => reflectMirrored(Reflect(VarPlusEquals(Variable(f(a)), f(b)), mapOver(f,u), f(es)))(mtype(manifest[A])) + case Reflect(VarMinusEquals(Variable(a),b), u, es) => reflectMirrored(Reflect(VarMinusEquals(Variable(f(a)), f(b)), mapOver(f,u), f(es)))(mtype(manifest[A])) + case Reflect(VarTimesEquals(Variable(a),b), u, es) => reflectMirrored(Reflect(VarTimesEquals(Variable(f(a)), f(b)), mapOver(f,u), f(es)))(mtype(manifest[A])) + case Reflect(VarDivideEquals(Variable(a),b), u, es) => reflectMirrored(Reflect(VarDivideEquals(Variable(f(a)), f(b)), mapOver(f,u), f(es)))(mtype(manifest[A])) case _ => super.mirror(e,f) }).asInstanceOf[Exp[A]] @@ -207,27 +262,8 @@ trait VariablesExpOpt extends VariablesExp { } } - // eliminate (some) redundant stores - // TODO: strong updates. overwriting a var makes previous stores unnecessary - - override implicit def var_assign[T:Manifest](v: Var[T], e: Exp[T])(implicit pos: SourceContext) : Exp[Unit] = { - if (context ne null) { - // find the last modification of variable v - // if it is an assigment with the same value, we don't need to do anything - val vs = v.e.asInstanceOf[Sym[Variable[T]]] - //TODO: could use calculateDependencies(Read(v)) - - context.reverse.foreach { - case w @ Def(Reflect(NewVar(rhs: Exp[T]), _, _)) if w == vs => if (rhs == e) return () - case Def(Reflect(Assign(`v`, rhs: Exp[T]), _, _)) => if (rhs == e) return () - case Def(Reflect(_, u, _)) if mayWrite(u, List(vs)) => // not a simple assignment - case _ => // ... - } - } - super.var_assign(v,e) - } - - + // TODO: could eliminate redundant stores, too + // by overriding assign ... } @@ -237,27 +273,40 @@ trait ScalaGenVariables extends ScalaGenEffect { override def emitNode(sym: Sym[Any], rhs: Def[Any]) = rhs match { case ReadVar(Variable(a)) => emitValDef(sym, quote(a)) - case NewVar(init) => emitVarDef(sym.asInstanceOf[Sym[Variable[Any]]], quote(init)) - case Assign(Variable(a), b) => emitAssignment(a.asInstanceOf[Sym[Variable[Any]]],quote(b)) - case VarPlusEquals(Variable(a), b) => emitValDef(sym, quote(a) + " += " + quote(b)) - case VarMinusEquals(Variable(a), b) => emitValDef(sym, quote(a) + " -= " + quote(b)) - case VarTimesEquals(Variable(a), b) => emitValDef(sym, quote(a) + " *= " + quote(b)) - case VarDivideEquals(Variable(a), b) => emitValDef(sym, quote(a) + " /= " + quote(b)) - case _ => super.emitNode(sym, rhs) - } - -/* - override def emitNode(sym: Sym[Any], rhs: Def[Any]) = rhs match { - case ReadVar(Variable(a)) => emitValDef(sym, quote(a)) - case NewVar(init) => emitVarDef(sym.asInstanceOf[Sym[Variable[Any]]], quote(init)) - case Assign(Variable(a), b) => emitValDef(sym, quote(a) + " = " + quote(b)) - case VarPlusEquals(Variable(a), b) => emitValDef(sym, quote(a) + " += " + quote(b)) - case VarMinusEquals(Variable(a), b) => emitValDef(sym, quote(a) + " -= " + quote(b)) - case VarTimesEquals(Variable(a), b) => emitValDef(sym, quote(a) + " *= " + quote(b)) - case VarDivideEquals(Variable(a), b) => emitValDef(sym, quote(a) + " /= " + quote(b)) + case y@NewVar(init) => { + if (sym.emitted == false && sym.tp != manifest[Variable[Nothing]] && sym.tp != manifest[Variable[Null]]) { + val obj = sym.asInstanceOf[Sym[Variable[Any]]] + emitVarDef(obj, quote(init)) + sym.emitted = true; + } + } + case ReadVar(null) => {} // emitVarDef(sym.asInstanceOf[Sym[Variable[Any]]], "null") + case Assign(v @ Variable(a), b) => { + val lhsIsNull = a match { + case Def(Reflect(NewVar(y: Exp[_]),_,_)) => + if (y.tp == manifest[Nothing]) true + else false + case y@_ => false + } + val obj = a.asInstanceOf[Sym[Variable[Any]]] + if (!obj.emitted) { + stream.println("var " + quote(obj) + ": " + remap(b.tp) + " = " + quote(b)) + obj.emitted = true + } + else emitAssignment(sym, quote(a), quote(b)) + } + //case Assign(a, b) => emitAssignment(quote(a), quote(b)) + case VarPlusEquals(Variable(a), b) => stream.println(quote(a) + " += " + quote(b)) + case VarMinusEquals(Variable(a), b) => stream.println(quote(a) + " -= " + quote(b)) + case VarTimesEquals(Variable(a), b) => stream.println(quote(a) + " *= " + quote(b)) + case VarDivideEquals(Variable(a), b) => stream.println(quote(a) + " /= " + quote(b)) + case VarTripleShift(Variable(a),b) => emitValDef(sym,quote(a) + ">>>" + quote(b)) + case VarDoubleShift(Variable(a),b) => emitValDef(sym,quote(a) + ">>" + quote(b)) + case VarLeftShift(Variable(a),b) => emitValDef(sym,quote(a) + "<<" + quote(b)) + case VarLogicalOr(Variable(a),b) => emitValDef(sym,quote(a) + "|" + quote(b)) + case VarLogicalAnd(Variable(a),b) => emitValDef(sym,quote(a) + "&" + quote(b)) case _ => super.emitNode(sym, rhs) } -*/ } trait CLikeGenVariables extends CLikeGenBase { @@ -266,13 +315,21 @@ trait CLikeGenVariables extends CLikeGenBase { override def emitNode(sym: Sym[Any], rhs: Def[Any]) = { rhs match { - case ReadVar(Variable(a)) => emitValDef(sym, quote(a)) - case NewVar(init) => emitVarDef(sym.asInstanceOf[Sym[Variable[Any]]], quote(init)) - case Assign(Variable(a), b) => stream.println(quote(a) + " = " + quote(b) + ";") - case VarPlusEquals(Variable(a), b) => stream.println(quote(a) + " += " + quote(b) + ";") - case VarMinusEquals(Variable(a), b) =>stream.println(quote(a) + " -= " + quote(b) + ";") - case VarTimesEquals(Variable(a), b) => stream.println(quote(a) + " *= " + quote(b) + ";") - case VarDivideEquals(Variable(a), b) => stream.println(quote(a) + " /= " + quote(b) + ";") + case ReadVar(Variable(a)) => + stream.println(remap(sym.tp) + " " + quote(sym) + " = " + quote(a) + ";") + case ReadVar(null) => {} // emitVarDef(sym.asInstanceOf[Sym[Variable[Any]]], "null") + case NewVar(init) => + emitVarDef(sym.asInstanceOf[Sym[Variable[Any]]], quote(init)) + case Assign(Variable(a), b) => + emitAssignment(sym, quote(a), quote(b)) + case VarPlusEquals(Variable(a), b) => + emitAssignment(sym, quote(a), quote(a) + " + " + quote(b)) + case VarMinusEquals(Variable(a), b) => + emitAssignment(sym, quote(a), quote(a) + " - " + quote(b)) + case VarTimesEquals(Variable(a), b) => + emitAssignment(sym, quote(a), quote(a) + " * " + quote(b)) + case VarDivideEquals(Variable(a), b) => + emitAssignment(sym, quote(a), quote(a) + " / " + quote(b)) case _ => super.emitNode(sym, rhs) } } diff --git a/src/common/While.scala b/src/common/While.scala index daf0c82b..af4e8488 100644 --- a/src/common/While.scala +++ b/src/common/While.scala @@ -3,15 +3,18 @@ package common import java.io.PrintWriter import scala.lms.internal.GenericNestedCodegen +import scala.lms.internal.CNestedCodegen import scala.reflect.SourceContext trait While extends Base { def __whileDo(cond: => Rep[Boolean], body: => Rep[Unit])(implicit pos: SourceContext): Rep[Unit] + def __doWhile(body: => Rep[Unit], cond: => Rep[Boolean])(implicit pos: SourceContext): Rep[Unit] } trait WhileExp extends While with EffectExp { case class While(cond: Block[Boolean], body: Block[Unit]) extends Def[Unit] + case class DoWhile(body: Block[Unit], cond: Block[Boolean]) extends Def[Unit] override def __whileDo(cond: => Exp[Boolean], body: => Rep[Unit])(implicit pos: SourceContext) = { val c = reifyEffects(cond) @@ -21,26 +24,52 @@ trait WhileExp extends While with EffectExp { reflectEffect(While(c, a), ce andThen ((ae andThen ce).star)) } + override def __doWhile(body: => Rep[Unit], cond: => Rep[Boolean])(implicit pos: SourceContext) = { + val a = reifyEffects(body) + val c = reifyEffects(cond) + val ae = summarizeEffects(a) + val ce = summarizeEffects(c) + reflectEffect(DoWhile(a, c), ae andThen ((ce andThen ae).star)) + } + override def syms(e: Any): List[Sym[Any]] = e match { case While(c, b) => syms(c):::syms(b) // wouldn't need to override... + case DoWhile(b, c) => syms(b):::syms(c) // wouldn't need to override... case _ => super.syms(e) } override def boundSyms(e: Any): List[Sym[Any]] = e match { case While(c, b) => effectSyms(c):::effectSyms(b) + case DoWhile(b, c) => effectSyms(b):::effectSyms(c) case _ => super.boundSyms(e) } override def symsFreq(e: Any): List[(Sym[Any], Double)] = e match { case While(c, b) => freqHot(c):::freqHot(b) + case DoWhile(b, c) => freqHot(b):::freqHot(c) case _ => super.symsFreq(e) } } +trait WhileExpOpt extends WhileExp { this: IfThenElseExp => -trait WhileExpOptSpeculative extends WhileExp with PreviousIterationDummyExp { + /** Optimization technique(s): + * - inversion : This technique changes a standard while loop into a do/while (a.k.a. repeat/until) + * loop wrapped in an if conditional, reducing the number of jumps by two for cases + * where the loop is executed. Doing so duplicates the condition check (increasing the + * size of the code) but is more efficient because jumps usually cause a pipeline stall. + * Additionally, if the initial condition is known at compile-time and is known to be + * side-effect-free, the if guard can be skipped. + */ + override def __whileDo(cond: => Exp[Boolean], body: => Rep[Unit])(implicit pos: SourceContext) = { + __ifThenElse(cond, __doWhile(body, cond), ()) + } + +} + +trait WhileExpOptSpeculative extends WhileExpOpt with PreviousIterationDummyExp { this: IfThenElseExp => override def __whileDo(cond: => Exp[Boolean], body: => Rep[Unit])(implicit pos: SourceContext) = { @@ -91,14 +120,42 @@ trait ScalaGenWhile extends ScalaGenEffect with BaseGenWhile { override def emitNode(sym: Sym[Any], rhs: Def[Any]) = rhs match { case While(c,b) => - stream.print("val " + quote(sym) + " = while ({") +//<<<<<<< HEAD + emitValDef(sym, "while ({") emitBlock(c) - stream.print(quote(getBlockResult(c))) + stream.println(quote(getBlockResult(c))) stream.println("}) {") emitBlock(b) stream.println(quote(getBlockResult(b))) stream.println("}") - +/*======= + //while-do's output is unit, so why do we need to assign its result to a val + val strWriter = new java.io.StringWriter + val localStream = new PrintWriter(strWriter); + withStream(localStream) { + stream.print("while ({") + emitBlock(c) + stream.print(quote(getBlockResult(c))) + stream.println("}) {") + emitBlock(b) + stream.println(quote(getBlockResult(b))) + stream.print("}") + } + emitValDef(sym, strWriter.toString)*/ + case DoWhile(b,c) => + //do-while's output is unit, so why do we need to assign its result to a val + val strWriter = new java.io.StringWriter + val localStream = new PrintWriter(strWriter); + withStream(localStream) { + stream.print("do {") + emitBlock(b) + stream.println(quote(getBlockResult(b))) + stream.println("} while ({") + emitBlock(c) + stream.print(quote(getBlockResult(c))) + stream.print("})") + } + emitValDef(sym, strWriter.toString) case _ => super.emitNode(sym, rhs) } } @@ -119,6 +176,15 @@ trait CLikeGenWhile extends CLikeGenBase with BaseGenWhile { stream.println("if (!"+quote(getBlockResult(c))+") break;") emitBlock(b) stream.println("}") + case DoWhile(b, c) => + stream.println("{") + emitBlock(b) + stream.println("}") + stream.println("for (;;) {") + emitBlock(c) + stream.println("if (!"+quote(getBlockResult(c))+") break;") + emitBlock(b) + stream.println("}") case _ => super.emitNode(sym, rhs) } } diff --git a/src/internal/CCodegen.scala b/src/internal/CCodegen.scala index e4811664..25fa2c59 100644 --- a/src/internal/CCodegen.scala +++ b/src/internal/CCodegen.scala @@ -1,223 +1,443 @@ package scala.lms package internal -import java.io.{FileWriter, PrintWriter, File} +import scala.lms.common.{Base,BaseExp} +import java.io.{FileWriter, StringWriter, PrintWriter, File} +import java.util.ArrayList +import collection.mutable.{ListBuffer, ArrayBuffer, LinkedList, HashMap, ListMap, HashSet, Map => MMap} import collection.immutable.List._ -import collection.mutable.ArrayBuffer -trait CCodegen extends CLikeCodegen with CppHostTransfer { +trait CCodegen extends CLikeCodegen { val IR: Expressions import IR._ - override def deviceTarget: Targets.Value = Targets.Cpp - override def kernelFileExt = "cpp" - override def toString = "cpp" - - val helperFuncList = ArrayBuffer[String]() - - var kernelInputVals: List[Sym[Any]] = Nil - var kernelInputVars: List[Sym[Any]] = Nil - var kernelOutputs: List[Sym[Any]] = Nil - - override def remap[A](m: Manifest[A]) : String = { - m.toString match { - case "java.lang.String" => "string" - case _ => super.remap(m) - } + override def toString = "c" + + var compileCount = 0 + var helperFuncIdx = 0 + var helperFuncString:StringBuilder = null + var hstream: PrintWriter = null + var headerStream: PrintWriter = null + var kernelsList = ListBuffer[Exp[Any]]() + + override def hasMetaData: Boolean = true + override def getMetaData: String = metaData.toString + var metaData: CMetaData = null + + final class TransferFunc { + var funcHtoD:String = _ + var argsFuncHtoD:List[Sym[Any]] = _ + var funcDtoH:String = _ + var argsFuncDtoH:List[Sym[Any]] = _ } - - // we treat string as a primitive type to prevent memory management on strings - // strings are always stack allocated and freed automatically at the scope exit - override def isPrimitiveType(tpe: String) : Boolean = { - tpe match { - case "string" => true - case _ => super.isPrimitiveType(tpe) + final class CMetaData { + val inputs: ListMap[Sym[Any],TransferFunc] = ListMap() + val outputs: ListMap[Sym[Any],TransferFunc] = ListMap() + //val temps: ListMap[Sym[Any],TransferFunc] = ListMap() + //val sizeFuncs: ListMap[String,SizeFunc] = ListMap() + //var gpuLibCall: String = "" + override def toString: String = { + val out = new StringBuilder + out.append("{") + + out.append("\"cppInputs\":["+inputs.toList.reverse.map(in=>"{\""+quote(in._1)+"\":[\""+remap(in._1.tp)+"\",\""+in._2.funcHtoD+"\",\""+in._2.funcDtoH+"\"]}").mkString(",")+"],") + out.append("\"cppOutputs\":["+outputs.toList.reverse.map(out=>"{\""+quote(out._1)+"\":[\""+remap(out._1.tp)+"\",\""+out._2.funcDtoH+"\"]}").mkString(",")+"]") + out.append("}") + out.toString } } - - override def quote(x: Exp[Any]) = x match { - case Const(s: String) => "string(" + super.quote(x) + ")" - case _ => super.quote(x) - } - override def isPrimitiveType[A](m: Manifest[A]) : Boolean = isPrimitiveType(remap(m)) - - override def emitValDef(sym: Sym[Any], rhs: String): Unit = { - if (!isVoidType(sym.tp)) - stream.println(remapWithRef(sym.tp) + quote(sym) + " = " + rhs + ";") - else // we might still want the RHS for its effects - stream.println(rhs + ";") + def initCompile = { + val className = "staged" + compileCount + compileCount = compileCount + 1 + className } - override def emitVarDef(sym: Sym[Variable[Any]], rhs: String): Unit = { - stream.println(remapWithRef(sym.tp.typeArguments.head) + quote(sym) + " = " + rhs + ";") + override def kernelInit(syms: List[Sym[Any]], vals: List[Sym[Any]], vars: List[Sym[Any]], resultIsVar: Boolean): Unit = { + // Set kernel input and output symbols + //setKernelInputs(vals) + //setKernelOutputs(syms) + + /* + // Conditions for not generating GPU kernels (may be relaxed later) + for (sym <- syms) { + if((!isObjectType(sym.tp)) && (remap(sym.tp)!="void")) throw new GenerationFailedException("GPUGen: Not GPUable output type : %s".format(remap(sym.tp))) + } + if((vars.length > 0) || (resultIsVar)) throw new GenerationFailedException("GPUGen: Not GPUable input/output types: Variable") +*/ + helperFuncString.clear + metaData = new CMetaData } - override def emitVarDecl(sym: Sym[Any]): Unit = { - stream.println(remapWithRef(sym.tp) + " " + quote(sym) + ";") + override def initializeGenerator(buildDir:String, args: Array[String], _analysisResults: MMap[String,Any]): Unit = { + val outDir = new File(buildDir) + outDir.mkdirs + helperFuncIdx = 0 + helperFuncString = new StringBuilder + hstream = new PrintWriter(new FileWriter(buildDir + "helperFuncs.cpp")) + headerStream = new PrintWriter(new FileWriter(buildDir + "dsl.hpp")) + headerStream.println("#include \"helperFuncs.cpp\"") + + /* + //TODO: Put all the DELITE APIs declarations somewhere + hstream.print(getDSLHeaders) + hstream.print("#include \n") + hstream.print("#include \n") + hstream.print("#include \n\n") + hstream.print("//Delite Runtime APIs\n") + hstream.print("extern void DeliteCudaMallocHost(void **ptr, size_t size);\n") + hstream.print("extern void DeliteCudaMalloc(void **ptr, size_t size);\n") + hstream.print("extern void DeliteCudaMemcpyHtoDAsync(void *dptr, void *sptr, size_t size);\n") + hstream.print("extern void DeliteCudaMemcpyDtoHAsync(void *dptr, void *sptr, size_t size);\n") + hstream.print("typedef jboolean jbool;\n") // TODO: Fix this + hstream.print("typedef jbooleanArray jboolArray;\n\n") // TODO: Fix this + */ + + super.initializeGenerator(buildDir, args, _analysisResults) } - override def emitAssignment(sym: Sym[Any], rhs: String): Unit = { - stream.println(quote(sym) + " = " + rhs + ";") + def copyInputHtoD(sym: Sym[Any]) : String = { + remap(sym.tp) match { + case _ => throw new GenerationFailedException("CGen: copyInputHtoD(sym) : Cannot copy to GPU device (%s)".format(remap(sym.tp))) + } } - override def kernelInit(syms: List[Sym[Any]], vals: List[Sym[Any]], vars: List[Sym[Any]], resultIsVar: Boolean): Unit = { - kernelInputVals = vals - kernelInputVars = vars - kernelOutputs = syms + def copyOutputDtoH(sym: Sym[Any]) : String = { + remap(sym.tp) match { + case _ => throw new GenerationFailedException("CGen: copyOutputDtoH(sym) : Cannot copy from GPU device (%s)".format(remap(sym.tp))) + } } - override def initializeGenerator(buildDir:String, args: Array[String]): Unit = { - val outDir = new File(buildDir) - outDir.mkdirs - - /* file for helper functions (transfer function, allocation function) */ - helperFuncStream = new PrintWriter(new FileWriter(buildDir + deviceTarget + "helperFuncs.cpp")) - helperFuncStream.println("#include ") - helperFuncStream.println("#include \"" + deviceTarget + "helperFuncs.h\"") - - /* type aliases */ - typesStream = new PrintWriter(new FileWriter(buildDir + deviceTarget + "types.h")) - - /* header file for kernels and helper functions */ - headerStream = new PrintWriter(new FileWriter(buildDir + deviceTarget + "helperFuncs.h")) - headerStream.println("#include ") - headerStream.println("#include ") - headerStream.println("#include ") - headerStream.println("#include ") - headerStream.println("#include ") - headerStream.println("#include ") - headerStream.println("#include ") - headerStream.println("#include ") - headerStream.println("#include ") - headerStream.println("#include ") - headerStream.println("#include ") - headerStream.println("#include ") - headerStream.println("#include \"" + deviceTarget + "types.h\"") - headerStream.println(getDataStructureHeaders()) - - super.initializeGenerator(buildDir, args) + def copyMutableInputDtoH(sym: Sym[Any]) : String = { + remap(sym.tp) match { + case _ => throw new GenerationFailedException("CGen: copyMutableInputDtoH(sym) : Cannot copy from GPU device (%s)".format(remap(sym.tp))) + } } def emitForwardDef[A:Manifest](args: List[Manifest[_]], functionName: String, out: PrintWriter) = { out.println(remap(manifest[A])+" "+functionName+"("+args.map(a => remap(a)).mkString(", ")+");") } - - def emitSource[A:Manifest](args: List[Sym[_]], body: Block[A], functionName: String, out: PrintWriter) = { - val sA = remap(manifest[A]) + def allocStruct(sym: Sym[Any], structName: String, out: PrintWriter) { + out.println(structName + " " + quote(sym)+ ";") + //out.println(structName + "* " + quote(sym) + " = (" + structName + "*)malloc(sizeof(" + structName + "));") + } + + def getMemoryAllocString(count: String, memType: String): String = { + "(" + memType + "*)malloc(" + count + " * sizeof(" + memType + "));" + } + def emitSource[A:Manifest](args: List[Sym[_]], b: Block[A], functionName: String, out: PrintWriter, dynamicReturnType: String = null, serializable: Boolean = false) = { + val body = runTransformations(b) + val sA = if (dynamicReturnType != null) dynamicReturnType else remap(getBlockResult(body).tp) withStream(out) { stream.println("/*****************************************\n"+ " Emitting C Generated Code \n"+ "*******************************************/\n" + "#include \n" + "#include \n" + - "#include \n" + - "#include " - ) + "#include \n" + + "#include ") + stream.println("int tpch_strcmp(const char *s1, const char *s2);") - // TODO: static data - - //stream.println("class "+className+(if (staticData.isEmpty) "" else "("+staticData.map(p=>"p"+quote(p._1)+":"+p._1.tp).mkString(",")+")")+" - //extends (("+args.map(a => remap(a.tp)).mkString(", ")+")=>("+sA+")) {") + stream.println("int timeval_subtract(struct timeval *result, struct timeval *t2, struct timeval *t1) {\n" + + "\tlong int diff = (t2->tv_usec + 1000000 * t2->tv_sec) - (t1->tv_usec + 1000000 * t1->tv_sec);\n" + + "\tresult->tv_sec = diff / 1000000;\n" + + "\tresult->tv_usec = diff % 1000000;\n" + + "\treturn (diff<0);\n" + + "}\n") - stream.println(sA+" "+functionName+"("+args.map(a => remapWithRef(a.tp)+" "+quote(a)).mkString(", ")+") {") - - emitBlock(body) - - val y = getBlockResult(body) + // TODO: static data + val sw = new StringWriter() + val tempWriter = new PrintWriter(sw) + tempWriter.println(sA+" "+functionName+"("+args.map(a => remap(a.tp)+" "+quote(a)).mkString(", ")+") {") + withStream(tempWriter) { emitBlock(body) } + val y = getBlockResult(body) if (remap(y.tp) != "void") - stream.println("return " + quote(y) + ";") - - stream.println("}") + tempWriter.println("return " + quote(y) + ";") + tempWriter.println("}") + + var code = sw.toString + sw.getBuffer().setLength(0) + withStream(tempWriter) { emitFileHeader() } + code = sw.toString + code + + stream.println("/********************* DATA STRUCTURES ***********************/") + emitDataStructures(stream) + stream.println("") + stream.println("/************************ FUNCTIONS **************************/") + sw.getBuffer().setLength(0) + withStream(tempWriter) { emitFunctions() } + val funs = sw.toString + printIndented(funs)(stream) + stream.println("") + stream.println("/************************ MAIN BODY **************************/") + //stream.println(code) + printIndented(code)(stream) stream.println("/*****************************************\n"+ - " End of C Generated Code \n"+ - "*******************************************/") + " * End of C Generated Code *\n"+ + " *****************************************/") } Nil - } - - override def emitTransferFunctions() { - - for ((tp,name) <- dsTypesList) { - try { - // Emit input copy helper functions for object type inputs - //TODO: For now just iterate over all possible hosts, but later we can pick one depending on the input target - val (recvHeader, recvSource) = emitRecv(tp, Targets.JVM) - if (!helperFuncList.contains(recvHeader)) { - headerStream.println(recvHeader) - helperFuncStream.println(recvSource) - helperFuncList.append(recvHeader) - } - val (recvViewHeader, recvViewSource) = emitRecvView(tp, Targets.JVM) - if (!helperFuncList.contains(recvViewHeader)) { - headerStream.println(recvViewHeader) - helperFuncStream.println(recvViewSource) - helperFuncList.append(recvViewHeader) - } - val (sendUpdateHeader, sendUpdateSource) = emitSendUpdate(tp, Targets.JVM) - if (!helperFuncList.contains(sendUpdateHeader)) { - headerStream.println(sendUpdateHeader) - helperFuncStream.println(sendUpdateSource) - helperFuncList.append(sendUpdateHeader) - } - val (recvUpdateHeader, recvUpdateSource) = emitRecvUpdate(tp, Targets.JVM) - if (!helperFuncList.contains(recvUpdateHeader)) { - headerStream.println(recvUpdateHeader) - helperFuncStream.println(recvUpdateSource) - helperFuncList.append(recvUpdateHeader) - } + } - // Emit output copy helper functions for object type inputs - val (sendHeader, sendSource) = emitSend(tp, Targets.JVM) - if (!helperFuncList.contains(sendHeader)) { - headerStream.println(sendHeader) - helperFuncStream.println(sendSource) - helperFuncList.append(sendHeader) + def printIndented(str: String)(out: PrintWriter): Unit = { + val lines = str.split("[\n\r]") + var indent = 0 + for (l0 <- lines) { + val l = l0.trim + if (l.length > 0) { + var open = 0 + var close = 0 + var initClose = 0 + var nonWsChar = false + l foreach { + case '{' /*| '(' | '['*/ => { + open += 1 + if (!nonWsChar) { + nonWsChar = true + initClose = close + } + } + case '}' /*| ')' | ']'*/ => close += 1 + case x => if (!nonWsChar && !x.isWhitespace) { + nonWsChar = true + initClose = close + } } - val (sendViewHeader, sendViewSource) = emitSendView(tp, Targets.JVM) - if (!helperFuncList.contains(sendViewHeader)) { - headerStream.println(sendViewHeader) - helperFuncStream.println(sendViewSource) - helperFuncList.append(sendViewHeader) - } - } - catch { - case e: GenerationFailedException => - helperFuncStream.flush - headerStream.flush - case e: Exception => throw(e) + if (!nonWsChar) initClose = close + out.println(" " * (indent - initClose) + l) + indent += (open - close) } } + } - helperFuncStream.flush - headerStream.flush - typesStream.flush + + +/* + //TODO: is sym of type Any or Variable[Any] ? + def emitConstDef(sym: Sym[Any], rhs: String): Unit = { + stream.print("const ") + emitVarDef(sym, rhs) + } +*/ + def emitVarDef(sym: Sym[Variable[Any]], rhs: String): Unit = { + // TODO: check void type? + stream.println(remap(sym.tp) + " " + quote(sym) + " = " + rhs + ";") } - def kernelName = "kernel_" + kernelOutputs.map(quote).mkString("") + def emitValDef(sym: Sym[Any], rhs: String): Unit = { + if (remap(sym.tp) == "void") + stream.println(rhs + "; // " + quote(sym, true)) + else + stream.println(remap(sym.tp) + " " + quote(sym) + " = " + rhs + ";") + } - override def emitKernelHeader(syms: List[Sym[Any]], vals: List[Sym[Any]], vars: List[Sym[Any]], resultType: String, resultIsVar: Boolean, external: Boolean): Unit = { - super.emitKernelHeader(syms, vals, vars, resultType, resultIsVar, external) + def emitAssignment(sym: Sym[Any], lhs:String, rhs: String): Unit = { + // TODO: check void type? + stream.println(lhs + " = " + rhs + ";") } override def emitKernelFooter(syms: List[Sym[Any]], vals: List[Sym[Any]], vars: List[Sym[Any]], resultType: String, resultIsVar: Boolean, external: Boolean): Unit = { - super.emitKernelFooter(syms, vals, vars, resultType, resultIsVar, external) + + //Currently only allow single return value + if(syms.size > 1) throw new GenerationFailedException("CLikeGen: Cannot have more than 1 results!\n"); + if(external) throw new GenerationFailedException("CLikeGen: Cannot have external libraries\n") + + if(resultType != "void") + stream.println("return " + quote(syms(0)) + ";") + stream.println("}") + + // Emit input copy helper functions for object type inputs + for(v <- (vals++vars) if isObjectType(v.tp)) { + helperFuncString.append(emitCopyInputHtoD(v, syms, copyInputHtoD(v))) + helperFuncString.append(emitCopyMutableInputDtoH(v, syms, copyMutableInputDtoH(v))) + } + + // Emit output copy helper functions for object type inputs + for(v <- (syms) if isObjectType(v.tp)) { + helperFuncString.append(emitCopyOutputDtoH(v, syms, copyOutputDtoH(v))) + } + + // Print helper functions to file stream + hstream.print(helperFuncString) + hstream.flush + + // Print out dsl.h file + if(kernelsList.intersect(syms).isEmpty) { + headerStream.println("#include \"%s.cpp\"".format(syms.map(quote).mkString(""))) + kernelsList ++= syms + } + headerStream.flush + + /* + // Print out device function + devStream.println(devFuncString) + devStream.flush + */ + } + + override def emitKernelHeader(syms: List[Sym[Any]], vals: List[Sym[Any]], vars: List[Sym[Any]], resultType: String, resultIsVar: Boolean, external: Boolean): Unit = { + if(syms.size>1) throw new GenerationFailedException("CGen: Cannot have multiple kernel outputs!\n") + + //if( (vars.length>0) || (resultIsVar) ) throw new GenerationFailedException("Var is not supported for CPP kernels") + + val kernelName = syms.map(quote).mkString("") + + /* + if (resultIsVar){ + stream.print("PrimitiveRef<" + resultType + ">") + } + else { + stream.print(resultType) + } + */ + stream.print(resultType) + + stream.print(" kernel_" + kernelName + "(") + stream.print(vals.map(p=>remap(p.tp) + " " + quote(p)).mkString(", ")) + if (vals.length > 0 && vars.length > 0){ + stream.print(", ") + } + if (vars.length > 0){ + stream.print(vars.map(v => remap(v.tp) + " &" + quote(v)).mkString(",")) + } + + stream.println(") {") + } + + override def quote(x: Exp[Any]) : String = { + x match { + case Const(y: java.lang.Character) => + if (y == '\0') "'\\0'" + else "'" + y.toString + "'" + case Const(null) => "NULL" + case Const(()) => ";" + case _ => super.quote(x) + } + } + + override def remap[A](m: Manifest[A]) = { + m match { + case s if m == manifest[Int] => "int" + case s if m == manifest[Double] => "double" + case s if m == manifest[Long] => "long" + case s if m == manifest[Character] => "char" + case s if m == manifest[Byte] => "char" + case s if m == manifest[Boolean] => "bool" + case s if m == manifest[String] => "char*" + case s if m == manifest[Float] => "float" + case s if m == manifest[Unit] => "void" + case s if m == manifest[java.util.Date] => "long" + case _ => super.remap(m) + } } + /******************************************************* + * Methods below are for emitting helper functions + *******************************************************/ + // Yannis: Should these things be here? They do not seem to be general C, but + // rather CUDA like programming. + def emitCopyInputHtoD(sym: Sym[Any], ksym: List[Sym[Any]], contents: String) : String = { + val out = new StringBuilder + if(isObjectType(sym.tp)) { + helperFuncIdx += 1 + out.append("%s copyInputHtoD_%s_%s_%s(%s) {\n".format(remap(sym.tp), ksym.map(quote).mkString(""), quote(sym),helperFuncIdx, "JNIEnv *env , jobject obj")) + out.append(copyInputHtoD(sym)) + out.append("}\n") + val tr = metaData.inputs.getOrElse(sym,new TransferFunc) + tr.funcHtoD = "copyInputHtoD_%s_%s_%s".format(ksym.map(quote).mkString(""),quote(sym),helperFuncIdx) + metaData.inputs.put(sym,tr) + out.toString + } + else { + val tr = metaData.inputs.getOrElse(sym,new TransferFunc) + tr.funcHtoD = "copyInputHtoD_dummy".format(ksym.map(quote).mkString(""),quote(sym),helperFuncIdx) + metaData.inputs.put(sym,tr) + "" + } + } + + // For mutable inputs, copy the mutated datastructure from GPU to CPU after the kernel is terminated + def emitCopyMutableInputDtoH(sym: Sym[Any], ksym: List[Sym[Any]], contents: String): String = { + val out = new StringBuilder + if(isObjectType(sym.tp)) { + helperFuncIdx += 1 + out.append("void copyMutableInputDtoH_%s_%s_%s(%s) {\n".format(ksym.map(quote).mkString(""), quote(sym), helperFuncIdx, "JNIEnv *env , jobject obj, "+remap(sym.tp)+" *"+quote(sym)+"_ptr")) + out.append("%s %s = *(%s_ptr);\n".format(remap(sym.tp),quote(sym),quote(sym))) + out.append(copyMutableInputDtoH(sym)) + out.append("}\n") + val tr = metaData.inputs.getOrElse(sym,new TransferFunc) + tr.funcDtoH = "copyMutableInputDtoH_%s_%s_%s".format(ksym.map(quote).mkString(""),quote(sym),helperFuncIdx) + metaData.inputs.put(sym,tr) + out.toString + } + else { + val tr = metaData.inputs.getOrElse(sym,new TransferFunc) + tr.funcDtoH = "copyMutableInputDtoH_%s_%s_%s".format(ksym.map(quote).mkString(""),quote(sym),helperFuncIdx) + metaData.inputs.put(sym,tr) + "" + } + } + + def emitCopyOutputDtoH(sym: Sym[Any], ksym: List[Sym[Any]], contents: String): String = { + val out = new StringBuilder + if(isObjectType(sym.tp)) { + helperFuncIdx += 1 + out.append("jobject copyOutputDtoH_%s(JNIEnv *env,%s) {\n".format(helperFuncIdx,remap(sym.tp)+" *"+quote(sym)+"_ptr")) + out.append("\t%s %s = *(%s_ptr);\n".format(remap(sym.tp),quote(sym),quote(sym))) + out.append(copyOutputDtoH(sym)) + out.append("}\n") + val tr = metaData.outputs.getOrElse(sym,new TransferFunc) + tr.funcDtoH = "copyOutputDtoH_%s".format(helperFuncIdx) + metaData.outputs.put(sym,tr) + out.toString + } + else { + val tr = metaData.outputs.getOrElse(sym,new TransferFunc) + tr.funcDtoH = "copyOutputDtoH_%s".format(helperFuncIdx) + metaData.outputs.put(sym,tr) + "" + } + } + + + } -trait CNestedCodegen extends CLikeNestedCodegen with CCodegen { - val IR: Expressions with Effects +// TODO: do we need this for each target? +trait CNestedCodegen extends GenericNestedCodegen with CCodegen { + val IR: Expressions with Effects with LoweringTransform import IR._ - } -trait CFatCodegen extends CLikeFatCodegen with CCodegen { - val IR: Expressions with Effects with FatExpressions - import IR._ +trait CFatCodegen extends GenericFatCodegen with CCodegen { + val IR: Expressions with Effects with FatExpressions with LoweringTransform +} +trait Pointer extends Base { + class PointerManifest[A:Manifest] + def pointer_assign[A:Manifest](s: Rep[A], vl: Rep[A]): Rep[Unit] + def getPointerManifest[A:Manifest] = manifest[PointerManifest[A]] +} + +trait PointerExp extends Pointer with BaseExp with Effects { + case class PointerAssign[A:Manifest](s: Exp[A], vl: Exp[A]) extends Def[Unit] + def pointer_assign[A:Manifest](s: Exp[A], vl: Exp[A]) = reflectEffect(PointerAssign(s,vl)) +} + +trait CGenPointer extends GenericNestedCodegen { + val IR: PointerExp + import IR._ + + override def remap[A](m: Manifest[A]) = m match { + case s if m <:< manifest[PointerManifest[Any]] => remap(m.typeArguments.head) + "*" + case _ => super.remap(m) + } + override def emitNode(sym: Sym[Any], rhs: Def[Any]) = rhs match { + case PointerAssign(s,vl) => stream.println("*" + quote(s) + " = " + quote(vl) + ";") + case _ => super.emitNode(sym, rhs) + } } diff --git a/src/internal/CLikeCodegen.scala b/src/internal/CLikeCodegen.scala index 2f8ee4da..64f5ac78 100644 --- a/src/internal/CLikeCodegen.scala +++ b/src/internal/CLikeCodegen.scala @@ -2,182 +2,82 @@ package scala.lms package internal import java.io.PrintWriter -import collection.mutable.HashSet trait CLikeCodegen extends GenericCodegen { val IR: Expressions import IR._ +/* + //TODO: is sym of type Any or Variable[Any] ? + def emitConstDef(sym: Sym[Any], rhs: String): Unit +*/ + def emitVarDef(sym: Sym[Variable[Any]], rhs: String): Unit + def emitValDef(sym: Sym[Any], rhs: String): Unit + def emitAssignment(sym: Sym[Any], lhs:String, rhs: String): Unit + + override def emitKernelHeader(syms: List[Sym[Any]], vals: List[Sym[Any]], vars: List[Sym[Any]], resultType: String, resultIsVar: Boolean, external: Boolean): Unit = { + val List(sym) = syms // TODO - def mangledName(name: String) = name.replaceAll("\\s","").map(c => if(!c.isDigit && !c.isLetter) '_' else c) - - // List of datastructure types that requires transfer functions to be generated for this target - val dsTypesList = HashSet[(Manifest[_],String)]() - - // Streams for helper functions and its header - var helperFuncStream: PrintWriter = _ - var headerStream: PrintWriter = _ - var actRecordStream: PrintWriter = _ - var typesStream: PrintWriter = _ - - def emitVarDef(sym: Sym[Variable[Any]], rhs: String): Unit = emitValDef(sym, rhs) - - def emitValDef(sym: Sym[Any], rhs: String): Unit = emitValDef(quote(sym), sym.tp, rhs) - - def emitValDef(sym: String, tpe: Manifest[_], rhs: String): Unit = { - if(remap(tpe) != "void") stream.println(remap(tpe) + " " + sym + " = " + rhs + ";") - } - - override def emitVarDecl(sym: Sym[Any]): Unit = { - stream.println(remap(sym.tp) + " " + quote(sym) + ";") - } - - override def emitAssignment(sym: Sym[Any], rhs: String): Unit = { - stream.println(quote(sym) + " = " + rhs + ";") - } - - def remapWithRef[A](m: Manifest[A]): String = remap(m) + addRef(m) - def remapWithRef(tpe: String): String = tpe + addRef(tpe) + if( (vars.length>0) || (resultIsVar) ) throw new GenerationFailedException("Var is not supported for CPP kernels") - override def remap[A](m: Manifest[A]) : String = { - if (m.erasure == classOf[Variable[AnyVal]]) - remap(m.typeArguments.head) - else if (m.erasure == classOf[List[Any]]) { // Use case: Delite Foreach sync list - deviceTarget.toString + "List< " + remap(m.typeArguments.head) + " >" - } - else { - m.toString match { - case "scala.collection.immutable.List[Float]" => "List" - case "Boolean" => "bool" - case "Byte" => "int8_t" - case "Char" => "uint16_t" - case "Short" => "int16_t" - case "Int" => "int32_t" - case "Long" => "int64_t" - case "Float" => "float" - case "Double" => "double" - case "Unit" => "void" - case "Nothing" => "void" - case _ => throw new GenerationFailedException("CLikeGen: remap(m) : Type %s cannot be remapped.".format(m.toString)) - } - } + val paramStr = vals.map(ele=>remap(ele.tp) + " " + quote(ele)).mkString(", ") + stream.println("%s kernel_%s(%s) {".format(resultType, quote(sym), paramStr)) } - def addRef(): String = if (cppMemMgr=="refcnt") " " else " *" - def addRef[A](m: Manifest[A]): String = addRef(remap(m)) - def addRef(tpe: String): String = { - if (!isPrimitiveType(tpe) && !isVoidType(tpe)) addRef() - else " " + override def emitKernelFooter(syms: List[Sym[Any]], vals: List[Sym[Any]], vars: List[Sym[Any]], resultType: String, resultIsVar: Boolean, external: Boolean): Unit = { + val List(sym) = syms // TODO + + if(resultType != "void") + stream.println("return " + quote(sym) + ";") + stream.println("}") } - // move to CCodegen? - def unwrapSharedPtr(tpe: String): String = { - assert(cppMemMgr == "refcnt") - if(tpe.contains("std::shared_ptr")) - tpe.replaceAll("std::shared_ptr<","").replaceAll(">","") - else - tpe - } - def wrapSharedPtr(tpe: String): String = { - assert(cppMemMgr == "refcnt") - if(!isPrimitiveType(tpe) && !isVoidType(tpe)) - "std::shared_ptr<" + tpe + ">" - else - tpe - } - - override def emitKernelHeader(syms: List[Sym[Any]], vals: List[Sym[Any]], vars: List[Sym[Any]], resultType: String, resultIsVar: Boolean, external: Boolean): Unit = { - - stream.append("#include \"" + deviceTarget + "helperFuncs.h\"\n") - - def kernelSignature: String = { - val out = new StringBuilder - if(resultIsVar) { - if (cppMemMgr == "refcnt") - out.append(wrapSharedPtr(hostTarget + "Ref" + unwrapSharedPtr(resultType))) - else - out.append(hostTarget + "Ref" + resultType + addRef()) - } - else { - out.append(resultType + addRef(resultType)) - } - - out.append(" kernel_" + syms.map(quote).mkString("") + "(") - out.append(vals.map(p => remap(p.tp) + " " + addRef(p.tp) + quote(p)).mkString(", ")) - if (vals.length > 0 && vars.length > 0) { - out.append(", ") - } - if (vars.length > 0) { - if (cppMemMgr == "refcnt") - out.append(vars.map(v => wrapSharedPtr(hostTarget + "Ref" + unwrapSharedPtr(remap(v.tp))) + " " + quote(v)).mkString(",")) - else - out.append(vars.map(v => hostTarget + "Ref" + remap(v.tp) + addRef() + " " + quote(v)).mkString(",")) - } - out.append(")") - out.toString + def isObjectType[A](m: Manifest[A]) : Boolean = { + m.toString match { + case _ => false } + } - //TODO: Remove the dependency to Multiloop to Delite - if (!resultType.startsWith("DeliteOpMultiLoop")) { - stream.println(kernelSignature + " {") - headerStream.println(kernelSignature + ";") + def remapToJNI[A](m: Manifest[A]) : String = { + remap(m) match { + case "bool" => "Boolean" + case "char" => "Byte" + case "CHAR" => "Char" + case "short" => "Short" + case "int" => "Int" + case "long" => "Long" + case "float" => "Float" + case "double" => "Double" + case _ => throw new GenerationFailedException("GPUGen: Cannot get array creation JNI function for this type " + remap(m)) } } - override def emitKernelFooter(syms: List[Sym[Any]], vals: List[Sym[Any]], vars: List[Sym[Any]], resultType: String, resultIsVar: Boolean, external: Boolean): Unit = { - //TODO: Remove the dependency to Multiloop to Delite - if(resultType != "void" && !resultType.startsWith("DeliteOpMultiLoop")) - stream.println("return " + quote(syms(0)) + ";") - if(!resultType.startsWith("DeliteOpMultiLoop")) - stream.println("}") -/* - for(s <- syms++vals++vars) { - if(dsTypesList.contains(s.tp)) println("contains :" + remap(s.tp)) - else println("not contains: " + remap(s.tp)) - } - println(syms.map(quote).mkString("") + "adding dsTypesList:" + (syms++vals++vars).map(_.tp).mkString(",")) - dsTypesList ++= (syms++vals++vars).map(_.tp) - println("dsTyps-lms:" + dsTypesList.map(remap(_)).mkString(",")) //toString) - */ - dsTypesList ++= (syms++vals++vars).map(s => (s.tp,remap(s.tp))) + // Map a scala primitive type to JNI type descriptor + def JNITypeDescriptor[A](m: Manifest[A]) : String = m.toString match { + case "Boolean" => "Z" + case "Byte" => "B" + case "Char" => "C" + case "Short" => "S" + case "Int" => "I" + case "Long" => "J" + case "Float" => "F" + case "Double" => "D" + case _ => throw new GenerationFailedException("Undefined GPU type") } - def isPrimitiveType(tpe: String) : Boolean = { - tpe match { - case "bool" | "int8_t" | "uint16_t" | "int16_t" | "int32_t" | "int64_t" | "float" | "double" => true - case _ => false - } - } - def isVoidType(tpe: String) : Boolean = { - if(tpe == "void") true - else false - } - - def CLikeConsts(x:Exp[Any], s:String): String = { - s match { - case "Infinity" => "std::numeric_limits<%s>::max()".format(remap(x.tp)) - case _ => super.quote(x) - } - } - - override def quote(x: Exp[Any]) = x match { - case Const(s: Unit) => "" - case Const(s: Float) => s+"f" - case Const(null) => "NULL" - case Const(z) => CLikeConsts(x, z.toString) - case Sym(-1) => "_" - case _ => super.quote(x) - } } trait CLikeNestedCodegen extends GenericNestedCodegen with CLikeCodegen { - val IR: Expressions with Effects + val IR: Expressions with Effects with LoweringTransform import IR._ } trait CLikeFatCodegen extends GenericFatCodegen with CLikeCodegen { - val IR: Expressions with Effects with FatExpressions + val IR: Expressions with Effects with FatExpressions with LoweringTransform import IR._ + + def emitMultiLoopCond(sym: Sym[Any], funcs:List[Block[Any]], idx: Sym[Int], postfix: String="", stream:PrintWriter):(String,List[Exp[Any]]) + } diff --git a/src/internal/Effects.scala b/src/internal/Effects.scala index d64617f1..ffe40f7e 100644 --- a/src/internal/Effects.scala +++ b/src/internal/Effects.scala @@ -28,8 +28,6 @@ trait Effects extends Expressions with Blocks with Utils { var context: State = _ - var conditionalScope = false // used to construct Control nodes - // --- class defs case class Reflect[+A](x:Def[A], summary: Summary, deps: List[Exp[Any]]) extends Def[A] @@ -43,27 +41,26 @@ trait Effects extends Expressions with Blocks with Utils { val mayGlobal: Boolean, val mstGlobal: Boolean, val resAlloc: Boolean, - val control: Boolean, val mayRead: List[Sym[Any]], val mstRead: List[Sym[Any]], val mayWrite: List[Sym[Any]], val mstWrite: List[Sym[Any]]) - def Pure() = new Summary(false,false,false,false,false,false,Nil,Nil,Nil,Nil) - def Simple() = new Summary(true,true,false,false,false,false,Nil,Nil,Nil,Nil) - def Global() = new Summary(false,false,true,true,false,false,Nil,Nil,Nil,Nil) - def Alloc() = new Summary(false,false,false,false,true,false,Nil,Nil,Nil,Nil) - def Control() = new Summary(false,false,false,false,false,true,Nil,Nil,Nil,Nil) + def Pure() = new Summary(false,false,false,false,false,Nil,Nil,Nil,Nil) + def Simple() = new Summary(true,true,false,false,false,Nil,Nil,Nil,Nil) + def Global() = new Summary(false,false,true,true,false,Nil,Nil,Nil,Nil) + def Alloc() = new Summary(false,false,false,false,true,Nil,Nil,Nil,Nil) - def Read(v: List[Sym[Any]]) = new Summary(false,false,false,false,false,false,v.distinct,v.distinct,Nil,Nil) - def Write(v: List[Sym[Any]]) = new Summary(false,false,false,false,false,false,Nil,Nil,v.distinct,v.distinct) + def Read(v: List[Sym[Any]]) = new Summary(false,false,false,false,false,v.distinct,v.distinct,Nil,Nil) + def Write(v: List[Sym[Any]]) = new Summary(false,false,false,false,false,Nil,Nil,v.distinct,v.distinct) def mayRead(u: Summary, a: List[Sym[Any]]): Boolean = u.mayGlobal || a.exists(u.mayRead contains _) def mayWrite(u: Summary, a: List[Sym[Any]]): Boolean = u.mayGlobal || a.exists(u.mayWrite contains _) - def maySimple(u: Summary): Boolean = u.mayGlobal || u.maySimple + def maySimple(u: Summary): Boolean = u.mayGlobal || u.maySimple def mustMutable(u: Summary): Boolean = u.resAlloc def mustPure(u: Summary): Boolean = u == Pure() + def mustOnlyAlloc(u: Summary): Boolean = u == Alloc() // only has a resource allocation def mustOnlyRead(u: Summary): Boolean = u == Pure().copy(mayRead=u.mayRead, mstRead=u.mstRead) // only reads allowed def mustIdempotent(u: Summary): Boolean = mustOnlyRead(u) // currently only reads are treated as idempotent @@ -73,7 +70,6 @@ trait Effects extends Expressions with Blocks with Utils { u.maySimple || v.maySimple, u.mstSimple && v.mstSimple, u.mayGlobal || v.mayGlobal, u.mstGlobal && v.mstGlobal, false, //u.resAlloc && v.resAlloc, <--- if/then/else will not be mutable! - u.control || v.control, (u.mayRead ++ v.mayRead).distinct, (u.mstRead intersect v.mstRead), (u.mayWrite ++ v.mayWrite).distinct, (u.mstWrite intersect v.mstWrite) ) @@ -82,30 +78,20 @@ trait Effects extends Expressions with Blocks with Utils { u.maySimple || v.maySimple, u.mstSimple || v.mstSimple, u.mayGlobal || v.mayGlobal, u.mstGlobal || v.mstGlobal, u.resAlloc || v.resAlloc, - u.control || v.control, (u.mayRead ++ v.mayRead).distinct, (u.mstRead ++ v.mstRead).distinct, (u.mayWrite ++ v.mayWrite).distinct, (u.mstWrite ++ v.mstWrite).distinct ) def infix_andThen(u: Summary, v: Summary) = new Summary( u.maySimple || v.maySimple, u.mstSimple || v.mstSimple, - u.mayGlobal || v.mayGlobal, u.mstGlobal || v.mstGlobal, + u.mayGlobal || v.mayGlobal, u.mstGlobal || v.mstGlobal, v.resAlloc, - u.control || v.control, (u.mayRead ++ v.mayRead).distinct, (u.mstRead ++ v.mstRead).distinct, (u.mayWrite ++ v.mayWrite).distinct, (u.mstWrite ++ v.mstWrite).distinct ) def infix_star(u: Summary) = Pure() orElse u // any number of repetitions, including 0 - def infix_withoutControl(u: Summary) = new Summary( - u.maySimple, u.mstSimple, - u.mayGlobal, u.mstGlobal, - u.resAlloc, - false, - u.mayRead, u.mstRead, - u.mayWrite, u.mstWrite - ) def summarizeEffects(e: Block[Any]) = e match { case Block(Def(Reify(_,u,_))) => u @@ -117,29 +103,8 @@ trait Effects extends Expressions with Blocks with Utils { // --- reflect helpers - def controlDep(x: Exp[Any]) = x match { - case Def(Reflect(y,u,es)) if u == Control() => true - case _ => false - } - - // performance hotspot - def nonControlSyms[R](es: List[Exp[Any]], ss: Any => List[R]): List[R] = { - // es.filterNot(controlDep).flatMap(syms) - val out = new mutable.ListBuffer[R] - var it = es.iterator - while (it.hasNext) { - val e = it.next() - if (!controlDep(e)) out ++= ss(e) - } - out.result - } - override def syms(e: Any): List[Sym[Any]] = e match { case s: Summary => Nil // don't count effect summaries as dependencies! - - // enable DCE of reflect nodes if they are only control dependencies - case Reflect(x,u,es) if addControlDeps => syms(x) ::: nonControlSyms(es, syms) - case Reify(x,u,es) if addControlDeps => syms(x) ::: nonControlSyms(es, syms) case _ => super.syms(e) } @@ -150,10 +115,6 @@ trait Effects extends Expressions with Blocks with Utils { override def symsFreq(e: Any): List[(Sym[Any], Double)] = e match { case s: Summary => Nil // don't count effect summaries as dependencies! - - // enable DCE of reflect nodes if they are only control dependencies - case Reflect(x,u,es) if addControlDeps => symsFreq(x) ::: nonControlSyms(es, symsFreq) - case Reify(x,u,es) if addControlDeps => symsFreq(x) ::: nonControlSyms(es, symsFreq) case _ => super.symsFreq(e) } @@ -353,7 +314,8 @@ trait Effects extends Expressions with Blocks with Utils { createReflectDefinition // if summary is not pure */ // warn if type is Any. TODO: make optional, sometimes Exp[Any] is fine - if (manifest[T] == manifest[Any]) printlog("warning: possible missing mtype call - toAtom with Def of type Any " + d) + if (Config.verbosity == 2 && manifest[T] == manifest[Any]) + printlog("warning: possible missing mtype call - toAtom with Def of type Any " + d) // AKS NOTE: this was removed on 6/27/12, but it is still a problem in OptiML apps without it, // so I'm putting it back until we can get it resolved properly. @@ -363,24 +325,23 @@ trait Effects extends Expressions with Blocks with Utils { // specifically, if we return the reified version of a mutable bound var, we get a Reflect(Reify(..)) error, e.g. mutable Sum // printlog("ignoring read of Reify(): " + d) super.toAtom(d) - case _ if conditionalScope && addControlDeps => reflectEffect(d, Control()) case _ => reflectEffect(d, Pure()) } // reflectEffect(d, Pure()) } - def reflectMirrored[A:Manifest](zd: Reflect[A])(implicit pos: SourceContext): Exp[A] = { + def reflectMirrored[A:Manifest](zd: Reflect[A]): Exp[A] = { checkContext() // warn if type is Any. TODO: make optional, sometimes Exp[Any] is fine if (manifest[A] == manifest[Any]) printlog("warning: possible missing mtype call - reflectMirrored with Def of type Any: " + zd) context.filter { case Def(d) if d == zd => true case _ => false }.reverse match { //case z::_ => z.asInstanceOf[Exp[A]] -- unsafe: we don't have a tight context, so we might pick one from a flattened subcontext - case _ => createReflectDefinition(fresh[A].withPos(List(pos)), zd) + case _ => createReflectDefinition(fresh[A], zd) } } def checkIllegalSharing(z: Exp[Any], mutableAliases: List[Sym[Any]]) { - if (mutableAliases.nonEmpty) { + if (mutableAliases.nonEmpty && Config.verbosity >= 1) { val zd = z match { case Def(zd) => zd } printerr("error: illegal sharing of mutable objects " + mutableAliases.mkString(", ")) printerr("at " + z + "=" + zd) @@ -429,21 +390,35 @@ trait Effects extends Expressions with Blocks with Utils { val mutableInputs = readMutableData(d) reflectEffectInternal(d, u andAlso Read(mutableInputs)) // will call super.toAtom if mutableInput.isEmpty } - + def reflectEffectInternal[A:Manifest](x: Def[A], u: Summary)(implicit pos: SourceContext): Exp[A] = { - if (mustPure(u)) super.toAtom(x) else { + /* + * We want to handle the case where a variable is first initialized to null, + * and later initialized to its actual value (also see variables.scala). + * This is the case with local state in the DBMS system of DATA lab. There, + * when the var is initialized to null, the context is null, since it is + * _outside_ any compiled method (we do not lift everything in this + * system!). Initial solution was to override the case where context == + * null, however this breaks test4-fac4 (reflectEffect when not needed). + */ + /*if (context == null) { + context = Nil + if (mustPure(u)) super.toAtom(x) + else { + val z = fresh[A] + val zd = Reflect(x,u,null) + createReflectDefinition(z, zd) + } + } else*/ if (mustPure(u)) super.toAtom(x) else { checkContext() // NOTE: reflecting mutable stuff *during mirroring* doesn't work right now. - // FIXME: Reflect(Reflect(ObjectUnsafeImmutable(..))) on delite - assert(!x.isInstanceOf[Reflect[_]], x) - val deps = calculateDependencies(u) val zd = Reflect(x,u,deps) if (mustIdempotent(u)) { context find { case Def(d) => d == zd } map { _.asInstanceOf[Exp[A]] } getOrElse { // findDefinition(zd) map (_.sym) filter (context contains _) getOrElse { // local cse TODO: turn around and look at context first?? - val z = fresh[A](List(pos)) + val z = fresh[A] if (!x.toString.startsWith("ReadVar")) { // supress output for ReadVar printlog("promoting to effect: " + z + "=" + zd) for (w <- u.mayRead) @@ -454,11 +429,13 @@ trait Effects extends Expressions with Blocks with Utils { } else { val z = fresh[A](List(pos)) // make sure all writes go to allocs - for (w <- u.mayWrite if !isWritableSym(w)) { - printerr("error: write to non-mutable " + w + " -> " + findDefinition(w)) - printerr("at " + z + "=" + zd) - printsrc("in " + quotePos(z)) - } + if (Config.verbosity >= 1) { + for (w <- u.mayWrite if !isWritableSym(w)) { + printerr("error: write to non-mutable " + w + " -> " + findDefinition(w)) + printerr("at " + z + "=" + zd) + printsrc("in " + quotePos(z)) + } + } // prevent sharing between mutable objects / disallow mutable escape for non read-only operations // make sure no mutable object becomes part of mutable result (in case of allocation) // or is written to another mutable object (in case of write) @@ -509,12 +486,11 @@ trait Effects extends Expressions with Blocks with Utils { val softWriteDeps = if (write.isEmpty) Nil else scope filter { case e@Def(Reflect(_, u, _)) => mayRead(u, write) } val writeDeps = if (write.isEmpty) Nil else scope filter { case e@Def(Reflect(_, u, _)) => mayWrite(u, write) || write.contains(e) } val simpleDeps = if (!u.maySimple) Nil else scope filter { case e@Def(Reflect(_, u, _)) => u.maySimple } - val controlDeps = if (!u.control) Nil else scope filter { case e@Def(Reflect(_, u, _)) => u.control } val globalDeps = scope filter { case e@Def(Reflect(_, u, _)) => u.mayGlobal } // TODO: write-on-read deps should be weak // TODO: optimize!! - val allDeps = canonic(readDeps ++ softWriteDeps ++ writeDeps ++ canonicLinear(simpleDeps) ++ canonicLinear(controlDeps) ++ canonicLinear(globalDeps)) + val allDeps = canonic(readDeps ++ softWriteDeps ++ writeDeps ++ canonicLinear(simpleDeps) ++ canonicLinear(globalDeps)) scope filter (allDeps contains _) } } @@ -561,18 +537,12 @@ trait Effects extends Expressions with Blocks with Utils { // reify the effects of an isolated block. // no assumptions about the current context remain valid. - def reifyEffects[A:Manifest](block: => Exp[A], controlScope: Boolean = false): Block[A] = { + def reifyEffects[A:Manifest](block: => Exp[A]): Block[A] = { val save = context context = Nil - // only add control dependencies scopes where controlScope is explicitly true (i.e., the first-level of an IfThenElse) - val saveControl = conditionalScope - conditionalScope = controlScope - val (result, defs) = reifySubGraph(block) - reflectSubGraph(defs) - - conditionalScope = saveControl + reflectSubGraph(defs) val deps = context val summary = summarizeAll(deps) @@ -583,19 +553,14 @@ trait Effects extends Expressions with Blocks with Utils { // reify the effects of a block that is executed 'here' (if it is executed at all). // all assumptions about the current context carry over unchanged. - def reifyEffectsHere[A:Manifest](block: => Exp[A], controlScope: Boolean = false): Block[A] = { + def reifyEffectsHere[A:Manifest](block: => Exp[A]): Block[A] = { val save = context if (save eq null) context = Nil - val saveControl = conditionalScope - conditionalScope = controlScope - val (result, defs) = reifySubGraph(block) reflectSubGraph(defs) - conditionalScope = saveControl - if ((save ne null) && context.take(save.length) != save) // TODO: use splitAt printerr("error: 'here' effects must leave outer information intact: " + save + " is not a prefix of " + context) diff --git a/src/internal/Expressions.scala b/src/internal/Expressions.scala index a2898c8c..a88ff522 100644 --- a/src/internal/Expressions.scala +++ b/src/internal/Expressions.scala @@ -16,22 +16,130 @@ import java.lang.{StackTraceElement,Thread} trait Expressions extends Utils { abstract class Exp[+T:Manifest] { // constants/symbols (atomic) - def tp: Manifest[T @uncheckedVariance] = manifest[T] //invariant position! but hey... + var tp: Manifest[T @uncheckedVariance] = manifest[T] //invariant position! but hey... def pos: List[SourceContext] = Nil + var emitted = false; } - case class Const[+T:Manifest](x: T) extends Exp[T] + case class Const[+T:Manifest](x: T) extends Exp[T] { + /** + * equals implementation in Const can not simply rely on default + * implementation for a case class, because we should check the + * type of Const for equality test. + * Otherwise, we might end-up generating code with wrong typing, + * specially upon CSE. + * + * For example, have a look at test1-arith/TestConstCSE: + * + * trait Prog extends ScalaOpsPkg { + * def test1(test_param: Rep[Boolean], acc: Rep[Long]): Rep[Long] = { + * val dblVal = if(test_param) unit(1.0) else unit(0.0) + * val lngVal = if(test_param) unit(1L) else unit(0L) + * auxMethod(acc + lngVal, dblVal) + * } + * + * def auxMethod(val1: Rep[Long], val2: Rep[Double]): Rep[Long] = { + * val1 + unit(133L) + rep_asinstanceof[Double, Long](val2,manifest[Double],manifest[Long]) + * } + * } + * + * That would generate a code containing a compile error: + * + * class test1 extends ((Boolean, Long)=>(Long)) { + * def apply(x0:Boolean, x1:Long): Long = { + * val x2 = if (x0) { + * 1.0 + * } else { + * 0.0 + * } + * val x3 = x1 + x2 + * val x4 = x3 + 133L + * val x5 = x2.asInstanceOf[Long] + * val x6 = x4 + x5 + * x6 + * } + * } + * + * :15: error: type mismatch; + * found : Double + * required: Long + * x6 + * ^ + * one error found + * compilation: had errors + * + * But, by introducing this new implementation for equals, the + * correct code will be generated: + * + * class test1 extends ((Boolean, Long)=>(Long)) { + * def apply(x0:Boolean, x1:Long): Long = { + * val x3 = if (x0) { + * 1L + * } else { + * 0L + * } + * val x4 = x1 + x3 + * val x5 = x4 + 133L + * val x2 = if (x0) { + * 1.0 + * } else { + * 0.0 + * } + * val x6 = x2.asInstanceOf[Long] + * val x7 = x5 + x6 + * x7 + * } + * } + * + * compilation: ok + */ + override def equals(that: Any) = that match { + case c@Const(y) => if(y == x) { + val thisTp = tp + //val thatTp = c.tp + if (Const.isNumeric[T](thisTp) /*&& isNumeric(thatTp)*/) + thisTp == c.tp //thatTp + else + true + } else false + case _ => false + } + } + + object Const { + val doubleManifest: Manifest[Double] = manifest[Double] + val floatManifest: Manifest[Float] = manifest[Float] + val longManifest: Manifest[Long] = manifest[Long] + val intManifest: Manifest[Int] = manifest[Int] + val shortManifest: Manifest[Short] = manifest[Short] + val byteManifest: Manifest[Byte] = manifest[Byte] + + def isNumeric[T:Manifest](m: Manifest[T]) = m == doubleManifest || + m == floatManifest || + m == longManifest || + m == intManifest || + m == shortManifest || + m == byteManifest + } case class Sym[+T:Manifest](val id: Int) extends Exp[T] { + val attributes: scala.collection.mutable.Map[Any,Any] = scala.collection.mutable.ListMap.empty + var sourceInfo = Thread.currentThread.getStackTrace // will go away var sourceContexts: List[SourceContext] = Nil override def pos = sourceContexts def withPos(pos: List[SourceContext]) = { sourceContexts :::= pos; this } } - case class Variable[+T](val e: Exp[Variable[T]]) // TODO: decide whether it should stay here ... FIXME: should be invariant + case class Variable[+T](val e: Exp[Variable[T]]) { + } // TODO: decide whether it should stay here ... FIXME: should be invariant var nVars = 0 - def fresh[T:Manifest]: Sym[T] = Sym[T] { nVars += 1; if (nVars%1000 == 0) printlog("nVars="+nVars); nVars -1 } + def fresh[T:Manifest]: Sym[T] = Sym[T] { + nVars += 1; + //if (nVars%1000 == 0) println("nVars="+nVars); + nVars - 1 + } + def fresh[T:Manifest](id: Int): Sym[T] = Sym[T] { id } def fresh[T:Manifest](pos: List[SourceContext]): Sym[T] = fresh[T].withPos(pos) @@ -45,6 +153,48 @@ trait Expressions extends Utils { cs.map(c => all(c).reverse.map(c => c.fileName.split("/").last + ":" + c.line).mkString("//")).mkString(";") } +/* + def fresh[T:Manifest] = { + val (name, id, nameId) = nextName("x") + val sym = Sym[T](id) + sym.name = name + sym.nameId = nameId + sym + } + + def fresh[T:Manifest](d: Def[T], ctx: Option[SourceContext]) = { + def enclosingNamedContext(sc: SourceContext): Option[SourceContext] = sc.bindings match { + case (null, _) :: _ => + if (!sc.parent.isEmpty) enclosingNamedContext(sc.parent.get) + else None + case (name, line) :: _ => + Some(sc) + } + + // create base name from source context + val (basename, line, srcCtx) = if (!ctx.isEmpty) { + enclosingNamedContext(ctx.get) match { + case None => + // no enclosing context has variable assignment + var outermost = ctx.get + while (!outermost.parent.isEmpty) { + outermost = outermost.parent.get + } + ("x", 0, Some(outermost)) + case Some(sc) => sc.bindings match { + case (n, l) :: _ => + (n, l, Some(sc)) + } + } + } else ("x", 0, None) + val (name, id, nameId) = nextName(basename) + val sym = Sym[T](id) + sym.name = name + sym.nameId = nameId + sym.sourceContext = srcCtx + sym + } +*/ abstract class Def[+T] { // operations (composite) override final lazy val hashCode = scala.runtime.ScalaRunTime._hashCode(this.asInstanceOf[Product]) @@ -94,8 +244,7 @@ trait Expressions extends Utils { def reflectSubGraph(ds: List[Stm]): Unit = { val lhs = ds.flatMap(_.lhs) assert(lhs.length == lhs.distinct.length, "multiple defs: " + ds) - // equivalent to: globalDefs filter (_.lhs exists (lhs contains _)) - val existing = lhs flatMap (globalDefsCache get _) + val existing = lhs flatMap (globalDefsCache get _)//globalDefs filter (_.lhs exists (lhs contains _)) assert(existing.isEmpty, "already defined: " + existing + " for " + ds) localDefs = localDefs ::: ds globalDefs = globalDefs ::: ds @@ -131,7 +280,7 @@ trait Expressions extends Utils { } object Def { - def unapply[T](e: Exp[T]): Option[Def[T]] = e match { + def unapply[T](e: Exp[T]): Option[Def[T]] = e match { // really need to test for sym? case s @ Sym(_) => findDefinition(s).flatMap(_.defines(s)) case _ => @@ -148,9 +297,8 @@ trait Expressions extends Utils { case ss: Iterable[Any] => ss.toList.flatMap(syms(_)) // All case classes extend Product! case p: Product => - // performance hotspot: this is the same as - // p.productIterator.toList.flatMap(syms(_)) - // but faster + //return p.productIterator.toList.flatMap(syms(_)) + /* performance hotspot */ val iter = p.productIterator val out = new ListBuffer[Sym[Any]] while (iter.hasNext) { @@ -192,7 +340,7 @@ trait Expressions extends Utils { case _ => Nil } - // generic symbol traversal: f is expected to call rsyms again + def rsyms[T](e: Any)(f: Any=>List[T]): List[T] = e match { case s: Sym[Any] => f(s) case ss: Iterable[Any] => ss.toList.flatMap(f) @@ -215,6 +363,25 @@ trait Expressions extends Utils { def freqCold(e: Any) = symsFreq(e).map(p=>(p._1,p._2*0.5)) + +/* + def symsFreq(e: Any): List[(Sym[Any], Double)] = e match { + case s: Sym[Any] => List((s,1.0)) + case p: Product => p.productIterator.toList.flatMap(symsFreq(_)) + case _ => Nil + } +*/ + +/* + def symsShare(e: Any): List[(Sym[Any], Int)] = { + case s: Sym[Any] => List(s) + case p: Product => p.productIterator.toList.flatMap(symsShare(_)) + case _ => Nil + } +*/ + + + // bookkeeping def reset { // used by delite? diff --git a/src/internal/ExtendedExpressions.scala b/src/internal/ExtendedExpressions.scala new file mode 100644 index 00000000..ceefd75d --- /dev/null +++ b/src/internal/ExtendedExpressions.scala @@ -0,0 +1,53 @@ +package scala.lms +package internal + +import scala.reflect.SourceContext +import scala.annotation.unchecked.uncheckedVariance +import scala.collection.mutable.ListBuffer +import java.lang.{StackTraceElement,Thread} + + +trait ExtendedExpressions extends Expressions with Blocks { + val RefCountAttributeKey = "refCnt" + val ParentBlockAttributeKey = "pBlk" + + def infix_refCount(s: Sym[Any]): Int = { + s.attributes.get(RefCountAttributeKey).getOrElse(0).asInstanceOf[Int] + } + + def infix_incRefCount(s: Sym[Any], inc: Int): Unit = { + s.setRefCount(s.refCount + inc) + } + + def infix_setRefCount(s: Sym[Any], refCnt: Int): Unit = { + s.attributes.update(RefCountAttributeKey, refCnt) + } + + def infix_parentBlock(s: Sym[Any]): Option[Block[Any]] = { + s.attributes.get(ParentBlockAttributeKey).asInstanceOf[Option[Block[Any]]] + } + + def infix_setParentBlock(s: Sym[Any], pBlk: Option[Block[Any]]): Unit = pBlk match { + case Some(blk) => s.attributes.update(ParentBlockAttributeKey, blk) + case None => + } + + def infix_inSameParentBlockAs(thiz: Sym[Any], other: Sym[Any]): Boolean = { + val thizParent: Option[Block[Any]] = thiz.attributes.get(ParentBlockAttributeKey).asInstanceOf[Option[Block[Any]]] + val otherParent: Option[Block[Any]] = other.attributes.get(ParentBlockAttributeKey).asInstanceOf[Option[Block[Any]]] + thizParent match { + case Some(thizP) => otherParent match { + case Some(otherP) => thizP.res == otherP.res + case None => false + } + case None => otherParent match { + case Some(otherP) => false + case None => true + } + } + } + + def infix_possibleToInline(s: Sym[Any]): Boolean = s.refCount == 1 + + def infix_noReference(s: Sym[Any]): Boolean = s.refCount == 0 +} \ No newline at end of file diff --git a/src/internal/GenericCodegen.scala b/src/internal/GenericCodegen.scala index 3e9639fe..ada9810b 100644 --- a/src/internal/GenericCodegen.scala +++ b/src/internal/GenericCodegen.scala @@ -4,6 +4,7 @@ package internal import util.GraphUtil import java.io.{File, PrintWriter} import scala.reflect.RefinedManifest +import scala.collection.mutable.{Map => MMap} trait GenericCodegen extends BlockTraversal { val IR: Expressions @@ -11,25 +12,36 @@ trait GenericCodegen extends BlockTraversal { // TODO: should some of the methods be moved into more specific subclasses? - def deviceTarget: Targets.Value = throw new Exception("deviceTarget is not defined for this codegen.") - def hostTarget: Targets.Value = Targets.getHostTarget(deviceTarget) - def isAcceleratorTarget: Boolean = hostTarget != deviceTarget - def kernelFileExt = "" - def emitFileHeader(): Unit = {} def emitKernelHeader(syms: List[Sym[Any]], vals: List[Sym[Any]], vars: List[Sym[Any]], resultType: String, resultIsVar: Boolean, external: Boolean): Unit = {} def emitKernelFooter(syms: List[Sym[Any]], vals: List[Sym[Any]], vars: List[Sym[Any]], resultType: String, resultIsVar: Boolean, external: Boolean): Unit = {} + var analysisResults: MMap[String,Any] = null.asInstanceOf[MMap[String,Any]] + + /** + * List of transformers that should be applied before code generation + */ + var transformers: List[AbstractTransformer] = List[AbstractTransformer]() + + def performTransformations[A:Manifest](body: Block[A]): Block[A] = { + var transformedBody = body + transformers foreach { trans => + transformedBody = trans.apply[A](body.asInstanceOf[trans.IR.Block[A]]).asInstanceOf[this.Block[A]] + } + transformedBody + } + + def emitFileHeader(): Unit = {} + def emitFunctions(): Unit = {} + // Initializer - def initializeGenerator(buildDir:String, args: Array[String]): Unit = { } + def initializeGenerator(buildDir:String, args: Array[String], _analysisResults: MMap[String,Any]): Unit = { analysisResults = _analysisResults } def finalizeGenerator(): Unit = {} def kernelInit(syms: List[Sym[Any]], vals: List[Sym[Any]], vars: List[Sym[Any]], resultIsVar: Boolean): Unit = {} - def emitDataStructures(stream: PrintWriter): Unit = {} + def emitDataStructures(out: PrintWriter): Unit = {} def emitDataStructures(path: String): Unit = {} - def getDataStructureHeaders(): String = "" - def emitTransferFunctions(): Unit = {} - + def dataPath = { "data" + java.io.File.separator } @@ -50,15 +62,24 @@ trait GenericCodegen extends BlockTraversal { stream.close() } - + + def runTransformations[A:Manifest](body: Block[A]): Block[A] = body + // exception handler def exceptionHandler(e: Exception, outFile:File, kstream:PrintWriter): Unit = { kstream.close() outFile.delete } - // optional type remapping (default is identity) - def remap(s: String): String = s + /** + * optional type remapping (default is identity) + * except that we should replace all '$' by '.' + * because inner class names might contain $ sign + */ + def remap(s: String): String = s match { + case "java.lang.Character" => "Character" + case _ => s.replace('$', '.') + } def remap[A](s: String, method: String, t: Manifest[A]) : String = remap(s, method, t.toString) def remap(s: String, method: String, t: String) : String = s + method + "[" + remap(t) + "]" def remap[A](m: Manifest[A]): String = m match { @@ -70,18 +91,18 @@ trait GenericCodegen extends BlockTraversal { val targs = m.typeArguments if (targs.length > 0) { val ms = m.toString - ms.take(ms.indexOf("[")+1) + targs.map(tp => remap(tp)).mkString(", ") + "]" + remap(ms.take(ms.indexOf("["))) + "[" + targs.map(tp => remap(tp)).mkString(", ") + "]" } - else m.toString + else remap(m.toString) } def remapImpl[A](m: Manifest[A]): String = remap(m) //def remapVar[A](m: Manifest[Variable[A]]) : String = remap(m.typeArguments.head) - - def remapHost[A](m: Manifest[A]): String = remap(m).replaceAll(deviceTarget.toString,hostTarget.toString) def hasMetaData: Boolean = false def getMetaData: String = null + def getDSLHeaders: String = null + // --------- var stream: PrintWriter = _ @@ -100,29 +121,36 @@ trait GenericCodegen extends BlockTraversal { } def emitBlock(y: Block[Any]): Unit = traverseBlock(y) + + def emitBlockResult[A: Manifest](b: Block[A]) { + if (remap(manifest[A]) != "Unit") stream.println(quote(getBlockResult(b))) + } def emitNode(sym: Sym[Any], rhs: Def[Any]): Unit = { throw new GenerationFailedException("don't know how to generate code for: " + rhs) } def emitValDef(sym: Sym[Any], rhs: String): Unit - def emitVarDecl(sym: Sym[Any]): Unit = throw new GenerationFailedException("don't know how to emit variable declaration " + quote(sym)) - def emitAssignment(sym: Sym[Any], rhs: String): Unit = throw new GenerationFailedException("don't know how to emit variable assignment " + quote(sym)) + + def emitSource0[R : Manifest](f: () => Exp[R], className: String, stream: PrintWriter, dynamicReturnType: String = null): List[(Sym[Any], Any)] = { + val body = reifyBlock(f()) + emitSource(List(), body, className, stream, dynamicReturnType) + } - def emitSource[T : Manifest, R : Manifest](f: Exp[T] => Exp[R], className: String, stream: PrintWriter): List[(Sym[Any], Any)] = { - val s = fresh[T] - val body = reifyBlock(f(s)) - emitSource(List(s), body, className, stream) + def emitSource1[T1: Manifest, R : Manifest](f: (Exp[T1]) => Exp[R], className: String, stream: PrintWriter): List[(Sym[Any], Any)] = { + val s1 = fresh[T1] + val body = reifyBlock(f(s1)) + emitSource(List(s1), body, className, stream) } - def emitSource2[T1 : Manifest, T2 : Manifest, R : Manifest](f: (Exp[T1], Exp[T2]) => Exp[R], className: String, stream: PrintWriter): List[(Sym[Any], Any)] = { + def emitSource2[T1: Manifest, T2: Manifest, R : Manifest](f: (Exp[T1], Exp[T2]) => Exp[R], className: String, stream: PrintWriter, dynamicReturnType: String = null): List[(Sym[Any], Any)] = { val s1 = fresh[T1] val s2 = fresh[T2] val body = reifyBlock(f(s1, s2)) - emitSource(List(s1, s2), body, className, stream) + emitSource(List(s1, s2), body, className, stream, dynamicReturnType) } - def emitSource3[T1 : Manifest, T2 : Manifest, T3 : Manifest, R : Manifest](f: (Exp[T1], Exp[T2], Exp[T3]) => Exp[R], className: String, stream: PrintWriter): List[(Sym[Any], Any)] = { + def emitSource3[T1: Manifest, T2: Manifest, T3: Manifest, R : Manifest](f: (Exp[T1], Exp[T2], Exp[T3]) => Exp[R], className: String, stream: PrintWriter): List[(Sym[Any], Any)] = { val s1 = fresh[T1] val s2 = fresh[T2] val s3 = fresh[T3] @@ -130,7 +158,7 @@ trait GenericCodegen extends BlockTraversal { emitSource(List(s1, s2, s3), body, className, stream) } - def emitSource4[T1 : Manifest, T2 : Manifest, T3 : Manifest, T4 : Manifest, R : Manifest](f: (Exp[T1], Exp[T2], Exp[T3], Exp[T4]) => Exp[R], className: String, stream: PrintWriter): List[(Sym[Any], Any)] = { + def emitSource4[T1: Manifest, T2: Manifest, T3: Manifest, T4: Manifest, R : Manifest](f: (Exp[T1], Exp[T2], Exp[T3], Exp[T4]) => Exp[R], className: String, stream: PrintWriter): List[(Sym[Any], Any)] = { val s1 = fresh[T1] val s2 = fresh[T2] val s3 = fresh[T3] @@ -138,8 +166,8 @@ trait GenericCodegen extends BlockTraversal { val body = reifyBlock(f(s1, s2, s3, s4)) emitSource(List(s1, s2, s3, s4), body, className, stream) } - - def emitSource5[T1 : Manifest, T2 : Manifest, T3 : Manifest, T4 : Manifest, T5 : Manifest, R : Manifest](f: (Exp[T1], Exp[T2], Exp[T3], Exp[T4], Exp[T5]) => Exp[R], className: String, stream: PrintWriter): List[(Sym[Any], Any)] = { + + def emitSource5[T1: Manifest, T2: Manifest, T3: Manifest, T4: Manifest, T5: Manifest, R : Manifest](f: (Exp[T1], Exp[T2], Exp[T3], Exp[T4], Exp[T5]) => Exp[R], className: String, stream: PrintWriter): List[(Sym[Any], Any)] = { val s1 = fresh[T1] val s2 = fresh[T2] val s3 = fresh[T3] @@ -148,6 +176,238 @@ trait GenericCodegen extends BlockTraversal { val body = reifyBlock(f(s1, s2, s3, s4, s5)) emitSource(List(s1, s2, s3, s4, s5), body, className, stream) } + + def emitSource6[T1: Manifest, T2: Manifest, T3: Manifest, T4: Manifest, T5: Manifest, T6: Manifest, R : Manifest](f: (Exp[T1], Exp[T2], Exp[T3], Exp[T4], Exp[T5], Exp[T6]) => Exp[R], className: String, stream: PrintWriter): List[(Sym[Any], Any)] = { + val s1 = fresh[T1] + val s2 = fresh[T2] + val s3 = fresh[T3] + val s4 = fresh[T4] + val s5 = fresh[T5] + val s6 = fresh[T6] + val body = reifyBlock(f(s1, s2, s3, s4, s5, s6)) + emitSource(List(s1, s2, s3, s4, s5, s6), body, className, stream) + } + def emitSource7[T1: Manifest, T2: Manifest, T3: Manifest, T4: Manifest, T5: Manifest, T6: Manifest, T7: Manifest, R : Manifest](f: (Exp[T1], Exp[T2], Exp[T3], Exp[T4], Exp[T5], Exp[T6], Exp[T7]) => Exp[R], className: String, stream: PrintWriter): List[(Sym[Any], Any)] = { + val s1 = fresh[T1] + val s2 = fresh[T2] + val s3 = fresh[T3] + val s4 = fresh[T4] + val s5 = fresh[T5] + val s6 = fresh[T6] + val s7 = fresh[T7] + val body = reifyBlock(f(s1, s2, s3, s4, s5, s6, s7)) + emitSource(List(s1, s2, s3, s4, s5, s6, s7), body, className, stream) + } + def emitSource8[T1: Manifest, T2: Manifest, T3: Manifest, T4: Manifest, T5: Manifest, T6: Manifest, T7: Manifest, T8: Manifest, R : Manifest](f: (Exp[T1], Exp[T2], Exp[T3], Exp[T4], Exp[T5], Exp[T6], Exp[T7], Exp[T8]) => Exp[R], className: String, stream: PrintWriter): List[(Sym[Any], Any)] = { + val s1 = fresh[T1] + val s2 = fresh[T2] + val s3 = fresh[T3] + val s4 = fresh[T4] + val s5 = fresh[T5] + val s6 = fresh[T6] + val s7 = fresh[T7] + val s8 = fresh[T8] + val body = reifyBlock(f(s1, s2, s3, s4, s5, s6, s7, s8)) + emitSource(List(s1, s2, s3, s4, s5, s6, s7, s8), body, className, stream) + } + def emitSource9[T1: Manifest, T2: Manifest, T3: Manifest, T4: Manifest, T5: Manifest, T6: Manifest, T7: Manifest, T8: Manifest, T9: Manifest, R : Manifest](f: (Exp[T1], Exp[T2], Exp[T3], Exp[T4], Exp[T5], Exp[T6], Exp[T7], Exp[T8], Exp[T9]) => Exp[R], className: String, stream: PrintWriter): List[(Sym[Any], Any)] = { + val s1 = fresh[T1] + val s2 = fresh[T2] + val s3 = fresh[T3] + val s4 = fresh[T4] + val s5 = fresh[T5] + val s6 = fresh[T6] + val s7 = fresh[T7] + val s8 = fresh[T8] + val s9 = fresh[T9] + val body = reifyBlock(f(s1, s2, s3, s4, s5, s6, s7, s8, s9)) + emitSource(List(s1, s2, s3, s4, s5, s6, s7, s8, s9), body, className, stream) + } + def emitSource10[T1: Manifest, T2: Manifest, T3: Manifest, T4: Manifest, T5: Manifest, T6: Manifest, T7: Manifest, T8: Manifest, T9: Manifest, T10: Manifest, R : Manifest](f: (Exp[T1], Exp[T2], Exp[T3], Exp[T4], Exp[T5], Exp[T6], Exp[T7], Exp[T8], Exp[T9], Exp[T10]) => Exp[R], className: String, stream: PrintWriter): List[(Sym[Any], Any)] = { + val s1 = fresh[T1] + val s2 = fresh[T2] + val s3 = fresh[T3] + val s4 = fresh[T4] + val s5 = fresh[T5] + val s6 = fresh[T6] + val s7 = fresh[T7] + val s8 = fresh[T8] + val s9 = fresh[T9] + val s10 = fresh[T10] + val body = reifyBlock(f(s1, s2, s3, s4, s5, s6, s7, s8, s9, s10)) + emitSource(List(s1, s2, s3, s4, s5, s6, s7, s8, s9, s10), body, className, stream) + } + def emitSource11[T1: Manifest, T2: Manifest, T3: Manifest, T4: Manifest, T5: Manifest, T6: Manifest, T7: Manifest, T8: Manifest, T9: Manifest, T10: Manifest, T11: Manifest, R : Manifest](f: (Exp[T1], Exp[T2], Exp[T3], Exp[T4], Exp[T5], Exp[T6], Exp[T7], Exp[T8], Exp[T9], Exp[T10], Exp[T11]) => Exp[R], className: String, stream: PrintWriter): List[(Sym[Any], Any)] = { + val s1 = fresh[T1] + val s2 = fresh[T2] + val s3 = fresh[T3] + val s4 = fresh[T4] + val s5 = fresh[T5] + val s6 = fresh[T6] + val s7 = fresh[T7] + val s8 = fresh[T8] + val s9 = fresh[T9] + val s10 = fresh[T10] + val s11 = fresh[T11] + val body = reifyBlock(f(s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11)) + emitSource(List(s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11), body, className, stream) + } + def emitSource12[T1: Manifest, T2: Manifest, T3: Manifest, T4: Manifest, T5: Manifest, T6: Manifest, T7: Manifest, T8: Manifest, T9: Manifest, T10: Manifest, T11: Manifest, T12: Manifest, R : Manifest](f: (Exp[T1], Exp[T2], Exp[T3], Exp[T4], Exp[T5], Exp[T6], Exp[T7], Exp[T8], Exp[T9], Exp[T10], Exp[T11], Exp[T12]) => Exp[R], className: String, stream: PrintWriter): List[(Sym[Any], Any)] = { + val s1 = fresh[T1] + val s2 = fresh[T2] + val s3 = fresh[T3] + val s4 = fresh[T4] + val s5 = fresh[T5] + val s6 = fresh[T6] + val s7 = fresh[T7] + val s8 = fresh[T8] + val s9 = fresh[T9] + val s10 = fresh[T10] + val s11 = fresh[T11] + val s12 = fresh[T12] + val body = reifyBlock(f(s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12)) + emitSource(List(s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12), body, className, stream) + } + def emitSource13[T1: Manifest, T2: Manifest, T3: Manifest, T4: Manifest, T5: Manifest, T6: Manifest, T7: Manifest, T8: Manifest, T9: Manifest, T10: Manifest, T11: Manifest, T12: Manifest, T13: Manifest, R : Manifest](f: (Exp[T1], Exp[T2], Exp[T3], Exp[T4], Exp[T5], Exp[T6], Exp[T7], Exp[T8], Exp[T9], Exp[T10], Exp[T11], Exp[T12], Exp[T13]) => Exp[R], className: String, stream: PrintWriter): List[(Sym[Any], Any)] = { + val s1 = fresh[T1] + val s2 = fresh[T2] + val s3 = fresh[T3] + val s4 = fresh[T4] + val s5 = fresh[T5] + val s6 = fresh[T6] + val s7 = fresh[T7] + val s8 = fresh[T8] + val s9 = fresh[T9] + val s10 = fresh[T10] + val s11 = fresh[T11] + val s12 = fresh[T12] + val s13 = fresh[T13] + val body = reifyBlock(f(s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12, s13)) + emitSource(List(s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12, s13), body, className, stream) + } + def emitSource14[T1: Manifest, T2: Manifest, T3: Manifest, T4: Manifest, T5: Manifest, T6: Manifest, T7: Manifest, T8: Manifest, T9: Manifest, T10: Manifest, T11: Manifest, T12: Manifest, T13: Manifest, T14: Manifest, R : Manifest](f: (Exp[T1], Exp[T2], Exp[T3], Exp[T4], Exp[T5], Exp[T6], Exp[T7], Exp[T8], Exp[T9], Exp[T10], Exp[T11], Exp[T12], Exp[T13], Exp[T14]) => Exp[R], className: String, stream: PrintWriter): List[(Sym[Any], Any)] = { + val s1 = fresh[T1] + val s2 = fresh[T2] + val s3 = fresh[T3] + val s4 = fresh[T4] + val s5 = fresh[T5] + val s6 = fresh[T6] + val s7 = fresh[T7] + val s8 = fresh[T8] + val s9 = fresh[T9] + val s10 = fresh[T10] + val s11 = fresh[T11] + val s12 = fresh[T12] + val s13 = fresh[T13] + val s14 = fresh[T14] + val body = reifyBlock(f(s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12, s13, s14)) + emitSource(List(s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12, s13, s14), body, className, stream) + } + def emitSource15[T1: Manifest, T2: Manifest, T3: Manifest, T4: Manifest, T5: Manifest, T6: Manifest, T7: Manifest, T8: Manifest, T9: Manifest, T10: Manifest, T11: Manifest, T12: Manifest, T13: Manifest, T14: Manifest, T15: Manifest, R : Manifest](f: (Exp[T1], Exp[T2], Exp[T3], Exp[T4], Exp[T5], Exp[T6], Exp[T7], Exp[T8], Exp[T9], Exp[T10], Exp[T11], Exp[T12], Exp[T13], Exp[T14], Exp[T15]) => Exp[R], className: String, stream: PrintWriter): List[(Sym[Any], Any)] = { + val s1 = fresh[T1] + val s2 = fresh[T2] + val s3 = fresh[T3] + val s4 = fresh[T4] + val s5 = fresh[T5] + val s6 = fresh[T6] + val s7 = fresh[T7] + val s8 = fresh[T8] + val s9 = fresh[T9] + val s10 = fresh[T10] + val s11 = fresh[T11] + val s12 = fresh[T12] + val s13 = fresh[T13] + val s14 = fresh[T14] + val s15 = fresh[T15] + val body = reifyBlock(f(s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12, s13, s14, s15)) + emitSource(List(s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12, s13, s14, s15), body, className, stream) + } + def emitSource16[T1: Manifest, T2: Manifest, T3: Manifest, T4: Manifest, T5: Manifest, T6: Manifest, T7: Manifest, T8: Manifest, T9: Manifest, T10: Manifest, T11: Manifest, T12: Manifest, T13: Manifest, T14: Manifest, T15: Manifest, T16: Manifest, R : Manifest](f: (Exp[T1], Exp[T2], Exp[T3], Exp[T4], Exp[T5], Exp[T6], Exp[T7], Exp[T8], Exp[T9], Exp[T10], Exp[T11], Exp[T12], Exp[T13], Exp[T14], Exp[T15], Exp[T16]) => Exp[R], className: String, stream: PrintWriter): List[(Sym[Any], Any)] = { + val s1 = fresh[T1] + val s2 = fresh[T2] + val s3 = fresh[T3] + val s4 = fresh[T4] + val s5 = fresh[T5] + val s6 = fresh[T6] + val s7 = fresh[T7] + val s8 = fresh[T8] + val s9 = fresh[T9] + val s10 = fresh[T10] + val s11 = fresh[T11] + val s12 = fresh[T12] + val s13 = fresh[T13] + val s14 = fresh[T14] + val s15 = fresh[T15] + val s16 = fresh[T16] + val body = reifyBlock(f(s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12, s13, s14, s15, s16)) + emitSource(List(s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12, s13, s14, s15, s16), body, className, stream) + } + def emitSource17[T1: Manifest, T2: Manifest, T3: Manifest, T4: Manifest, T5: Manifest, T6: Manifest, T7: Manifest, T8: Manifest, T9: Manifest, T10: Manifest, T11: Manifest, T12: Manifest, T13: Manifest, T14: Manifest, T15: Manifest, T16: Manifest, T17: Manifest, R : Manifest](f: (Exp[T1], Exp[T2], Exp[T3], Exp[T4], Exp[T5], Exp[T6], Exp[T7], Exp[T8], Exp[T9], Exp[T10], Exp[T11], Exp[T12], Exp[T13], Exp[T14], Exp[T15], Exp[T16], Exp[T17]) => Exp[R], className: String, stream: PrintWriter): List[(Sym[Any], Any)] = { + val s1 = fresh[T1] + val s2 = fresh[T2] + val s3 = fresh[T3] + val s4 = fresh[T4] + val s5 = fresh[T5] + val s6 = fresh[T6] + val s7 = fresh[T7] + val s8 = fresh[T8] + val s9 = fresh[T9] + val s10 = fresh[T10] + val s11 = fresh[T11] + val s12 = fresh[T12] + val s13 = fresh[T13] + val s14 = fresh[T14] + val s15 = fresh[T15] + val s16 = fresh[T16] + val s17 = fresh[T17] + val body = reifyBlock(f(s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12, s13, s14, s15, s16, s17)) + emitSource(List(s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12, s13, s14, s15, s16, s17), body, className, stream) + } + def emitSource18[T1: Manifest, T2: Manifest, T3: Manifest, T4: Manifest, T5: Manifest, T6: Manifest, T7: Manifest, T8: Manifest, T9: Manifest, T10: Manifest, T11: Manifest, T12: Manifest, T13: Manifest, T14: Manifest, T15: Manifest, T16: Manifest, T17: Manifest, T18: Manifest, R : Manifest](f: (Exp[T1], Exp[T2], Exp[T3], Exp[T4], Exp[T5], Exp[T6], Exp[T7], Exp[T8], Exp[T9], Exp[T10], Exp[T11], Exp[T12], Exp[T13], Exp[T14], Exp[T15], Exp[T16], Exp[T17], Exp[T18]) => Exp[R], className: String, stream: PrintWriter): List[(Sym[Any], Any)] = { + val s1 = fresh[T1] + val s2 = fresh[T2] + val s3 = fresh[T3] + val s4 = fresh[T4] + val s5 = fresh[T5] + val s6 = fresh[T6] + val s7 = fresh[T7] + val s8 = fresh[T8] + val s9 = fresh[T9] + val s10 = fresh[T10] + val s11 = fresh[T11] + val s12 = fresh[T12] + val s13 = fresh[T13] + val s14 = fresh[T14] + val s15 = fresh[T15] + val s16 = fresh[T16] + val s17 = fresh[T17] + val s18 = fresh[T18] + val body = reifyBlock(f(s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12, s13, s14, s15, s16, s17, s18)) + emitSource(List(s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12, s13, s14, s15, s16, s17, s18), body, className, stream) + } + def emitSource19[T1: Manifest, T2: Manifest, T3: Manifest, T4: Manifest, T5: Manifest, T6: Manifest, T7: Manifest, T8: Manifest, T9: Manifest, T10: Manifest, T11: Manifest, T12: Manifest, T13: Manifest, T14: Manifest, T15: Manifest, T16: Manifest, T17: Manifest, T18: Manifest, T19: Manifest, R : Manifest](f: (Exp[T1], Exp[T2], Exp[T3], Exp[T4], Exp[T5], Exp[T6], Exp[T7], Exp[T8], Exp[T9], Exp[T10], Exp[T11], Exp[T12], Exp[T13], Exp[T14], Exp[T15], Exp[T16], Exp[T17], Exp[T18], Exp[T19]) => Exp[R], className: String, stream: PrintWriter): List[(Sym[Any], Any)] = { + val s1 = fresh[T1] + val s2 = fresh[T2] + val s3 = fresh[T3] + val s4 = fresh[T4] + val s5 = fresh[T5] + val s6 = fresh[T6] + val s7 = fresh[T7] + val s8 = fresh[T8] + val s9 = fresh[T9] + val s10 = fresh[T10] + val s11 = fresh[T11] + val s12 = fresh[T12] + val s13 = fresh[T13] + val s14 = fresh[T14] + val s15 = fresh[T15] + val s16 = fresh[T16] + val s17 = fresh[T17] + val s18 = fresh[T18] + val s19 = fresh[T19] + val body = reifyBlock(f(s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12, s13, s14, s15, s16, s17, s18, s19)) + emitSource(List(s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12, s13, s14, s15, s16, s17, s18, s19), body, className, stream) + } /** * @param args List of symbols bound to `body` @@ -155,16 +415,25 @@ trait GenericCodegen extends BlockTraversal { * @param className Name of the generated identifier * @param stream Output stream */ - def emitSource[A : Manifest](args: List[Sym[_]], body: Block[A], className: String, stream: PrintWriter): List[(Sym[Any], Any)] // return free static data in block + def emitSource[A : Manifest](args: List[Sym[_]], body: Block[A], className: String, stream: PrintWriter, dynamicReturnType: String = null, serializable: Boolean = false): List[(Sym[Any], Any)] // return free static data in block + + def quote(x: Exp[Any]) : String = quote(x, false) - def quote(x: Exp[Any]) : String = x match { - case Const(s: String) => "\""+s.replace("\\", "\\\\").replace("\"", "\\\"").replace("\n", "\\n")+"\"" // TODO: more escapes? - case Const(c: Char) => "'"+(""+c).replace("'", "\\'").replace("\n", "\\n")+"'" + def quote(x: Exp[Any], forcePrintSymbol: Boolean = false) : String = x match { + case Const(s: String) => "\""+s.replace("\"", "\\\"").replace("\n", "\\n")+"\"" // TODO: more escapes? + case Const(c: Char) => "'"+c+"'" case Const(f: Float) => "%1.10f".format(f) + "f" case Const(l: Long) => l.toString + "L" case Const(null) => "null" case Const(z) => z.toString - case Sym(n) => "x"+n + case s@Sym(n) => { + if (forcePrintSymbol) "x" + n + // Avoid printing symbols that are of type null + else if (s.tp.toString == "Unit") "" + else "x"+n + } + case x@_ if x == Const(null) => "null" + case null => "null" case _ => throw new RuntimeException("could not quote " + x) } @@ -175,25 +444,13 @@ trait GenericCodegen extends BlockTraversal { super.reset } - def isPrimitiveType[A](m: Manifest[A]) : Boolean = { - m.toString match { - case "Boolean" | "Byte" | "Char" | "Short" | "Int" | "Long" | "Float" | "Double" => true - case _ => false - } - } - def isVoidType[A](m: Manifest[A]) : Boolean = { m.toString match { case "Unit" => true case _ => false } } - - def isVariableType[A](m: Manifest[A]) : Boolean = { - if(m.erasure == classOf[Variable[AnyVal]]) true - else false - } - + // Provides automatic quoting and remapping in the gen string interpolater implicit class CodegenHelper(sc: StringContext) { def printToStream(arg: Any): Unit = { @@ -205,6 +462,7 @@ trait GenericCodegen extends BlockTraversal { case e: Exp[_] => quote(e) case m: Manifest[_] => remap(m) case s: String => s + case null => "null" case _ => throw new RuntimeException(s"Could not quote or remap $arg") } @@ -235,10 +493,41 @@ trait GenericCodegen extends BlockTraversal { -trait GenericNestedCodegen extends NestedBlockTraversal with GenericCodegen { - val IR: Expressions with Effects +trait GenericNestedCodegen extends NestedBlockTraversal with GenericCodegen { self => + val IR: Expressions with Effects with LoweringTransform import IR._ + /* Lowering stuff */ + def lowerNode[T:Manifest](sym: Sym[T], rhs: Def[T]): Unit = { +// println("Lowering " + sym + " with def " + rhs ) + rhs match { + case Reflect(s, u, effects) => lowerNode(sym, s) + case dflt@_ => { + //System.out.println("Don't know how to lower symbol " + dflt + ".") + () + } + } + } + + object HIRLowering extends LoweringTransformer + object LIRLowering extends LoweringTransformer + object LIRTraversal extends NestedBlockTraversal { + val IR: self.IR.type = self.IR + def apply[A: Manifest](b: Block[A]) = traverseBlock(b) + override def traverseStm(stm: Stm): Unit = stm match { + case TP(sym, rhs) => lowerNode(sym, rhs)(sym.tp) + case _ => throw new GenerationFailedException(s"don't know how to generate code for statement: $stm during LIRTraversal") + } + } + + def remapManifest[A:Manifest](m: Sym[A]): Manifest[_] = manifest[A] + + override def runTransformations[A:Manifest](body: Block[A]): Block[A] = { + val b = HIRLowering.run(body) + LIRTraversal(b) + LIRLowering.run(b) + } + override def traverseStm(stm: Stm) = super[GenericCodegen].traverseStm(stm) override def emitNode(sym: Sym[Any], rhs: Def[Any]) = rhs match { @@ -257,11 +546,10 @@ trait GenericNestedCodegen extends NestedBlockTraversal with GenericCodegen { // Allows the gen string interpolator to perform emitBlock when passed a Block implicit class NestedCodegenHelper(sc: StringContext) extends CodegenHelper(sc) { - override def printToStream(arg: Any): Unit = arg match { + override def printToStream(arg: Any): Unit = arg match { case NestedBlock(b) => emitBlock(b) case b: Block[_] => stream.print(quoteOrRemap(getBlockResult(b))) case _ => stream.print(quoteOrRemap(arg)) } } - } diff --git a/src/internal/GraphVizDependencyGraphExport.scala b/src/internal/GraphVizDependencyGraphExport.scala new file mode 100644 index 00000000..6cf6fc28 --- /dev/null +++ b/src/internal/GraphVizDependencyGraphExport.scala @@ -0,0 +1,214 @@ +package scala.lms +package internal + +import java.io.{File, FileWriter, PrintWriter} +import scala.lms.util.ReflectionUtil +import scala.reflect.SourceContext + +/** + * This code-generator, actually generates DependencyGraph + * between IR nodes graph in GraphViz format + * It works similar to other codeg-generators like + * ScalaCodegen or etc. + */ +trait GraphVizDependencyGraphExport extends GenericCodegen with NestedBlockTraversal { self => + val IR: ExtendedExpressions with Effects + import IR._ + + val GraphNodeKindInput = "input" + val GraphNodeKindOutput = "output" + + var levelCounter = 0; + + /** + * Produces IR graph node representation string, given: + * - a: node's symbol + * - kind: meta-data for the shape of node + * (can be GraphNodeKindInput, GraphNodeKindOutput + * , or any name for special nodes) + * - rhs: definition of this node (for finding dependencies) + */ + def getGraphNodeString[A](a: Sym[_], kind: String, rhs: Def[A]): String = { + var atp = remap(a.tp).toString; + + val extra = kind match { + case GraphNodeKindInput => "" + case _ => "shape=record" + } + + var indentation = indentString(levelCounter) + var output = { + indentation + "\"" + quote(a, true) + "\" [ " + extra + + " label = <" + getNodeLabel(a) + ":" + + atp + " = " + kind + ">];" + } + indentation = indentString(levelCounter+1) + if(rhs != null) { + ReflectionUtil.caseNameTypeValues(rhs) foreach { + x => if(x._2 == classOf[Block[A]]) { + val blk: Block[A] = x._3.asInstanceOf[Block[A]] + focusExactScope(blk) { levelScope => + val clusterName: String = findSymString(blk.res, levelScope) + val clusterStartingElem: String = findSymString(Const(0), levelScope) + if(clusterName != "") { + output += "\n"+indentation+"\"" + quote(a,true) + + "\" -> \"" + clusterStartingElem + + "\" [lhead=cluster_"+clusterName+" color=gray];" + } + } + emitBlock(blk) + } else if(x._2 == classOf[Exp[A]]) { + x._3 match { + case s:Sym[_] => { + output += "\n"+indentation+"\"" + quote(s,true) + + "\" -> \"" + quote(a,true) + "\";" + } + case _ => + } + } else if (x._2 == classOf[Variable[A]]) { + x._3.asInstanceOf[Variable[A]].e match { + case s:Sym[_] => { + output += "\n"+indentation+"\"" + quote(s,true) + + "\" -> \"" + quote(a,true) + "\";" + } + case _ => + } + } + } + } + + output + } + + override def emitSource[A : Manifest](args: List[Sym[_]], body: Block[A], className: String, out: PrintWriter, dynamicReturnType: String = null, serializable: Boolean = false) = { + + val sA = remap(manifest[A]) + + // TODO - reflect this static data with some (if any exists) representation in the graph + val staticData = getFreeDataBlock(body) + + withStream(out) { + levelCounter+=2 + val indentation = indentString(levelCounter-1) + val indentationPlus1 = indentString(levelCounter) + + stream.println(indentation + "subgraph cluster" + className + " {") + + emitFileHeader() + + var transformedBody = performTransformations(body) + + stream.println(args.map( a => getGraphNodeString(a, GraphNodeKindInput, null)).mkString("\n")) + emitBlock(transformedBody) + val res = quote(getBlockResult(transformedBody)) + if(res != "()" && res.isInstanceOf[Sym[_]]) { + stream.print("\n" + getGraphNodeString(res.asInstanceOf[Sym[_]], GraphNodeKindOutput, null)) + } + + stream.println("\n"+indentationPlus1+"label = \"" + className + "\";") + stream.println(indentationPlus1+"node [style=filled];") + stream.println(indentationPlus1+"color=blue;") + stream.println(indentation+"}") + levelCounter-=2 + } + + staticData + } + + override def traverseBlockFocused[A](block: Block[A]): Unit = { + focusExactScope(block) { levelScope => + val clusterName: String = findSymString(block.res, levelScope) + + val indentation = indentString(levelCounter) + val indentationPlus1 = indentString(levelCounter+1) + + if(clusterName != "") { + stream.println(indentation+"subgraph cluster_" + clusterName + " {") + } + traverseStmsInBlock(levelScope) + if(clusterName != "") { + stream.println("\n"+indentationPlus1+"label = \"" + clusterName + "\";") + stream.println(indentationPlus1+"node [style=filled];") + stream.println(indentationPlus1+"color=red;") + stream.println(indentation+"}") + } + } + } + + override def traverseStmsInBlock[A](stms: List[Stm]): Unit = { + levelCounter+=1 + stms foreach traverseStm + levelCounter-=1 + } + + override def traverseStm(stm: Stm) = stm match { + case TP(sym, rhs) => rhs match { + case Reflect(s, u, effects) => stream.println(getGraphNodeString(sym, s+"", s)) + case Reify(s, u, effects) => // just ignore -- effects are accounted for in emitBlock + case _ => stream.println(getGraphNodeString(sym, rhs+"", rhs)) + } + case _ => throw new GenerationFailedException("don't know how to generate code for statement: " + stm) + } + + /** + * This method produces the symbol name for a block, given: + * - an Exp for block result + * - list of statements in block + * returns: + * - in the case that block result is a symbol => symbol name + * - in the case that block result is a constant => + * first symbol name in the statements list + */ + def findSymString(e: Exp[_], stms: List[Stm]): String = { + def findFirstSymString: String = { + stms match { + case head :: tail => findSymString(head.asInstanceOf[TP[_]].sym, tail) + case Nil => "" + } + } + + e match { + case Const(x) => findFirstSymString + case s@Sym(n) => quote(s, true) + } + } + + /** + * It is possible to print more meta-data in the node-label + * by overriding this method + */ + def getNodeLabel(s: Sym[_]): String = quote(s, true); + + /** + * utility method for generating proper indentation prefix + * for given level. + */ + private def indentString(level: Int): String = { + def indentString(level: Int, acc: String): String = { + if(level <= 0) acc + else indentString(level - 1 , acc + " ") + } + indentString(level, "") + } + + // emitValDef is not used in this code generator + def emitValDef(sym: Sym[Any], rhs: String): Unit = {} + + override def performTransformations[A:Manifest](body: Block[A]): Block[A] = { + val transformedBody = super.performTransformations[A](body) + val fixer = new SymMetaDataFixerTransform{ val IR: self.IR.type = self.IR } + fixer.traverseBlock(transformedBody.asInstanceOf[fixer.Block[A]]) + transformedBody + } + + override def remap(s: String): String = { + val rs = super.remap(s) + val lastDot = rs.lastIndexOf('.') + val len = rs.length + if(lastDot > 0 && lastDot+1 < len) { + rs.substring(lastDot+1, len) + } else { + rs + } + } +} diff --git a/src/internal/ScalaCodegen.scala b/src/internal/ScalaCodegen.scala index 2d108fd0..d084556f 100644 --- a/src/internal/ScalaCodegen.scala +++ b/src/internal/ScalaCodegen.scala @@ -5,12 +5,10 @@ import java.io.{File, FileWriter, PrintWriter} import scala.reflect.SourceContext -trait ScalaCodegen extends GenericCodegen with Config { +trait ScalaCodegen extends GenericCodegen { val IR: Expressions import IR._ - override def deviceTarget: Targets.Value = Targets.Scala - override def kernelFileExt = "scala" override def toString = "scala" @@ -21,10 +19,9 @@ trait ScalaCodegen extends GenericCodegen with Config { outFile.delete } - def emitSource[A : Manifest](args: List[Sym[_]], body: Block[A], className: String, out: PrintWriter) = { - - val sA = remap(manifest[A]) + def emitSource[A : Manifest](args: List[Sym[_]], body: Block[A], className: String, out: PrintWriter, dynamicReturnType: String = null, serializable: Boolean = false) = { + val sA = if (dynamicReturnType != null) dynamicReturnType else remap(manifest[A]) val staticData = getFreeDataBlock(body) withStream(out) { @@ -33,13 +30,15 @@ trait ScalaCodegen extends GenericCodegen with Config { "*******************************************/") emitFileHeader() + val transformedBody = performTransformations(body) + // TODO: separate concerns, should not hard code "pxX" name scheme for static data here - stream.println("class "+className+(if (staticData.isEmpty) "" else "("+staticData.map(p=>"p"+quote(p._1)+":"+p._1.tp).mkString(",")+")")+" extends (("+args.map(a => remap(a.tp)).mkString(", ")+")=>("+sA+")) {") - stream.println("def apply("+args.map(a => quote(a) + ":" + remap(a.tp)).mkString(", ")+"): "+sA+" = {") - - emitBlock(body) - stream.println(quote(getBlockResult(body))) - + stream.print("class "+className+(if (staticData.isEmpty) "" else "("+staticData.map(p=>"p"+quote(p._1)+":"+p._1.tp).mkString(",")+")")+" extends (("+args.map( a => remap(a.tp)).mkString(", ")+")=>("+sA+"))") + if (serializable) stream.println("with Serializable {") else stream.println(" {") + emitFunctions() + stream.println("def apply("+args.map(a => quote(a, true) + ":" + remap(a.tp)).mkString(", ")+"): "+sA+" = {") + emitBlock(transformedBody) + if (sA != "Unit") stream.println(quote(getBlockResult(transformedBody))) stream.println("}") stream.println("}") @@ -51,15 +50,13 @@ trait ScalaCodegen extends GenericCodegen with Config { staticData } - override def emitFileHeader() { - // empty by default. override to emit package or import declarations. - } - override def emitKernelHeader(syms: List[Sym[Any]], vals: List[Sym[Any]], vars: List[Sym[Any]], resultType: String, resultIsVar: Boolean, external: Boolean): Unit = { val kernelName = syms.map(quote).mkString("") + + stream.println("package generated." + this.toString) stream.println("object kernel_" + kernelName + " {") stream.print("def apply(") - stream.print(vals.map(p => quote(p) + ":" + remap(p.tp)).mkString(",")) + stream.print(vals.map(p => quote(p, true) + ":" + remap(p.tp)).mkString(",")) // variable name mangling if (vals.length > 0 && vars.length > 0){ @@ -67,7 +64,7 @@ trait ScalaCodegen extends GenericCodegen with Config { } // TODO: remap Ref instead of explicitly adding generated.scala if (vars.length > 0){ - stream.print(vars.map(v => quote(v) + ":" + "generated.scala.Ref[" + remap(v.tp) +"]").mkString(",")) + stream.print(vars.map(v => quote(v, true) + ":" + "generated.scala.Ref[" + remap(v.tp) +"]").mkString(",")) } if (resultIsVar){ stream.print("): " + "generated.scala.Ref[" + resultType + "] = {") @@ -91,28 +88,26 @@ trait ScalaCodegen extends GenericCodegen with Config { } def emitValDef(sym: Sym[Any], rhs: String): Unit = { - val extra = if ((sourceinfo < 2) || sym.pos.isEmpty) "" else { + val extra = if ((Config.sourceinfo < 2) || sym.pos.isEmpty) "" else { val context = sym.pos(0) " // " + relativePath(context.fileName) + ":" + context.line } - stream.println("val " + quote(sym) + " = " + rhs + extra) + if (sym.tp != manifest[Unit]) + stream.println("val " + quote(sym) + " = " + rhs + extra) + else + stream.println(rhs + extra) } def emitVarDef(sym: Sym[Variable[Any]], rhs: String): Unit = { stream.println("var " + quote(sym) + ": " + remap(sym.tp) + " = " + rhs) +// stream.println("var " + quote(sym) + " = " + rhs) } - override def emitVarDecl(sym: Sym[Any]): Unit = { - stream.println("var " + quote(sym) + ": " + remap(sym.tp) + " = null.asInstanceOf[" + remap(sym.tp) + "];") - } - - override def emitAssignment(sym: Sym[Any], rhs: String): Unit = { - stream.println(quote(sym) + " = " + rhs) - } + def emitAssignment(sym: Sym[Any], lhs: String, rhs: String): Unit = emitValDef(sym, lhs + " = " + rhs) } trait ScalaNestedCodegen extends GenericNestedCodegen with ScalaCodegen { - val IR: Expressions with Effects + val IR: Expressions with Effects with LoweringTransform import IR._ // emit forward decls for recursive vals @@ -122,7 +117,7 @@ trait ScalaNestedCodegen extends GenericNestedCodegen with ScalaCodegen { } def emitForwardDef(sym: Sym[Any]): Unit = { - stream.println("var " + quote(sym) + /*": " + remap(sym.tp) +*/ " = null.asInstanceOf[" + remap(sym.tp) + "]") + stream.println("var " + quote(sym, true) + /*": " + remap(sym.tp) + */ " = null.asInstanceOf[" + remap(sym.tp) + "]") } // special case for recursive vals @@ -137,7 +132,7 @@ trait ScalaNestedCodegen extends GenericNestedCodegen with ScalaCodegen { trait ScalaFatCodegen extends GenericFatCodegen with ScalaCodegen { - val IR: Expressions with Effects with FatExpressions + val IR: Expressions with Effects with FatExpressions with LoweringTransform import IR._ def emitKernelExtra(syms: List[Sym[Any]]): Unit = { diff --git a/src/internal/ScalaCompile.scala b/src/internal/ScalaCompile.scala index 78f69424..eeea9667 100644 --- a/src/internal/ScalaCompile.scala +++ b/src/internal/ScalaCompile.scala @@ -1,29 +1,40 @@ package scala.lms package internal -import java.io._ +import java.io.{StringWriter, PrintWriter} +import scala.lms.util._ +import scala.sys.process._ import scala.tools.nsc._ +import scala.tools.nsc.settings._ import scala.tools.nsc.util._ import scala.tools.nsc.reporters._ import scala.tools.nsc.io._ - import scala.tools.nsc.interpreter.AbstractFileClassLoader +import java.lang.management.ManagementFactory; +import javax.management.ObjectName; +import javax.management.openmbean.CompositeData; +import javax.management.openmbean.CompositeDataSupport; -trait ScalaCompile extends Expressions { - - val codegen: ScalaCodegen { val IR: ScalaCompile.this.type } - +object ScalaCompile { + var compileCount = 0 + var dumpGeneratedCode = false var compiler: Global = _ var reporter: ConsoleReporter = _ - //var output: ByteArrayOutputStream = _ + // From what I understand, this is not currently exported from the JVM, but it used internally. + // (To check, run java -XX:+PrintFlagsFinal -version | grep Huge) and check for the limit. + val maximumHugeMethodLimit = 8000 + // NOTE: Always disable these two flags when running the test suite + val byteCodeSizeCheckEnabled: Boolean =false + var cleanerEnabled: Boolean = false + val source = new StringWriter() + var writer = new PrintWriter(source) + val workingDir = System.getProperty("user.dir") + "/CompiledClasses" + var loader: AbstractFileClassLoader = null + lazy val comp = this.compiler def setupCompiler() = { - /* - output = new ByteArrayOutputStream() - val writer = new PrintWriter(new OutputStreamWriter(output)) - */ val settings = new Settings() val pathSeparator = System.getProperty("path.separator") @@ -36,59 +47,162 @@ trait ScalaCompile extends Expressions { case _ => System.getProperty("sun.boot.class.path") } settings.encoding.value = "UTF-8" - settings.outdir.value = "." + //settings.processArguments(List("-optimise", "-feature", "-deprecation", "-language:postfixOps", "-Yinline-warnings"), true) + settings.processArguments(List("-feature", "-deprecation", "-language:postfixOps", "-Yinline-warnings"), true) + // XX TR: do not optimize (save compile time!) + + // Create output directory if it does not exist + val f = new java.io.File(ScalaCompile.workingDir) + if (!f.exists) + f.mkdirs() + + settings.outdir.value = ScalaCompile.workingDir settings.extdirs.value = "" //settings.verbose.value = true // -usejavacp needed on windows? + ScalaCompile.loader = new AbstractFileClassLoader(AbstractFile.getDirectory(workingDir), this.getClass.getClassLoader) + ScalaCompile.reporter = new ConsoleReporter(settings, null, new PrintWriter(System.out))//writer + ScalaCompile.compiler = new Global(settings, ScalaCompile.reporter) + } + def reset() { + setupCompiler() + compileCount = 0 + dumpGeneratedCode = false + } + +} + +trait ScalaCompile extends Expressions { - reporter = new ConsoleReporter(settings, null, new PrintWriter(System.out))//writer - compiler = new Global(settings, reporter) + val codegen: ScalaCodegen { val IR: ScalaCompile.this.type } + + def initCompile = { + // System.out.println("Initializing compiler...") // This unfortunately + // breaks the test suite as well :-( + ScalaCompile.source.getBuffer().setLength(0) + val className = "staged" + ScalaCompile.compileCount + ScalaCompile.compileCount = ScalaCompile.compileCount + 1 + className } - var compileCount = 0 - - var dumpGeneratedCode = false + def checkByteCodeSize(className: String): Int = { + lazy val runtime: Runtime = Runtime.getRuntime(); + // Is compiling huge methods allowed? + val mserver = ManagementFactory.getPlatformMBeanServer(); + val name = new ObjectName("com.sun.management:type=HotSpotDiagnostic"); + val operationName = "getVMOption"; + val params = Array[Object]("DontCompileHugeMethods") + val signature = Array[String](classOf[String].getName()) + val result = mserver.invoke(name,operationName,params,signature).asInstanceOf[CompositeDataSupport].get("value") + // If yes, then check the size + if (result == "true") { + val cmd = Seq("javap","-classpath",ScalaCompile.workingDir,"-c",className) #| Seq("cut","-d:","-f1") #| Seq("sort","-n") #| Seq("tail","-n1") + val size = cmd.!!.trim.toInt + if (size > ScalaCompile.maximumHugeMethodLimit) { + println("\n\n|------------------------------------------------------------------------------------|") + println("| CATASTROPHIC ERROR ENCOUNTERED!!! YOUR CODE IS TOO BIG (" + size + ") TO BE COMPILED BY THE JVM |") + println("| AND WILL BE INTERPRETED INSTEAD. THIS WILL CAUSE A DRAMATIC PERFORMANCE DROP. |") + println("| THE DEVELOPERS WORRY ABOUT YOUR MENTAL HEALTH, AND CANNOT ALLOW YOU TO EXPERIENCE |") + println("| THAT. EXITING NOW! |") + println("| |") + println("| Note: You have two alternatives: |") + println("| \t(a) Refactor your code so that the generated code size is smaller.(advised) |") + println("| \t(b) Set JVM Option DontCompileHugeMethods to false and rerun (Not advised). |") + println("| -----------------------------------------------------------------------------------|") + System.exit(0) + } + return size; + } + -1; + } - def compile[A,B](f: Exp[A] => Exp[B])(implicit mA: Manifest[A], mB: Manifest[B]): A=>B = { - if (this.compiler eq null) - setupCompiler() - - val className = "staged$" + compileCount - compileCount += 1 - - val source = new StringWriter() - val writer = new PrintWriter(source) - val staticData = codegen.emitSource(f, className, writer) + def compileLoadClass(src: StringWriter, className: String) = { + if (ScalaCompile.compiler eq null) + ScalaCompile.setupCompiler() + if (ScalaCompile.dumpGeneratedCode) println(src) + + ScalaCompile.compiler.settings.outputDirs.setSingleOutput(AbstractFile.getDirectory(ScalaCompile.workingDir)) + val run = new ScalaCompile.comp.Run + var parsedsrc = src.toString + + if (ScalaCompile.cleanerEnabled) { + println("\n\n------------------------------------------------") + println("EXPERIMENTAL:: CODE BEFORE RUNNING CODEGEN CLEANER.\n" + parsedsrc) + parsedsrc = CodegenCleaner.clean(src.toString) + println("\n\n------------------------------------------------") + println("EXPERIMENTAL:: CODE AFTER RUNNING CODEGEN CLEANER\n" + parsedsrc) + } - codegen.emitDataStructures(writer) + run.compileSources(List(new util.BatchSourceFile("", parsedsrc))) - if (dumpGeneratedCode) println(source) + ScalaCompile.reporter.printSummary() + if (ScalaCompile.reporter.hasErrors) { + println("compilation of the following code had errors:") + println(src) + System.exit(0) + } + ScalaCompile.reporter.reset - val compiler = this.compiler - val run = new compiler.Run + if (ScalaCompile.byteCodeSizeCheckEnabled) { + val size = checkByteCodeSize(className) + if (size != -1) println("ByteCode size of the compiled code is: " + size) + } - val fileSystem = new VirtualDirectory("", None) - compiler.settings.outputDirs.setSingleOutput(fileSystem) - // compiler.genJVM.outputDir = fileSystem + val cls: Class[_] = ScalaCompile.loader.loadClass(className) + cls + } - run.compileSources(List(new util.BatchSourceFile("", source.toString))) - reporter.printSummary() + def compile0[B](f: () => Exp[B], dynamicReturnType: String = null)(implicit mB: Manifest[B]): () =>B = { + val className = initCompile + val staticData = codegen.emitSource0(f, className, ScalaCompile.writer, dynamicReturnType) + codegen.emitDataStructures(ScalaCompile.writer) + val cls = compileLoadClass(ScalaCompile.source, className) + val cons = cls.getConstructor(staticData.map(_._1.tp.erasure):_*) + cons.newInstance(staticData.map(_._2.asInstanceOf[AnyRef]):_*).asInstanceOf[()=>B] + } - if (!reporter.hasErrors) - println("compilation: ok") - else - println("compilation: had errors") + def compile1[A,B](f: Exp[A] => Exp[B])(implicit mA: Manifest[A], mB: Manifest[B]): A=>B = { + val className = initCompile + val staticData = codegen.emitSource1(f, className, ScalaCompile.writer) + codegen.emitDataStructures(ScalaCompile.writer) + val cls = compileLoadClass(ScalaCompile.source, className) + val cons = cls.getConstructor(staticData.map(_._1.tp.erasure):_*) + cons.newInstance(staticData.map(_._2.asInstanceOf[AnyRef]):_*).asInstanceOf[A=>B] + } - reporter.reset - //output.reset + def compile2[A1,A2,B](f: (Exp[A1],Exp[A2]) => Exp[B])(implicit mA1: Manifest[A1], mA2: Manifest[A2], mB: Manifest[B]): (A1,A2)=>B = { + val className = initCompile + val staticData = codegen.emitSource2(f, className, ScalaCompile.writer) + codegen.emitDataStructures(ScalaCompile.writer) + val cls = compileLoadClass(ScalaCompile.source, className) + val cons = cls.getConstructor(staticData.map(_._1.tp.erasure):_*) + cons.newInstance(staticData.map(_._2.asInstanceOf[AnyRef]):_*).asInstanceOf[(A1,A2)=>B] + } - val parent = this.getClass.getClassLoader - val loader = new AbstractFileClassLoader(fileSystem, this.getClass.getClassLoader) + def compile3[A1,A2,A3,B](f: (Exp[A1], Exp[A2], Exp[A3]) => Exp[B])(implicit mA1: Manifest[A1], mA2: Manifest[A2], mA3: Manifest[A3], mB: Manifest[B]): (A1,A2,A3)=>B = { + val className = initCompile + val staticData = codegen.emitSource3(f, className, ScalaCompile.writer) + codegen.emitDataStructures(ScalaCompile.writer) + val cls = compileLoadClass(ScalaCompile.source, className) + val cons = cls.getConstructor(staticData.map(_._1.tp.erasure):_*) + cons.newInstance(staticData.map(_._2.asInstanceOf[AnyRef]):_*).asInstanceOf[(A1,A2,A3)=>B] + } - val cls: Class[_] = loader.loadClass(className) + def compile4[A1,A2,A3,A4,B](f: (Exp[A1], Exp[A2], Exp[A3], Exp[A4]) => Exp[B])(implicit mA1: Manifest[A1], mA2: Manifest[A2], mA3: Manifest[A3], mA4: Manifest[A4], mB: Manifest[B]): (A1,A2,A3,A4)=>B = { + val className = initCompile + val staticData = codegen.emitSource4(f, className, ScalaCompile.writer) + codegen.emitDataStructures(ScalaCompile.writer) + val cls = compileLoadClass(ScalaCompile.source, className) + val cons = cls.getConstructor(staticData.map(_._1.tp.erasure):_*) + cons.newInstance(staticData.map(_._2.asInstanceOf[AnyRef]):_*).asInstanceOf[(A1,A2,A3,A4)=>B] + } + + def compile5[A1,A2,A3,A4,A5,B](f: (Exp[A1], Exp[A2], Exp[A3], Exp[A4], Exp[A5]) => Exp[B])(implicit mA1: Manifest[A1], mA2: Manifest[A2], mA3: Manifest[A3], mA4: Manifest[A4], mA5: Manifest[A5], mB: Manifest[B]): (A1,A2,A3,A4,A5)=>B = { + val className = initCompile + val staticData = codegen.emitSource5(f, className, ScalaCompile.writer) + codegen.emitDataStructures(ScalaCompile.writer) + val cls = compileLoadClass(ScalaCompile.source, className) val cons = cls.getConstructor(staticData.map(_._1.tp.erasure):_*) - - val obj: A=>B = cons.newInstance(staticData.map(_._2.asInstanceOf[AnyRef]):_*).asInstanceOf[A=>B] - obj + cons.newInstance(staticData.map(_._2.asInstanceOf[AnyRef]):_*).asInstanceOf[(A1,A2,A3,A4,A5)=>B] } -} \ No newline at end of file +} diff --git a/src/internal/ScalaConciseCodegen.scala b/src/internal/ScalaConciseCodegen.scala new file mode 100644 index 00000000..26ada79b --- /dev/null +++ b/src/internal/ScalaConciseCodegen.scala @@ -0,0 +1,100 @@ +package scala.lms +package internal + +import java.io.{File, FileWriter, PrintWriter} +import scala.lms.util.ReflectionUtil +import scala.reflect.SourceContext + +/** + * ScalaConciseCodegen is just an extension to ScalaCodegen + * which inlines expressions that are possible to inline, + * instead of creating a new val-def for each of them, leading + * to a more compact and concise code. + * + * @author Mohammad Dashti (mohammad.dashti@epfl.ch) + */ +trait ScalaConciseCodegen extends ScalaNestedCodegen { self => + val IR: ExtendedExpressions with Effects with LoweringTransform + import IR._ + + override def emitValDef(sym: Sym[Any], rhs: String): Unit = { + val extra = if ((Config.sourceinfo < 2) || sym.pos.isEmpty) "" else { + val context = sym.pos(0) + " // " + relativePath(context.fileName) + ":" + context.line + } + sym match { + case s@Sym(n) => isVoidType(s.tp) match { + case true => stream.println("" + rhs + extra) + case false => if(s.possibleToInline || s.noReference) { + stream.print("("+rhs+")") + } else { + stream.println("val " + quote(sym) + " = " + rhs + extra) + } + } + case _ => stream.println("val " + quote(sym) + " = " + rhs + extra) + } + } + + override def emitAssignment(sym: Sym[Any], lhs: String, rhs: String): Unit = { + // if(isVoidType(sym.tp)) { + stream.println(lhs + " = " + rhs) + // } else { + // emitValDef(sym, lhs + " = " + rhs) + // } + } + + override def emitForwardDef(sym: Sym[Any]): Unit = { + if(!isVoidType(sym.tp)) { stream.println("var " + quote(sym, true) + /*": " + remap(sym.tp) +*/ " = null.asInstanceOf[" + remap(sym.tp) + "]") } + } + + override def traverseStm(stm: Stm) = stm match { + case TP(sym, rhs) => if(!sym.possibleToInline && sym.refCount > 0 /*for eliminating read-only effect-ful statements*/) emitNode(sym,rhs) + case _ => throw new GenerationFailedException("don't know how to generate code for statement: " + stm) + } + + override def quote(x: Exp[Any], forcePrintSymbol: Boolean) : String = { + def printSym(s: Sym[Any]): String = { + if(s.possibleToInline || s.noReference) { + Def.unapply(s) match { + case Some(d: Def[Any]) => { + val strWriter: java.io.StringWriter = new java.io.StringWriter; + val stream = new PrintWriter(strWriter); + withStream(stream) { + emitNode(s, d) + } + strWriter.toString + } + case None => "x"+s.id + } + } else { + "x"+s.id + } + } + x match { + case Const(s: String) => "\""+s.replace("\"", "\\\"").replace("\n", "\\n")+"\"" // TODO: more escapes? + case Const(c: Char) => "'"+c+"'" + case Const(f: Float) => "%1.10f".format(f) + "f" + case Const(l: Long) => l.toString + "L" + case Const(null) => "null" + case Const(z) => z.toString + case s@Sym(n) => if (forcePrintSymbol) { + printSym(s) + } else { + isVoidType(s.tp) match { + case true => "(" + /*"x" + n +*/ ")" + case false => printSym(s) + } + } + case null => "null" + case _ => throw new RuntimeException("could not quote " + x) + } + } + + override def performTransformations[A:Manifest](body: Block[A]): Block[A] = { + val transformedBody = super.performTransformations[A](body) + val fixer = new SymMetaDataFixerTransform{ val IR: self.IR.type = self.IR } + fixer.traverseBlock(transformedBody.asInstanceOf[fixer.Block[A]]) + transformedBody + } + +} diff --git a/src/internal/Scheduling.scala b/src/internal/Scheduling.scala index 5308e04a..f6dbea64 100644 --- a/src/internal/Scheduling.scala +++ b/src/internal/Scheduling.scala @@ -14,6 +14,8 @@ trait Scheduling { getSchedule(scope)(result, false) } + // PERFORMANCE: 'intersect' calls appear to be a hotspot + // checks if a and b share at least one element. O(N^2), but with no allocation and possible early exit. def containsAny(a: List[Sym[Any]], b: List[Sym[Any]]): Boolean = { var aIter = a @@ -46,41 +48,19 @@ trait Scheduling { xx.flatten.reverse } - //performance hotspot! - //should be O(1) wrt 'scope' (nodes in graph), try to keep this as efficient as possible - protected def scheduleDepsWithIndex(syms: List[Sym[Any]], cache: IdentityHashMap[Sym[Any], (Stm,Int)]): List[Stm] = { - //syms.map(cache.get(_)).filter(_ ne null).distinct.sortBy(_._2).map(_._1) - val sortedSet = new java.util.TreeSet[(Stm,Int)]( - new java.util.Comparator[(Stm,Int)] { def compare(a:(Stm,Int), b:(Stm,Int)) = if (b._2 < a._2) -1 else if (b._2 == a._2) 0 else 1 } - ) - - for (sym <- syms) { - val stm = cache.get(sym) - if (stm ne null) sortedSet.add(stm) - } - - var res: List[Stm] = Nil - val iter = sortedSet.iterator //return stms in the original order given by 'scope' - while (iter.hasNext) { - res ::= iter.next._1 - } - res - } + //FIXME: hotspot + def getSchedule(scope: List[Stm])(result: Any, sort: Boolean = true): List[Stm] = { + val scopeCache = new mutable.HashMap[Sym[Any],Stm] + for (stm <- scope; s <- stm.lhs) + scopeCache(s) = stm - protected def buildScopeIndex(scope: List[Stm]): IdentityHashMap[Sym[Any], (Stm,Int)] = { - val cache = new IdentityHashMap[Sym[Any], (Stm,Int)] - var idx = 0 - for (stm <- scope) { - for (s <- stm.lhs) cache.put(s, (stm,idx)) //remember the original order of the stms - idx += 1 + def deps(st: List[Sym[Any]]): List[Stm] = {//st flatMap (scopeCache.get(_).toList) + // scope.filter(d => (st intersect d.lhs).nonEmpty) + // scope.filter(d => containsAny(st, d.lhs)) + st sortBy(_.id) flatMap (scopeCache.get(_).toList) } - cache - } - def getSchedule(scope: List[Stm])(result: Any, sort: Boolean = true): List[Stm] = { - val scopeIndex = buildScopeIndex(scope) - - val xx = GraphUtil.stronglyConnectedComponents[Stm](scheduleDepsWithIndex(syms(result), scopeIndex), t => scheduleDepsWithIndex(syms(t.rhs), scopeIndex)) + val xx = GraphUtil.stronglyConnectedComponents[Stm](deps(syms(result)), t => deps(syms(t.rhs))) if (sort) xx.foreach { x => if (x.length > 1) { printerr("warning: recursive schedule for result " + result + ": " + x) @@ -90,6 +70,7 @@ trait Scheduling { xx.flatten.reverse } + //FIXME: hotspot def getScheduleM(scope: List[Stm])(result: Any, cold: Boolean, hot: Boolean): List[Stm] = { def mysyms(st: Any) = { val db = symsFreq(st).groupBy(_._1).mapValues(_.map(_._2).sum).toList @@ -100,9 +81,17 @@ trait Scheduling { else db.withFilter(p=>p._2 > 0.75 && p._2 < 100.0).map(_._1) } - val scopeIndex = buildScopeIndex(scope) + val scopeCache = new mutable.HashMap[Sym[Any],Stm] + for (stm <- scope; s <- stm.lhs) + scopeCache(s) = stm + + def deps(st: List[Sym[Any]]): List[Stm] = {//st flatMap (scopeCache.get(_).toList) + // scope.filter(d => (st intersect d.lhs).nonEmpty) + // scope.filter(d => containsAny(st, d.lhs)) + st flatMap (scopeCache.get(_).toList) + } - GraphUtil.stronglyConnectedComponents[Stm](scheduleDepsWithIndex(mysyms(result), scopeIndex), t => scheduleDepsWithIndex(mysyms(t.rhs), scopeIndex)).flatten.reverse + GraphUtil.stronglyConnectedComponents[Stm](deps(mysyms(result)), t => deps(mysyms(t.rhs))).flatten.reverse } @@ -171,8 +160,6 @@ trait Scheduling { traverse syms by ascending id. if sym s1 is used by s2, do not evaluate further uses of s2 because they are already there. - CAVEAT: TRANSFORMERS !!! - assumption: if s2 uses s1, the scope of s2 is completely included in s1's scope: val A = loop { s1 => ... val B = sum { s2 => ... val y = s2 + s1; .../* use y */ ... } } @@ -186,7 +173,7 @@ trait Scheduling { def getDepStuff(st: Sym[Any]) = { // could also precalculate uses, but computing all combinations eagerly is also expensive def uses(s: Sym[Any]): List[Stm] = if (seen(s)) Nil else { - //seen += s + seen += s lhsCache.getOrElse(s,Nil) ::: symsCache.getOrElse(s,Nil) filterNot (boundSymsCache.getOrElse(st, Nil) contains _) } GraphUtil.stronglyConnectedComponents[Stm]( @@ -196,21 +183,13 @@ trait Scheduling { } /* - reference impl:*/ - val res = sts.flatMap(getDepStuff).distinct + reference impl: + sts.flatMap(getDepStuff).distinct + */ - /*if (sts.contains(Sym(1064))) { - println("dep on x1064:") - res.foreach { r => - println(" " + r) - } - }*/ - res - - // CAVEAT: TRANSFORMERS !!! see CloseWorldRestage app in Delite - //sts.sortBy(_.id).flatMap(getDepStuff) + sts.sortBy(_.id).flatMap(getDepStuff) } /** end performance hotspot **/ -} +} \ No newline at end of file diff --git a/src/internal/SymMetaDataFixerTransform.scala b/src/internal/SymMetaDataFixerTransform.scala new file mode 100644 index 00000000..1ce58761 --- /dev/null +++ b/src/internal/SymMetaDataFixerTransform.scala @@ -0,0 +1,88 @@ +package scala.lms +package internal + +import scala.lms.util.ReflectionUtil + +/** + * There are some meta-data added to Sym using infix + * operations in ExtendedExpressions. + * This trait fixes this properties, e.g parentBlock + * and refCount. + * + * This information are gathered by a single pass + * traversal over Exp graph. + */ +trait SymMetaDataFixerTransform extends NestedBlockTraversal { + val IR: ExtendedExpressions with Effects + import IR._ + + override def traverseBlockFocused[A](block: Block[A]): Unit = { + focusExactScope(block) { levelScope => + levelScope foreach { stm => stm match { + case TP(sym, rhs) => sym.setParentBlock(Some(block)) + case _ => + } + } + + traverseStmsInBlock(levelScope) + } + } + + override def traverseStm(stm: Stm): Unit = { // override this to implement custom traversal + stm match { + case TP(sym, rhs) => { + rhs match { + case Reflect(s, u, effects) => { + if(!mustOnlyRead(u) && !mustOnlyAlloc(u)) { + sym.incRefCount(100) + } else if(mustOnlyAlloc(u)) { + sym.incRefCount(1) + } + increaseRefCountsOnRhs(sym, s) + } + case Reify(s, u, effects) => { + sym.incRefCount(-1000) // just ignore -- effects are accounted for in emitBlock + s match { + case s@Sym(n) => s.incRefCount(1) + case Const(x) => + } + } + case rhs => increaseRefCountsOnRhs(sym, rhs) + } + } + case _ => + } + blocks(stm.rhs) foreach traverseBlock + } + + private def increaseRefCountsOnRhs[A](s: Exp[Any], rhs: Def[A]): Unit = { + val sym = s.asInstanceOf[Sym[Any]] + ReflectionUtil.caseNameTypeValues(rhs) foreach { + x => if(x._2 == classOf[Block[A]]) { + val blk: Block[A] = x._3.asInstanceOf[Block[A]] + blk.res match { + case s:Sym[_] => if(s.inSameParentBlockAs(sym)) { s.incRefCount(1) } else { s.incRefCount(10) } + case _ => + } + //transformBlock[Any](blk) + } else if(x._2 == classOf[Exp[A]]) { + x._3 match { + case s:Sym[_] => if(s.inSameParentBlockAs(sym)) { s.incRefCount(1) } else { s.incRefCount(10) } + case _ => + } + } else if (x._2 == classOf[Variable[A]]) { + if (x._3 != null) { + x._3.asInstanceOf[Variable[A]].e match { + case s:Sym[_] => if(s.inSameParentBlockAs(sym)) { s.incRefCount(1) } else { s.incRefCount(10) } + case _ => + } + } + } else { + syms(x._3).foreach { + s: Sym[Any] => if(s.inSameParentBlockAs(sym)) { s.incRefCount(1) } else { s.incRefCount(10) } + } + } + } + } + +} diff --git a/src/internal/Transforming.scala b/src/internal/Transforming.scala index bf074ba0..329c5cb0 100644 --- a/src/internal/Transforming.scala +++ b/src/internal/Transforming.scala @@ -4,6 +4,7 @@ package internal import util.OverloadHack import scala.collection.{immutable,mutable} import scala.reflect.SourceContext +import scala.lms.common.WorklistTransformer trait AbstractTransformer { val IR: Expressions with Blocks with OverloadHack @@ -22,9 +23,6 @@ trait AbstractTransformer { def apply[A](xs: Seq[Exp[A]]): Seq[Exp[A]] = xs map (e => apply(e)) def apply[X,A](f: X=>Exp[A]): X=>Exp[A] = (z:X) => apply(f(z)) def apply[X,Y,A](f: (X,Y)=>Exp[A]): (X,Y)=>Exp[A] = (z1:X,z2:Y) => apply(f(z1,z2)) - def apply[X,Y,Z,A](f: (X,Y,Z)=>Exp[A]): (X,Y,Z)=>Exp[A] = (z1:X,z2:Y,z3:Z) => apply(f(z1,z2,z3)) - def apply[W,X,Y,Z,A](f: (W,X,Y,Z)=>Exp[A]): (W,X,Y,Z)=>Exp[A] = (z1:W,z2:X,z3:Y,z4:Z) => apply(f(z1,z2,z3,z4)) - def apply[V,W,X,Y,Z,A](f: (V,W,X,Y,Z)=>Exp[A]): (V,W,X,Y,Z)=>Exp[A] = (z1:V,z2:W,z3:X,z4:Y,z5:Z) => apply(f(z1,z2,z3,z4,z5)) //def apply[A](xs: Summary): Summary = xs //TODO def onlySyms[A](xs: List[Sym[A]]): List[Sym[A]] = xs map (e => apply(e)) collect { case e: Sym[A] => e } @@ -100,3 +98,27 @@ trait FatTransforming extends Transforming with FatExpressions { //def mirror[A:Manifest](e: FatDef, f: Transformer): Exp[A] = sys.error("don't know how to mirror " + e) } + +/* Lewis: adapted from LMS TestWorklistTransform2.scala */ +trait LoweringTransform extends FatTransforming with Effects { self => + trait LoweringTransformer extends WorklistTransformer { val IR: self.type = self } + + // ---------- Exp api + implicit def toAfter[A:Manifest](x: Def[A]) = new { def atPhase(t: LoweringTransformer)(y: => Exp[A]) = transformAtPhase(x)(t)(y) } + implicit def toAfter[A](x: Exp[A]) = new { def atPhase(t: LoweringTransformer)(y: => Exp[A]) = transformAtPhase(x)(t)(y) } + + // transform x to y at the *next* iteration of t. + // note: if t is currently active, it will continue the current pass with x = x. + // do we need a variant that replaces x -> y immediately if t is active? + def transformAtPhase[A](x: Exp[A])(t: LoweringTransformer)(y: => Exp[A]): Exp[A] = { + t.register(x)(y) + x + } + + def onCreate[A:Manifest](s: Sym[A], d: Def[A]): Exp[A] = s + + override def createDefinition[T](s: Sym[T], d: Def[T]): Stm = { + onCreate(s,d)(s.tp) + super.createDefinition(s,d) + } +} diff --git a/src/internal/Traversal.scala b/src/internal/Traversal.scala index 36e805fb..db39a7de 100644 --- a/src/internal/Traversal.scala +++ b/src/internal/Traversal.scala @@ -55,7 +55,7 @@ trait NestedGraphTraversal extends GraphTraversal with CodeMotion { rval = body } catch { - case e => throw e + case e: Throwable => throw e } finally { innerScope = saveInner @@ -79,9 +79,27 @@ trait NestedGraphTraversal extends GraphTraversal with CodeMotion { // strong order for levelScope (as obtained by code motion), taking care of recursive dependencies. def getStronglySortedSchedule2(scope: List[Stm], level: List[Stm], result: Any): (List[Stm], List[Sym[Any]]) = { - val scopeIndex = buildScopeIndex(scope) + import util.GraphUtil + import scala.collection.{mutable,immutable} + + val scopeCache = new mutable.HashMap[Sym[Any],Stm] + for (stm <- scope; s <- stm.lhs) + scopeCache(s) = stm + + //TR: wip! + + def deps(st: List[Sym[Any]]): List[Stm] = //st flatMap (scopeCache.get(_).toList) + { + val l1 = st sortBy(_.id) flatMap (scopeCache.get(_).toList) distinct; // need distinc?? + /*val l2 = scope.filter(d => (st intersect d.lhs).nonEmpty) sortBy(_.lhs.intersec(st).map(_.id).min) + if (l1 != l2) { + println("l1: " + l1) + println("l2: " + l2) + }*/ + l1 + } - val fixed = new collection.mutable.HashMap[Any,List[Sym[Any]]] + val fixed = new mutable.HashMap[Any,List[Sym[Any]]] def allSyms(r: Any) = fixed.getOrElse(r, syms(r) ++ softSyms(r)) @@ -89,7 +107,7 @@ trait NestedGraphTraversal extends GraphTraversal with CodeMotion { var recursive: List[Sym[Any]] = Nil - var xx = GraphUtil.stronglyConnectedComponents[Stm](scheduleDepsWithIndex(allSyms(result), scopeIndex), t => scheduleDepsWithIndex(allSyms(t.rhs), scopeIndex)) + var xx = GraphUtil.stronglyConnectedComponents[Stm](deps(allSyms(result)), t => deps(allSyms(t.rhs))) xx.foreach { xs => if (xs.length > 1 && (xs intersect level).nonEmpty) { printdbg("warning: recursive schedule for result " + result + ": " + xs) @@ -121,7 +139,7 @@ trait NestedGraphTraversal extends GraphTraversal with CodeMotion { } } } - xx = GraphUtil.stronglyConnectedComponents[Stm](scheduleDepsWithIndex(allSyms(result) ++ allSyms(recursive), scopeIndex), t => scheduleDepsWithIndex(allSyms(t.rhs), scopeIndex)) + xx = GraphUtil.stronglyConnectedComponents[Stm](deps(allSyms(result) ++ allSyms(recursive)), t => deps(allSyms(t.rhs))) xx.foreach { xs => if (xs.length > 1 && (xs intersect level).nonEmpty) { // see test5-schedfun. since we're only returning level scope (not inner) diff --git a/src/internal/Utils.scala b/src/internal/Utils.scala index d107bf38..7736176f 100644 --- a/src/internal/Utils.scala +++ b/src/internal/Utils.scala @@ -2,14 +2,11 @@ package scala.lms package internal // TODO: add logging, etc. -trait Utils extends Config { +trait Utils { def __ = throw new RuntimeException("unsupported embedded dsl operation") - def printdbg(x: =>Any) { if (verbosity >= 2) System.err.println(x) } - def printlog(x: =>Any) { if (verbosity >= 1) System.err.println(x) } - def printerr(x: =>Any) { System.err.println(x); hadErrors = true } - - def printsrc(x: =>Any) { if (sourceinfo >= 1) System.err.println(x) } - - var hadErrors = false -} \ No newline at end of file + def printdbg(x: =>Any) { if (Config.verbosity >= 2) System.err.println(x) } + def printlog(x: =>Any) { if (Config.verbosity >= 1) System.err.println(x) } + def printerr(x: =>Any) { System.err.println(x); } + def printsrc(x: =>Any) { if (Config.sourceinfo >= 1) System.err.println(x) } +} diff --git a/src/util/GencodeCleaner.scala b/src/util/GencodeCleaner.scala new file mode 100644 index 00000000..0205b4d8 --- /dev/null +++ b/src/util/GencodeCleaner.scala @@ -0,0 +1,117 @@ +package scala.lms.util + +import java.util.regex.Pattern +import scala.collection.mutable.ListBuffer + +object CodegenCleaner { + val pattern1 = Pattern.compile("val x[0-9]* = x[0-9]*$") + val pattern2 = Pattern.compile("var x[0-9]* = ") + val pattern3 = Pattern.compile("val x[0-9]* =") + val pattern4 = Pattern.compile("var x[0-9]* = x[0-9]*$") + val pattern5 = Pattern.compile("x[0-9]*$") + + def clean(src: String) = { + var lines = src.split("\n").map(x => x.trim) + // Extract variables + val variables = lines.filter(x => pattern2.matcher(x).find).map(x => { + val y = x.split("=") + ( y(0).replaceAll("var ","").trim, y(1).trim ) + }).sortBy(x => x._1) + // Extract values + val values = lines.filter(x => pattern3.matcher(x).find).map(x => { + val y = x.split("=") + ( y(0).replaceAll("val ","").trim, y(1).trim ) + }).sortBy(x => x._1) + // Extract "val x = x" lines + var resList = new ListBuffer[(String,String)]() + lines = lines.map(x => + if (pattern1.matcher(x).find) { + val y = x.split("=") + val valId = y(0).replaceAll("val ","").trim + val lhs = y(1).trim + variables.find(z => z._1 == lhs) match { + case Some(w) => { + resList += new Tuple2[String,String](valId,lhs) + "" + } + case None => x + } + } else x + ) +// println("PHASE 1A DONE") + // Now remove all references to this val + resList.foreach( res => { + lines = lines.map( line => line.replaceAll(res._1 + "\\.", res._2 + ".").replaceAll(res._1 + "$",res._2).replaceAll(res._1 + " ", res._2 + " ").replaceAll(res._1 + "\\(",res._2 + "(").replaceAll(res._1 + "\\+",res._2 + "+").replaceAll(res._1 + "\\)",res._2 + ")").replaceAll("= __" + res._1 + "Size", "= __" + res._2 + "Size").replaceAll("^__" + res._1 + "Size", "__" + res._2 + "Size").replaceAll("< __" + res._1 + "Size", "< __" + res._2 + "Size").replaceAll("__" + res._1 + "Indices", "__" + res._2 + "Indices").replaceAll("__" + res._1 + "LastIndex", "__" + res._2 + "LastIndex") ) + }) + + // Extract "var x = x" lines + resList = new ListBuffer[(String,String)]() + lines = lines.map(x => + if (pattern4.matcher(x).find) { + val y = x.split("=") + val valId = y(0).replaceAll("var ","").trim + val lhs = y(1).trim + variables.find(z => z._1 == lhs) match { + case Some(w) => { + resList += new Tuple2[String,String](valId,lhs) + "" + } + case None => x + } + } else x + ) +// println("PHASE 1A DONE") + // Now remove all references to this val + resList.foreach( res => { + lines = lines.map( line => line.replaceAll(res._1 + "\\.", res._2 + ".").replaceAll(res._1 + "$",res._2).replaceAll(res._1 + " ", res._2 + " ").replaceAll(res._1 + "\\(",res._2 + "(").replaceAll(res._1 + "\\+",res._2 + "+").replaceAll(res._1 + "\\)",res._2 + ")").replaceAll("= __" + res._1 + "Size", "= __" + res._2 + "Size").replaceAll("^__" + res._1 + "Size", "__" + res._2 + "Size").replaceAll("< __" + res._1 + "Size", "< __" + res._2 + "Size").replaceAll("__" + res._1 + "Indices", "__" + res._2 + "Indices").replaceAll("__" + res._1 + "LastIndex", "__" + res._2 + "LastIndex") ) + }) + + +// println("PHASE 1B DONE") + // CASE 2 + resList = new ListBuffer[(String,String)]() + lines = lines.map( line => { + if (pattern4.matcher(line).find) { + val y = line.split("=") + val varId = y(0).replaceAll("var ","").trim + val lhs = y(1).trim + values.find(z => z._1 == lhs) match { + case Some(w) => { + resList += new Tuple2[String,String](varId,lhs) + "" + } + case None => line + } + } else line + }) +// println("PHASE 2A DONE") + // Now change the val to var + resList.foreach( res => { + lines = lines.map( line => line.replaceAll("val " + res._2 + " ","var " + res._1 + " ")) + }) + +resList.foreach( res => { + lines = lines.map( line => line.replaceAll(res._1 + "\\.", res._2 + ".").replaceAll(res._1 + "$",res._2).replaceAll(res._1 + " ", res._2 + " ").replaceAll(res._1 + "\\(",res._2 + "(").replaceAll(res._1 + "\\+",res._2 + "+").replaceAll(res._1 + "\\)",res._2 + ")").replaceAll("= __" + res._1 + "Size", "= __" + res._2 + "Size").replaceAll("^__" + res._1 + "Size", "__" + res._2 + "Size").replaceAll("< __" + res._1 + "Size", "< __" + res._2 + "Size").replaceAll("__" + res._1 + "Indices", "__" + res._2 + "Indices").replaceAll("__" + res._1 + "LastIndex", "__" + res._2 + "LastIndex") ) + }) + + +// println("PHASE 2B DONE") + // CASE 3 + for (i <- 1 to lines.length - 1) { + if (lines(i).matches("x[0-9]*$") && pattern3.matcher(lines(i-1)).find) { + if (lines(i-1).startsWith("val " + lines(i))) { + val lhs = lines(i-1).split("=").drop(1).mkString("=") + lines(i-1) = lhs + lines(i) = "" + } + } + } +// println("PHASE 3 DONE") + + // print result + lines = lines.filter(x => x!= "" && x!="()") +// println("PHASE 4 DONE") + lines.mkString("\n") + } +} + diff --git a/src/util/GraphUtil.scala b/src/util/GraphUtil.scala index 1070bd66..96324b71 100644 --- a/src/util/GraphUtil.scala +++ b/src/util/GraphUtil.scala @@ -1,7 +1,11 @@ package scala.lms package util -import java.util.{ArrayDeque, HashMap} +import scala.collection.mutable.Map +import scala.collection.mutable.HashMap +import scala.collection.mutable.Stack +import scala.collection.mutable.Buffer +import scala.collection.mutable.ArrayBuffer object GraphUtil { @@ -29,27 +33,27 @@ object GraphUtil { def stronglyConnectedComponents[T](start: List[T], succ: T=>List[T]): List[List[T]] = { val id: Ref[Int] = new Ref(0) - val stack = new ArrayDeque[T] - val mark = new HashMap[T,Int] + val stack: Stack[T] = new Stack() + val mark: Map[T,Int] = new HashMap() - val res = new Ref[List[List[T]]](Nil) + val res: Buffer[Buffer[T]] = new ArrayBuffer() for (node <- start) visit(node,succ,id,stack,mark,res) - res.value + // TODO: get rid of reverse + + (for (scc <- res) yield scc.toList.reverse).toList.reverse } - def visit[T](node: T, succ: T=>List[T], id: Ref[Int], stack: ArrayDeque[T], - mark: HashMap[T,Int], res: Ref[List[List[T]]]): Int = { + def visit[T](node: T, succ: T=>List[T], id: Ref[Int], stack: Stack[T], + mark: Map[T,Int], res: Buffer[Buffer[T]]): Int = { + + mark.getOrElse(node, { - - if (mark.containsKey(node)) - mark.get(node) - else { id.value = id.value + 1 mark.put(node, id.value) - stack.addFirst(node) + stack.push(node) // println("push " + node) var min: Int = id.value @@ -60,20 +64,23 @@ object GraphUtil { min = m } - if (min == mark.get(node)) { - var scc: List[T] = Nil + if (min == mark(node)) { + + val scc: Buffer[T] = new ArrayBuffer() var loop: Boolean = true do { - val element = stack.removeFirst() + val element = stack.pop() // println("appending " + element) - scc ::= element + scc.append(element) mark.put(element, Integer.MAX_VALUE) loop = element != node } while (loop) - res.value ::= scc + res.append(scc) } min - } + + }) } + } diff --git a/src/util/OverloadHack.scala b/src/util/OverloadHack.scala index 8402e093..00cbe582 100644 --- a/src/util/OverloadHack.scala +++ b/src/util/OverloadHack.scala @@ -3,107 +3,31 @@ package util // hack to appease erasure -trait OverloadHack { - class Overloaded1 - class Overloaded2 - class Overloaded3 - class Overloaded4 - class Overloaded5 - class Overloaded6 - class Overloaded7 - class Overloaded8 - class Overloaded9 - class Overloaded10 - class Overloaded11 - class Overloaded12 - class Overloaded13 - class Overloaded14 - class Overloaded15 - class Overloaded16 - class Overloaded17 - class Overloaded18 - class Overloaded19 - class Overloaded20 - class Overloaded21 - class Overloaded22 - class Overloaded23 - class Overloaded24 - class Overloaded25 - class Overloaded26 - class Overloaded27 - class Overloaded28 - class Overloaded29 - class Overloaded30 - class Overloaded31 - class Overloaded32 - class Overloaded33 - class Overloaded34 - class Overloaded35 - class Overloaded36 - class Overloaded37 - class Overloaded38 - class Overloaded39 - class Overloaded40 - class Overloaded41 - class Overloaded42 - class Overloaded43 - class Overloaded44 - class Overloaded45 - class Overloaded46 - class Overloaded47 - class Overloaded48 - class Overloaded49 - class Overloaded50 - class Overloaded51 - class Overloaded52 - class Overloaded53 - class Overloaded54 - class Overloaded55 - class Overloaded56 - class Overloaded57 - class Overloaded58 - class Overloaded59 - class Overloaded60 - class Overloaded61 - class Overloaded62 - class Overloaded63 - class Overloaded64 - class Overloaded65 - class Overloaded66 - class Overloaded67 - class Overloaded68 - class Overloaded69 - class Overloaded70 - class Overloaded71 - class Overloaded72 - class Overloaded73 - class Overloaded74 - class Overloaded75 - class Overloaded76 - class Overloaded77 - class Overloaded78 - class Overloaded79 - class Overloaded80 - class Overloaded81 - class Overloaded82 - class Overloaded83 - class Overloaded84 - class Overloaded85 - class Overloaded86 - class Overloaded87 - class Overloaded88 - class Overloaded89 - class Overloaded90 - class Overloaded91 - class Overloaded92 - class Overloaded93 - class Overloaded94 - class Overloaded95 - class Overloaded96 - class Overloaded97 - class Overloaded98 - class Overloaded99 - class Overloaded100 +trait OverloadHack extends Serializable { + class Overloaded1 extends Serializable + class Overloaded2 extends Serializable + class Overloaded3 extends Serializable + class Overloaded4 extends Serializable + class Overloaded5 extends Serializable + class Overloaded6 extends Serializable + class Overloaded7 extends Serializable + class Overloaded8 extends Serializable + class Overloaded9 extends Serializable + class Overloaded10 extends Serializable + class Overloaded11 extends Serializable + class Overloaded12 extends Serializable + class Overloaded13 extends Serializable + class Overloaded14 extends Serializable + class Overloaded15 extends Serializable + class Overloaded16 extends Serializable + class Overloaded17 extends Serializable + class Overloaded18 extends Serializable + class Overloaded19 extends Serializable + class Overloaded20 extends Serializable + class Overloaded21 extends Serializable + class Overloaded22 extends Serializable + class Overloaded23 extends Serializable + class Overloaded24 extends Serializable implicit val overloaded1 = new Overloaded1 implicit val overloaded2 = new Overloaded2 @@ -129,81 +53,4 @@ trait OverloadHack { implicit val overloaded22 = new Overloaded22 implicit val overloaded23 = new Overloaded23 implicit val overloaded24 = new Overloaded24 - implicit val overloaded25 = new Overloaded25 - implicit val overloaded26 = new Overloaded26 - implicit val overloaded27 = new Overloaded27 - implicit val overloaded28 = new Overloaded28 - implicit val overloaded29 = new Overloaded29 - implicit val overloaded30 = new Overloaded30 - implicit val overloaded31 = new Overloaded31 - implicit val overloaded32 = new Overloaded32 - implicit val overloaded33 = new Overloaded33 - implicit val overloaded34 = new Overloaded34 - implicit val overloaded35 = new Overloaded35 - implicit val overloaded36 = new Overloaded36 - implicit val overloaded37 = new Overloaded37 - implicit val overloaded38 = new Overloaded38 - implicit val overloaded39 = new Overloaded39 - implicit val overloaded40 = new Overloaded40 - implicit val overloaded41 = new Overloaded41 - implicit val overloaded42 = new Overloaded42 - implicit val overloaded43 = new Overloaded43 - implicit val overloaded44 = new Overloaded44 - implicit val overloaded45 = new Overloaded45 - implicit val overloaded46 = new Overloaded46 - implicit val overloaded47 = new Overloaded47 - implicit val overloaded48 = new Overloaded48 - implicit val overloaded49 = new Overloaded49 - implicit val overloaded50 = new Overloaded50 - implicit val overloaded51 = new Overloaded51 - implicit val overloaded52 = new Overloaded52 - implicit val overloaded53 = new Overloaded53 - implicit val overloaded54 = new Overloaded54 - implicit val overloaded55 = new Overloaded55 - implicit val overloaded56 = new Overloaded56 - implicit val overloaded57 = new Overloaded57 - implicit val overloaded58 = new Overloaded58 - implicit val overloaded59 = new Overloaded59 - implicit val overloaded60 = new Overloaded60 - implicit val overloaded61 = new Overloaded61 - implicit val overloaded62 = new Overloaded62 - implicit val overloaded63 = new Overloaded63 - implicit val overloaded64 = new Overloaded64 - implicit val overloaded65 = new Overloaded65 - implicit val overloaded66 = new Overloaded66 - implicit val overloaded67 = new Overloaded67 - implicit val overloaded68 = new Overloaded68 - implicit val overloaded69 = new Overloaded69 - implicit val overloaded70 = new Overloaded70 - implicit val overloaded71 = new Overloaded71 - implicit val overloaded72 = new Overloaded72 - implicit val overloaded73 = new Overloaded73 - implicit val overloaded74 = new Overloaded74 - implicit val overloaded75 = new Overloaded75 - implicit val overloaded76 = new Overloaded76 - implicit val overloaded77 = new Overloaded77 - implicit val overloaded78 = new Overloaded78 - implicit val overloaded79 = new Overloaded79 - implicit val overloaded80 = new Overloaded80 - implicit val overloaded81 = new Overloaded81 - implicit val overloaded82 = new Overloaded82 - implicit val overloaded83 = new Overloaded83 - implicit val overloaded84 = new Overloaded84 - implicit val overloaded85 = new Overloaded85 - implicit val overloaded86 = new Overloaded86 - implicit val overloaded87 = new Overloaded87 - implicit val overloaded88 = new Overloaded88 - implicit val overloaded89 = new Overloaded89 - implicit val overloaded90 = new Overloaded90 - implicit val overloaded91 = new Overloaded91 - implicit val overloaded92 = new Overloaded92 - implicit val overloaded93 = new Overloaded93 - implicit val overloaded94 = new Overloaded94 - implicit val overloaded95 = new Overloaded95 - implicit val overloaded96 = new Overloaded96 - implicit val overloaded97 = new Overloaded97 - implicit val overloaded98 = new Overloaded98 - implicit val overloaded99 = new Overloaded99 - implicit val overloaded100 = new Overloaded100 - -} \ No newline at end of file +} diff --git a/src/util/ReflectionUtil.scala b/src/util/ReflectionUtil.scala new file mode 100644 index 00000000..f6cee792 --- /dev/null +++ b/src/util/ReflectionUtil.scala @@ -0,0 +1,30 @@ +package scala.lms +package util + +/** + * An object for reflection related utility methods + */ +object ReflectionUtil { + /** + * This method accepts an instance of a case class and returns + * the list of its fields. + * Each entry in the returned list is a tripple: + * - field name + * - field type + * - field value + */ + def caseNameTypeValues(a: AnyRef) = { + /** + * returns number of parameters for the first constructor of an object + */ + def numConstructorParams(a: AnyRef) = a.getClass.getConstructors()(0).getParameterTypes.size + /** + * returns list of fields in an instance of a case class + */ + def caseFields(a: AnyRef) = a.getClass.getDeclaredFields.toSeq.filterNot(_.isSynthetic).take(numConstructorParams(a)).map{field => + field.setAccessible(true) + field + } + caseFields(a).map{field => (field.getName, field.getType, field.get(a))} + } +} \ No newline at end of file diff --git a/src/util/Timing.scala b/src/util/Timing.scala new file mode 100644 index 00000000..ae37c57c --- /dev/null +++ b/src/util/Timing.scala @@ -0,0 +1,100 @@ +package scala.lms +package util + +import scala.reflect.{SourceContext, RefinedManifest} +import scala.lms.common._ +import scala.lms.internal._ + +trait Timing extends Base { + def timeGeneratedCode[A: Manifest](f: => Rep[A], msg: Rep[String] = unit("")): Rep[A] +} + +trait TimingExp extends BaseExp with EffectExp { + case class TimeGeneratedCode[A: Manifest](start: Exp[Long], end: Exp[Long], f: Block[A], msg: Rep[String] = unit("")) extends Def[A] { + val diff = fresh[Long] + } + + def timeGeneratedCode[A: Manifest](f: => Rep[A], msg: Rep[String] = unit("")) = { + val b = reifyEffects(f) + val start = fresh[Long] + val end = fresh[Long] + reflectEffect(TimeGeneratedCode[A](start, end, b, msg), summarizeEffects(b).star) + } + + override def syms(e: Any): List[Sym[Any]] = e match { + case TimeGeneratedCode(a, x, body, msg) => syms(body) + case _ => super.syms(e) + } + + override def boundSyms(e: Any): List[Sym[Any]] = e match { + case TimeGeneratedCode(a, x, body, msg) => effectSyms(body) + case _ => super.boundSyms(e) + } + + override def symsFreq(e: Any): List[(Sym[Any], Double)] = e match { + case TimeGeneratedCode(a, x, body, msg) => freqHot(body) + case _ => super.symsFreq(e) + } + + override def mirror[A:Manifest](e: Def[A], f: Transformer)(implicit pos: SourceContext): Exp[A] = (e match { + case Reflect(TimeGeneratedCode(s,e,body,msg),u,ef) => reflectMirrored(Reflect(TimeGeneratedCode(f(s),f(e),f(body),f(msg)), mapOver(f,u), f(ef))) + case _ => super.mirror(e,f) + }).asInstanceOf[Exp[A]] +} + +trait ScalaGenTiming extends ScalaGenBase with GenericNestedCodegen { + val IR: TimingExp + import IR._ + + override def emitNode(sym: Sym[Any], rhs: Def[Any]) = { + rhs match { + case TimeGeneratedCode(start, end, f, msg) => { + stream.println("val " + quote(start) + " = System.nanoTime") + if (sym.tp != manifest[Unit]) + stream.print("val " + quote(sym) + " = { ") + emitBlock(f) + stream.println(quote(getBlockResult(f))) + if (sym.tp != manifest[Unit]) + stream.println("}") + stream.println("val " + quote(end) + " = System.nanoTime") + stream.print("System.out.println(\"Generated Code Profiling Info: Operation " + quote(msg).replaceAll("\"","") + " completed") + val calcStr = "((" + quote(end) + "-" + quote(start) + ")/(1000*1000))" + stream.println(" in \" + " + calcStr + " + \" milliseconds\")") + } + case _ => super.emitNode(sym, rhs) + } + } +} + + +trait CGenTiming extends CGenBase with GenericNestedCodegen { + val IR: TimingExp + import IR._ + + override def lowerNode[T:Manifest](sym: Sym[T], rhs: Def[T]) = rhs match { + case TimeGeneratedCode(start, end, f, msg) => { + LIRTraversal(f) + sym.atPhase(LIRLowering) { + reflectEffect(TimeGeneratedCode(start, end, LIRLowering(f), msg)).asInstanceOf[Exp[T]] + } + } + case _ => super.lowerNode(sym, rhs) + } + + override def emitNode(sym: Sym[Any], rhs: Def[Any]) = { + rhs match { + case t@TimeGeneratedCode(start, end, f, msg) => { + stream.println("struct timeval " + quote(start) + ", " + quote(end) + ", " + quote(t.diff) + ";") + stream.println("gettimeofday(&" + quote(start) + ", NULL);") + emitBlock(f) + stream.println("gettimeofday(&" + quote(end) + ", NULL);") + stream.println("timeval_subtract(&" + quote(t.diff) + ", &" + quote(end) + ", &" + quote(start) + ");") + + stream.print("fprintf(stderr,\"Generated Code Profiling Info: Operation completed in %ld milliseconds\\n\",") + stream.print("((" + quote(t.diff) + ".tv_sec * 1000) + (" +quote(t.diff) + ".tv_usec/1000))") + stream.println(");") + } + case _ => super.emitNode(sym, rhs) + } + } +} diff --git a/test-out/epfl/test1-constcse1.check b/test-out/epfl/test1-constcse1.check new file mode 100644 index 00000000..1a307830 --- /dev/null +++ b/test-out/epfl/test1-constcse1.check @@ -0,0 +1,26 @@ +/***************************************** + Emitting Generated Code +*******************************************/ +class test1 extends ((Boolean, Long)=>(Long)) { +def apply(x0:Boolean, x1:Long): Long = { +val x3 = if (x0) { +1L +} else { +0L +} +val x4 = x1 + x3 +val x5 = x4 + 133L +val x2 = if (x0) { +1.0 +} else { +0.0 +} +val x6 = x2.asInstanceOf[Long] +val x7 = x5 + x6 +x7 +} +} +/***************************************** + End of Generated Code +*******************************************/ +147 diff --git a/test-out/epfl/test15-generator-array.check b/test-out/epfl/test15-generator-array.check new file mode 100644 index 00000000..de7ace41 --- /dev/null +++ b/test-out/epfl/test15-generator-array.check @@ -0,0 +1,103 @@ +/***************************************** + Emitting Generated Code +*******************************************/ +class testMul extends ((Array[scala.Tuple2[Int, Int]], Int)=>(Unit)) { +def apply(x0:Array[scala.Tuple2[Int, Int]], x1:Int): Unit = { +val x2 = x0.length +val x3 = x2 + 1 +val x4 = x3 * x3 +val x5 = new Array[Recordintintint](x4) +var x7 : Int = 1 +while (x7 < x3) { +val x8 = x3 - x7 +var x10 : Int = 0 +while (x10 < x8) { +var x15 = 0 +var x16 = 0 +var x17 = 10000 +val x11 = x10 + x7 +val x12 = x10 + 1 +val x19 = x12 == x11 +if (x19) { +val x24 = x15 +val x25 = x16 +val x26 = x17 +val x28 = 0 < x26 +if (x28) { +val x20 = x0(x10) +val x21 = x20._1 +x15 = x21 +val x22 = x20._2 +x16 = x22 +x17 = 0 +() +} else { +() +} + +} else { +() +} +val x13 = x12 < x11 +val x38 = x10 * x3 +if (x13) { +var x37 : Int = x12 +while (x37 < x11) { +val x39 = x38 + x37 +val x40 = x5(x39) +val x41 = x37 * x3 +val x42 = x41 + x11 +val x43 = x5(x42) +val x44 = (x40,x43) +val x45 = x44._1 +val x46 = x44._2 +val x47 = x45.rows +val x48 = x46.cols +val x49 = x45.mults +val x50 = x46.mults +val x52 = x45.cols +val x57 = x15 +val x58 = x16 +val x59 = x17 +val x51 = x49 + x50 +val x53 = x47 * x52 +val x54 = x53 * x48 +val x55 = x51 + x54 +val x61 = x55 < x59 +if (x61) { +x15 = x47 +x16 = x48 +x17 = x55 +() +} else { +() +} + +x37 = x37 + 1 +} + +} else { +() +} +val x72 = x15 +val x73 = x16 +val x74 = x17 +val x71 = x38 + x11 +val x75 = Recordintintint(rows = x72, cols = x73, mults = x74) +x5(x71) = x75 + +x10 = x10 + 1 +} + +x7 = x7 + 1 +} +val x81 = 0 * x3 +val x82 = x81 + x2 +val x83 = x5(x82) +println(x83) +} +} +/***************************************** + End of Generated Code +*******************************************/ +case class Recordintintint(rows: Int, cols: Int, mults: Int) diff --git a/test-out/epfl/test15-generator-simple.check b/test-out/epfl/test15-generator-simple.check new file mode 100644 index 00000000..b755a7ee --- /dev/null +++ b/test-out/epfl/test15-generator-simple.check @@ -0,0 +1,374 @@ +/***************************************** + Emitting Generated Code +*******************************************/ +class test1 extends ((Int, Int)=>(Int)) { +def apply(x0:Int, x1:Int): Int = { +var x3 = 0 +val x2 = x0 < x1 +if (x2) { +var x5 : Int = x0 +while (x5 < x1) { +x3 = x5 + +x5 = x5 + 1 +} + +} else { +() +} +val x11 = x3 +x11 +} +} +/***************************************** + End of Generated Code +*******************************************/ +/***************************************** + Emitting Generated Code +*******************************************/ +class test2 extends ((Int, Int)=>(Int)) { +def apply(x13:Int, x14:Int): Int = { +var x16 = 0 +val x15 = x13 < x14 +if (x15) { +var x18 : Int = x13 +while (x18 < x14) { +val x19 = x18 * 2 +x16 = x19 + +x18 = x18 + 1 +} + +} else { +() +} +val x25 = x16 +x25 +} +} +/***************************************** + End of Generated Code +*******************************************/ +20 +/***************************************** + Emitting Generated Code +*******************************************/ +class test3 extends ((Int, Int)=>(Int)) { +def apply(x41:Int, x42:Int): Int = { +var x44 = 0 +val x43 = x41 < x42 +if (x43) { +var x46 : Int = x41 +while (x46 < x42) { +val x47 = x44 +val x48 = x47 + x46 +x44 = x48 + +x46 = x46 + 1 +} + +} else { +() +} +val x54 = x44 +x54 +} +} +/***************************************** + End of Generated Code +*******************************************/ +55 +/***************************************** + Emitting Generated Code +*******************************************/ +class test4 extends ((Int, Int)=>(Int)) { +def apply(x71:Int, x72:Int): Int = { +var x74 = 0 +val x73 = x71 < x72 +if (x73) { +var x76 : Int = x71 +while (x76 < x72) { +val x77 = x76 % 2 +val x78 = x77 != 0 +if (x78) { +val x79 = x74 +val x80 = x79 + x76 +x74 = x80 +() +} else { +() +} + +x76 = x76 + 1 +} + +} else { +() +} +val x88 = x74 +x88 +} +} +/***************************************** + End of Generated Code +*******************************************/ +25 +/***************************************** + Emitting Generated Code +*******************************************/ +class test5 extends ((Int, Int)=>(Int)) { +def apply(x109:Int, x110:Int): Int = { +var x112 = 0 +val x111 = x109 < x110 +if (x111) { +var x114 : Int = x109 +while (x114 < x110) { +val x115 = x112 +val x116 = x115 + x114 +x112 = x116 + +x114 = x114 + 1 +} + +} else { +() +} +if (x111) { +var x122 : Int = x109 +while (x122 < x110) { +val x123 = x122 % 2 +val x124 = x123 != 0 +if (x124) { +val x125 = x112 +val x126 = x125 + x122 +x112 = x126 +() +} else { +() +} + +x122 = x122 + 1 +} + +} else { +() +} +val x134 = x112 +x134 +} +} +/***************************************** + End of Generated Code +*******************************************/ +80 +/***************************************** + Emitting Generated Code +*******************************************/ +class test6 extends ((Int, Int)=>(Int)) { +def apply(x163:Int, x164:Int): Int = { +var x166 = 0 +val x165 = x163 < x164 +if (x165) { +var x168 : Int = x163 +while (x168 < x164) { +val x169 = x163 < x168 +if (x169) { +var x171 : Int = x163 +while (x171 < x168) { +val x172 = x166 +val x173 = x172 + x171 +x166 = x173 + +x171 = x171 + 1 +} + +} else { +() +} + +x168 = x168 + 1 +} + +} else { +() +} +val x183 = x166 +x183 +} +} +/***************************************** + End of Generated Code +*******************************************/ +20 +/***************************************** + Emitting Generated Code +*******************************************/ +class test8 extends ((Int)=>(Int)) { +def apply(x207:Int): Int = { +var x208 = 0 +val x209 = x208 +val x210 = x209 + 1 +x208 = x210 +val x212 = x208 +val x213 = x212 + 2 +x208 = x213 +val x215 = x208 +val x216 = x215 + 3 +x208 = x216 +val x218 = x208 +x218 +} +} +/***************************************** + End of Generated Code +*******************************************/ +6 +/***************************************** + Emitting Generated Code +*******************************************/ +class test6b extends ((Int, Int)=>(Int)) { +def apply(x233:Int, x234:Int): Int = { +var x236 = 0 +val x235 = x233 < x234 +val x252 = if (x235) { +var x238 : Int = x233 +while (x238 < x234) { +val x239 = x233 < x238 +val x248 = if (x239) { +var x241 : Int = x233 +while (x241 < x238) { +val x242 = x236 +val x243 = x242 + x241 +x236 = x243 + +x241 = x241 + 1 +} +true +} else { +false +} + +x238 = x238 + 1 +} +true +} else { +false +} +val x253 = x236 +x253 +} +} +/***************************************** + End of Generated Code +*******************************************/ +20 +/***************************************** + Emitting Generated Code +*******************************************/ +class test9 extends ((Int, Int)=>(Int)) { +def apply(x277:Int, x278:Int): Int = { +var x280 = 0 +val x279 = x277 < x278 +if (x279) { +var x282 : Int = x277 +while (x282 < x278) { +val x286 = x280 +val x283 = x282 * 2 +val x284 = x283 + 1 +val x285 = x284 * 3 +val x287 = x286 + x285 +x280 = x287 + +x282 = x282 + 1 +} + +} else { +() +} +val x293 = x280 +x293 +} +} +/***************************************** + End of Generated Code +*******************************************/ +297 +/***************************************** + Emitting Generated Code +*******************************************/ +class test10 extends ((Int)=>(Int)) { +def apply(x313:Int): Int = { +var x315 = 0 +val x314 = 1 < x313 +if (x314) { +var x317 : Int = 1 +while (x317 < x313) { +val x318 = 1 < x317 +if (x318) { +var x320 : Int = 1 +while (x320 < x317) { +val x322 = x315 +val x321 = x317 * x320 +val x323 = x322 + x321 +x315 = x323 + +x320 = x320 + 1 +} + +} else { +() +} + +x317 = x317 + 1 +} + +} else { +() +} +val x333 = x315 +x333 +} +} +/***************************************** + End of Generated Code +*******************************************/ +870 +/***************************************** + Emitting Generated Code +*******************************************/ +class test11 extends ((Int, Int)=>(Int)) { +def apply(x357:Int, x358:Int): Int = { +var x360 = 0 +val x359 = x357 < x358 +if (x359) { +var x362 : Int = x357 +while (x362 < x358) { +val x363 = x357 < x362 +if (x363) { +var x366 : Int = x357 +while (x366 < x362) { +val x367 = x360 +val x368 = x367 + x366 +x360 = x368 + +x366 = x366 + 1 +} + +} else { +() +} + +x362 = x362 + 1 +} + +} else { +() +} +val x378 = x360 +x378 +} +} +/***************************************** + End of Generated Code +*******************************************/ +20 diff --git a/test-out/epfl/test15-tupled-generator-flatten.check b/test-out/epfl/test15-tupled-generator-flatten.check new file mode 100644 index 00000000..8b9b65e0 --- /dev/null +++ b/test-out/epfl/test15-tupled-generator-flatten.check @@ -0,0 +1,27 @@ +/***************************************** + Emitting Generated Code +*******************************************/ +class test1 extends ((Double, Long, Long, Double, Double, org.dbtoaster.dbtoasterlib.K3Collection.K3PersistentCollection[Long, Double], org.dbtoaster.dbtoasterlib.K3Collection.K3PersistentCollection[scala.Tuple2[Long, Double], Double], org.dbtoaster.dbtoasterlib.K3Collection.K3PersistentCollection[scala.Tuple2[Long, Double], Long], org.dbtoaster.dbtoasterlib.K3Collection.K3PersistentCollection[scala.Tuple2[Long, Double], Long], org.dbtoaster.dbtoasterlib.K3Collection.K3PersistentCollection[scala.Tuple2[Long, Double], Double])=>(scala.collection.immutable.List[scala.Tuple2[scala.Tuple2[Double, Long], Long]])) { +def apply(x0:Double, x1:Long, x2:Long, x3:Double, x4:Double, x5:org.dbtoaster.dbtoasterlib.K3Collection.K3PersistentCollection[Long, Double], x6:org.dbtoaster.dbtoasterlib.K3Collection.K3PersistentCollection[scala.Tuple2[Long, Double], Double], x7:org.dbtoaster.dbtoasterlib.K3Collection.K3PersistentCollection[scala.Tuple2[Long, Double], Long], x8:org.dbtoaster.dbtoasterlib.K3Collection.K3PersistentCollection[scala.Tuple2[Long, Double], Long], x9:org.dbtoaster.dbtoasterlib.K3Collection.K3PersistentCollection[scala.Tuple2[Long, Double], Double]): scala.collection.immutable.List[scala.Tuple2[scala.Tuple2[Double, Long], Long]] = { +val x13 = x8 // mutable K3PersistentCollection +var x18 = (List()) +val x25 = -1L * x4 +(x13.slice(x2,(List(0)))).foreach{ +x19 => +val x23 = (x19._1)._2 +x18 = (((((x23,((if ((1000.0 < (x23 + x25))) { +1L +} else { +0L +}) + (if ((1000.0 < (x4 + (-1L * x23)))) { +1L +} else { +0L +})))),((x19._2) * 1L))) :: (x18)) +} +(x18) +} +} +/***************************************** + End of Generated Code +*******************************************/ diff --git a/test-out/epfl/test15-tupled-generator-huge.check b/test-out/epfl/test15-tupled-generator-huge.check new file mode 100644 index 00000000..2886eb3a --- /dev/null +++ b/test-out/epfl/test15-tupled-generator-huge.check @@ -0,0 +1,93 @@ +/***************************************** + Emitting Generated Code +*******************************************/ +class VWAPonInsertBIDS extends ((Double, Long, Long, Double, Double, org.dbtoaster.dbtoasterlib.K3Collection.SimpleVal[Double], org.dbtoaster.dbtoasterlib.K3Collection.K3PersistentCollection[Double, Double], org.dbtoaster.dbtoasterlib.K3Collection.SimpleVal[Double], org.dbtoaster.dbtoasterlib.K3Collection.K3PersistentCollection[Double, Double])=>(Unit)) { +def apply(x0:Double, x1:Long, x2:Long, x3:Double, x4:Double, x5:org.dbtoaster.dbtoasterlib.K3Collection.SimpleVal[Double], x6:org.dbtoaster.dbtoasterlib.K3Collection.K3PersistentCollection[Double, Double], x7:org.dbtoaster.dbtoasterlib.K3Collection.SimpleVal[Double], x8:org.dbtoaster.dbtoasterlib.K3Collection.K3PersistentCollection[Double, Double]): Unit = { +val x9 = x5 // mutable SimpleVal +val x10 = x6 // mutable K3PersistentCollection +val x11 = x7 // mutable SimpleVal +val x12 = x8 // mutable K3PersistentCollection +x10.updateValue(x4,((if ((x10.contains(x4))) { +(x10.lookup(x4,0)) +} else { +0.0 +}) + (x4 * x3))) +x11.update(((x11.get()) + x3)) +x12.updateValue(x4,((if ((x12.contains(x4))) { +(x12.lookup(x4,0)) +} else { +0.0 +}) + x3)) +var x33 = 0.0 +val x31 = (x11.get()) * 0.25 +x10.foreach{ +x35 => +var x42 = 0.0 +val x36 = x35._1 +x12.foreach{ +x43 => +x42 = ((x42) + ((x43._2) * (if ((x36 < (x43._1))) { +1.0 +} else { +0.0 +}))) +} +x33 = ((x33) + (((1L * (x35._2)) * 1L) * (if (((x42) < x31)) { +1.0 +} else { +0.0 +}))) +} +x9.update((0.0 + (x33))) +} +} +/***************************************** + End of Generated Code +*******************************************/ +/***************************************** + Emitting Generated Code +*******************************************/ +class VWAPonDeleteBIDS extends ((Double, Long, Long, Double, Double, org.dbtoaster.dbtoasterlib.K3Collection.SimpleVal[Double], org.dbtoaster.dbtoasterlib.K3Collection.K3PersistentCollection[Double, Double], org.dbtoaster.dbtoasterlib.K3Collection.SimpleVal[Double], org.dbtoaster.dbtoasterlib.K3Collection.K3PersistentCollection[Double, Double])=>(Unit)) { +def apply(x76:Double, x77:Long, x78:Long, x79:Double, x80:Double, x81:org.dbtoaster.dbtoasterlib.K3Collection.SimpleVal[Double], x82:org.dbtoaster.dbtoasterlib.K3Collection.K3PersistentCollection[Double, Double], x83:org.dbtoaster.dbtoasterlib.K3Collection.SimpleVal[Double], x84:org.dbtoaster.dbtoasterlib.K3Collection.K3PersistentCollection[Double, Double]): Unit = { +val x85 = x81 // mutable SimpleVal +val x86 = x82 // mutable K3PersistentCollection +val x87 = x83 // mutable SimpleVal +val x88 = x84 // mutable K3PersistentCollection +x86.updateValue(x80,((if ((x86.contains(x80))) { +(x86.lookup(x80,0)) +} else { +0.0 +}) + ((-1L * x80) * x79))) +val x98 = -1L * x79 +x87.update(((x87.get()) + x98)) +x88.updateValue(x80,((if ((x88.contains(x80))) { +(x88.lookup(x80,0)) +} else { +0.0 +}) + x98)) +var x111 = 0.0 +val x109 = (x87.get()) * 0.25 +x86.foreach{ +x113 => +var x120 = 0.0 +val x114 = x113._1 +x88.foreach{ +x121 => +x120 = ((x120) + ((x121._2) * (if ((x114 < (x121._1))) { +1.0 +} else { +0.0 +}))) +} +x111 = ((x111) + (((1L * (x113._2)) * 1L) * (if (((x120) < x109)) { +1.0 +} else { +0.0 +}))) +} +x85.update((0.0 + (x111))) +} +} +/***************************************** + End of Generated Code +*******************************************/ diff --git a/test-out/epfl/test15-tupled-generator-simple.check b/test-out/epfl/test15-tupled-generator-simple.check new file mode 100644 index 00000000..f51e9d9d --- /dev/null +++ b/test-out/epfl/test15-tupled-generator-simple.check @@ -0,0 +1,36 @@ +/***************************************** + Emitting Generated Code +*******************************************/ +class test1 extends ((org.dbtoaster.dbtoasterlib.K3Collection.K3PersistentCollection[Double, Double], org.dbtoaster.dbtoasterlib.K3Collection.K3PersistentCollection[scala.Tuple2[Long, Double], Double])=>(Unit)) { +def apply(x0:org.dbtoaster.dbtoasterlib.K3Collection.K3PersistentCollection[Double, Double], x1:org.dbtoaster.dbtoasterlib.K3Collection.K3PersistentCollection[scala.Tuple2[Long, Double], Double]): Unit = { +var x2 = 0.0 +x0.foreach{ +x3 => +x2 = ((x2) + ((x3._2) * (if ((12.2 < (x3._1))) { +1.0 +} else { +0.0 +}))) +} +} +} +/***************************************** + End of Generated Code +*******************************************/ +/***************************************** + Emitting Generated Code +*******************************************/ +class test2 extends ((Long, org.dbtoaster.dbtoasterlib.K3Collection.K3PersistentCollection[scala.Tuple2[Long, Double], Double])=>(Double)) { +def apply(x17:Long, x18:org.dbtoaster.dbtoasterlib.K3Collection.K3PersistentCollection[scala.Tuple2[Long, Double], Double]): Double = { +var x22 = 0.0 +val x30 = 200L * 75.32 +(x18.slice(x17,(List(0)))).foreach{ +x23 => +x22 = ((x22) + x30) +} +(x22) +} +} +/***************************************** + End of Generated Code +*******************************************/ diff --git a/test-out/epfl/test16-DBMSOpt1.check b/test-out/epfl/test16-DBMSOpt1.check new file mode 100644 index 00000000..e0554ad5 --- /dev/null +++ b/test-out/epfl/test16-DBMSOpt1.check @@ -0,0 +1,2 @@ +-1 +0 diff --git a/test-out/epfl/test16-DBMSOpt2.check b/test-out/epfl/test16-DBMSOpt2.check new file mode 100644 index 00000000..f7f62d72 --- /dev/null +++ b/test-out/epfl/test16-DBMSOpt2.check @@ -0,0 +1,69 @@ +/***************************************** + Emitting Generated Code +*******************************************/ +class lala extends (()=>(Array[Array[Int]])) { +def apply(): Array[Array[Int]] = { +val x0 = new Array[Array[Int]](3) +var x1 = x0 +val x2 = x1 +val x3 = new Array[Int](3) +x2(0) = x3 +val x5 = new Array[Int](3) +x2(1) = x5 +val x7 = new Array[Int](3) +x2(2) = x7 +val x9 = x2(0) +var x10 = x9 +val x11 = x10 +x11(0) = 2 +val x13 = x11.mkString(",") +println(x13) +val x15 = new Array[Int](7) +x10 = x15 +val x17 = x10 +x17(6) = 1 +val x19 = x17.mkString(",") +println(x19) +x2(0) = x17 +val x22 = x2(1) +var x23 = x22 +val x24 = x23 +x24(1) = 2 +val x26 = x24.mkString(",") +println(x26) +val x28 = new Array[Int](7) +x23 = x28 +val x30 = x23 +x30(5) = 1 +val x32 = x30.mkString(",") +println(x32) +x2(1) = x30 +val x35 = x2(2) +var x36 = x35 +val x37 = x36 +x37(2) = 2 +val x39 = x37.mkString(",") +println(x39) +val x41 = new Array[Int](7) +x36 = x41 +val x43 = x36 +x43(4) = 1 +val x45 = x43.mkString(",") +println(x45) +x2(2) = x43 +x2 +} +} +/***************************************** + End of Generated Code +*******************************************/ +2,0,0 +0,0,0,0,0,0,1 +0,2,0 +0,0,0,0,0,1,0 +0,0,2 +0,0,0,0,1,0,0 +3 +0,0,0,0,0,0,1 +0,0,0,0,0,1,0 +0,0,0,0,1,0,0 diff --git a/test-src/epfl/test1-arith/TestConstCSE.scala b/test-src/epfl/test1-arith/TestConstCSE.scala new file mode 100644 index 00000000..d0ab1edd --- /dev/null +++ b/test-src/epfl/test1-arith/TestConstCSE.scala @@ -0,0 +1,54 @@ +package scala.lms +package epfl +package test1 + +import common._ +import internal._ +import java.io._ + +import scala.reflect.SourceContext + + +class TestConstCSE extends FileDiffSuite { + + val prefix = "test-out/epfl/test1-" + + /** + * This test targets checking resolved bug for equality check + * on Const values. For more information, have a look at + * "equals" method implementation (and its comments) for Const + * class inside Expressions trait. + */ + def testBugConstCSE1 = { + withOutFile(prefix+"constcse1") { + trait Prog extends ScalaOpsPkg { + def test1(test_param: Rep[Boolean], acc: Rep[Long]): Rep[Long] = { + val dblVal = if(test_param) unit(1.0) else unit(0.0) + val lngVal = if(test_param) unit(1L) else unit(0L) + auxMethod(acc + lngVal, dblVal) + } + + def auxMethod(val1: Rep[Long], val2: Rep[Double]): Rep[Long] = { + val1 + unit(133L) + rep_asinstanceof[Double, Long](val2,manifest[Double],manifest[Long]) + } + } + + new Prog with ScalaOpsPkgExp with ScalaCompile{ self => + + val printWriter = new java.io.PrintWriter(System.out) + + //test1: first "loop" + val codegen = new ScalaCodeGenPkg with ScalaCodegen{ val IR: self.type = self } + + codegen.emitSource2(test1 _ , "test1", printWriter) + val source = new StringWriter + val testc1 = compile2(test1) + scala.Console.println(testc1(true,12)) + + + } + } + assertFileEqualsCheck(prefix+"constcse1") + } + +}