Skip to content

Commit 93fba4e

Browse files
committed
Add tanh method
1 parent e23dca3 commit 93fba4e

File tree

4 files changed

+33
-0
lines changed

4 files changed

+33
-0
lines changed

Expressions/src/main/scala/com/thoughtworks/compute/Expressions.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ object Expressions {
120120
def exp(operand: FloatTerm): FloatTerm
121121
def abs(operand: FloatTerm): FloatTerm
122122
def sqrt(operand: FloatTerm): FloatTerm
123+
def tanh(operand: FloatTerm): FloatTerm
123124
}
124125

125126
/** @template */

OpenCLKernelBuilder/src/main/scala/com/thoughtworks/compute/OpenCLKernelBuilder.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,14 @@ trait OpenCLKernelBuilder extends AllExpressions {
265265
"""
266266
float.termFactory.newInstance(valueTermName)
267267
}
268+
269+
def tanh(operand0: FloatTerm): FloatTerm = {
270+
val valueTermName = freshName("")
271+
localDefinitions += fastraw"""
272+
const ${operand0.typeCode} $valueTermName = tanh(${operand0.termCode});
273+
"""
274+
float.termFactory.newInstance(valueTermName)
275+
}
268276
}
269277

270278
type FloatType <: (ValueType with Any) with ClFloatType

Tensors/src/main/scala/com/thoughtworks/compute/Tensors.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -502,6 +502,10 @@ trait Tensors extends OpenCL {
502502
leftHandSide.derivedTensor(trees.float.sqrt(leftHandSide.closure.asInstanceOf[FloatTerm]))
503503
}
504504

505+
def tanh(leftHandSide: Tensor): Tensor = {
506+
leftHandSide.derivedTensor(trees.float.tanh(leftHandSide.closure.asInstanceOf[FloatTerm]))
507+
}
508+
505509
def exp(leftHandSide: Tensor): Tensor = {
506510
leftHandSide.derivedTensor(trees.float.exp(leftHandSide.closure.asInstanceOf[FloatTerm]))
507511
}

Trees/src/main/scala/com/thoughtworks/compute/Trees.scala

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -421,6 +421,21 @@ object Trees {
421421
}
422422
}
423423

424+
/** @group AST */
425+
@(silent @companionObject)
426+
final case class Tanh(operand0: FloatTree) extends FloatOperator {
427+
428+
protected def erasedExport(foreignCategory: Category, context: ExportContext) = {
429+
context.asScala.getOrElseUpdate(this, foreignCategory.float.tanh(operand0.export(foreignCategory, context)))
430+
}
431+
432+
protected def erasedAlphaConversion(context: AlphaConversionContext): Tree = {
433+
def converted = copy(operand0 = operand0.alphaConversion(context))
434+
context.asScala.getOrElseUpdate(this, converted)
435+
}
436+
}
437+
438+
424439
/** @group AST */
425440
@(silent @companionObject)
426441
final case class Sqrt(operand0: FloatTree) extends FloatOperator {
@@ -619,6 +634,11 @@ object Trees {
619634
term(Sqrt(operand.tree))
620635
}
621636

637+
@inline
638+
def tanh(operand: FloatTerm): FloatTerm = {
639+
term(Tanh(operand.tree))
640+
}
641+
622642
}
623643

624644
type FloatType <: (ValueType with Any) with FloatTreeType

0 commit comments

Comments
 (0)