Skip to content

Commit df3edc5

Browse files
committed
Merge pull request #97 from afernandez90/develop
Lambda function and if-then-else return value deduction
2 parents 5c30222 + 5da2bab commit df3edc5

File tree

4 files changed

+33
-13
lines changed

4 files changed

+33
-13
lines changed

src/common/Functions.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -350,8 +350,8 @@ trait CGenFunctions extends CGenEffect with BaseGenFunctions {
350350
override def emitNode(sym: Sym[Any], rhs: Def[Any]) = rhs match {
351351
case e@Lambda(fun, x, y) =>
352352
val retType = remap(getBlockResult(y).tp)
353-
stream.println("function<"+retType+"("+
354-
remap(x.tp)+")> "+quote(sym)+
353+
val retTp = if (cppExplicitFunRet == "true") "function<"+retType+"("+remap(x.tp)+")>" else "auto"
354+
stream.println(retTp+" "+quote(sym)+
355355
" = [&]("+remap(x.tp)+" "+quote(x)+") {")
356356
emitBlock(y)
357357
val z = getBlockResult(y)
@@ -378,8 +378,8 @@ trait CGenTupledFunctions extends CGenFunctions with GenericGenUnboxedTupleAcces
378378
override def emitNode(sym: Sym[Any], rhs: Def[Any]) = rhs match {
379379
case Lambda(fun, UnboxedTuple(xs), y) =>
380380
val retType = remap(getBlockResult(y).tp)
381-
stream.println("function<"+retType+"("+
382-
xs.map(s=>remap(s.tp)).mkString(",")+")> "+quote(sym)+
381+
val retTp = if (cppExplicitFunRet == "true") "function<"+retType+"("+xs.map(s=>remap(s.tp)).mkString(",")+")>" else "auto"
382+
stream.println(retTp+" "+quote(sym)+
383383
" = [&]("+xs.map(s=>remap(s.tp)+" "+quote(s)).mkString(",")+") {")
384384
emitBlock(y)
385385
val z = getBlockResult(y)

src/common/IfThenElse.scala

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -384,14 +384,28 @@ trait CGenIfThenElse extends CGenEffect with BaseGenIfThenElse {
384384
emitBlock(b)
385385
stream.println("}")
386386
case _ =>
387-
stream.println("%s %s;".format(remap(sym.tp),quote(sym)))
388-
stream.println("if (" + quote(c) + ") {")
389-
emitBlock(a)
390-
stream.println("%s = %s;".format(quote(sym),quote(getBlockResult(a))))
391-
stream.println("} else {")
392-
emitBlock(b)
393-
stream.println("%s = %s;".format(quote(sym),quote(getBlockResult(b))))
394-
stream.println("}")
387+
if (cppIfElseAutoRet == "true") {
388+
val ten = quote(sym) + "True"
389+
val fen = quote(sym) + "False"
390+
def emitCondFun[T: Manifest](fname: String, block: Block[T]) {
391+
stream.println("auto " + fname + " = [&]() {");
392+
emitBlock(block)
393+
stream.println("return " + quote(getBlockResult(block)) + ";")
394+
stream.println("};")
395+
}
396+
emitCondFun(ten, a)
397+
emitCondFun(fen, b)
398+
stream.println("auto " + quote(sym) + " = " + quote(c) + " ? " + ten + "() : " + fen + "();")
399+
} else {
400+
stream.println("%s %s;".format(remap(sym.tp),quote(sym)))
401+
stream.println("if (" + quote(c) + ") {")
402+
emitBlock(a)
403+
stream.println("%s = %s;".format(quote(sym),quote(getBlockResult(a))))
404+
stream.println("} else {")
405+
emitBlock(b)
406+
stream.println("%s = %s;".format(quote(sym),quote(getBlockResult(b))))
407+
stream.println("}")
408+
}
395409
}
396410
/*
397411
val booll = remap(sym.tp).equals("void")

src/common/MathOps.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ trait MathOps extends Base {
2828
def abs[A:Manifest:Numeric](x: Rep[A])(implicit pos: SourceContext) = math_abs(x)
2929
def max[A:Manifest:Numeric](x: Rep[A], y: Rep[A])(implicit pos: SourceContext) = math_max(x,y)
3030
def min[A:Manifest:Numeric](x: Rep[A], y: Rep[A])(implicit pos: SourceContext) = math_min(x,y)
31-
def Pi(implicit pos: SourceContext) = math_pi
31+
def Pi(implicit pos: SourceContext) = 3.141592653589793238462643383279502884197169
3232
def E(implicit pos: SourceContext) = math_e
3333
}
3434

src/internal/Config.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,10 @@ trait Config {
88

99
// memory management type for C++ target (refcnt or gc)
1010
val cppMemMgr = System.getProperty("lms.cpp.memmgr","malloc")
11+
12+
// explicit return type of lambda functions (allows recursive functions but is less generic)
13+
val cppExplicitFunRet = System.getProperty("lms.cpp.explicitFunRet","true")
14+
15+
// auto return value of if-else expressions (allows type deduction on if-then-else expressions)
16+
val cppIfElseAutoRet = System.getProperty("lms.cpp.ifElseAutoRet","false")
1117
}

0 commit comments

Comments
 (0)