Skip to content

Commit c34f558

Browse files
authored
Merge pull request #94 from ThoughtWorksInc/tanh
Add tanh method
2 parents e23dca3 + 3afdc0b commit c34f558

File tree

5 files changed

+58
-15
lines changed

5 files changed

+58
-15
lines changed

Expressions/src/main/scala/com/thoughtworks/compute/Expressions.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ object Expressions {
120120
def exp(operand: FloatTerm): FloatTerm
121121
def abs(operand: FloatTerm): FloatTerm
122122
def sqrt(operand: FloatTerm): FloatTerm
123+
def tanh(operand: FloatTerm): FloatTerm
123124
}
124125

125126
/** @template */

OpenCLKernelBuilder/src/main/scala/com/thoughtworks/compute/OpenCLKernelBuilder.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,14 @@ trait OpenCLKernelBuilder extends AllExpressions {
265265
"""
266266
float.termFactory.newInstance(valueTermName)
267267
}
268+
269+
def tanh(operand0: FloatTerm): FloatTerm = {
270+
val valueTermName = freshName("")
271+
localDefinitions += fastraw"""
272+
const ${operand0.typeCode} $valueTermName = tanh(${operand0.termCode});
273+
"""
274+
float.termFactory.newInstance(valueTermName)
275+
}
268276
}
269277

270278
type FloatType <: (ValueType with Any) with ClFloatType

Tensors/src/main/scala/com/thoughtworks/compute/Tensors.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -502,6 +502,10 @@ trait Tensors extends OpenCL {
502502
leftHandSide.derivedTensor(trees.float.sqrt(leftHandSide.closure.asInstanceOf[FloatTerm]))
503503
}
504504

505+
def tanh(leftHandSide: Tensor): Tensor = {
506+
leftHandSide.derivedTensor(trees.float.tanh(leftHandSide.closure.asInstanceOf[FloatTerm]))
507+
}
508+
505509
def exp(leftHandSide: Tensor): Tensor = {
506510
leftHandSide.derivedTensor(trees.float.exp(leftHandSide.closure.asInstanceOf[FloatTerm]))
507511
}

Trees/src/main/scala/com/thoughtworks/compute/Trees.scala

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -421,6 +421,21 @@ object Trees {
421421
}
422422
}
423423

424+
/** @group AST */
425+
@(silent @companionObject)
426+
final case class Tanh(operand0: FloatTree) extends FloatOperator {
427+
428+
protected def erasedExport(foreignCategory: Category, context: ExportContext) = {
429+
context.asScala.getOrElseUpdate(this, foreignCategory.float.tanh(operand0.export(foreignCategory, context)))
430+
}
431+
432+
protected def erasedAlphaConversion(context: AlphaConversionContext): Tree = {
433+
def converted = copy(operand0 = operand0.alphaConversion(context))
434+
context.asScala.getOrElseUpdate(this, converted)
435+
}
436+
}
437+
438+
424439
/** @group AST */
425440
@(silent @companionObject)
426441
final case class Sqrt(operand0: FloatTree) extends FloatOperator {
@@ -619,6 +634,11 @@ object Trees {
619634
term(Sqrt(operand.tree))
620635
}
621636

637+
@inline
638+
def tanh(operand: FloatTerm): FloatTerm = {
639+
term(Tanh(operand.tree))
640+
}
641+
622642
}
623643

624644
type FloatType <: (ValueType with Any) with FloatTreeType

benchmarks/src/jmh/scala/com/thoughtworks/compute/benchmarks.scala

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import com.thoughtworks.raii.covariant._
99
import com.thoughtworks.tryt.covariant._
1010
import com.typesafe.scalalogging.StrictLogging
1111
import org.lwjgl.opencl.CLCapabilities
12+
import org.lwjgl.system.Configuration
1213
import org.nd4j.linalg.api.ndarray.INDArray
1314
import org.nd4j.linalg.convolution.Convolution
1415
import org.nd4j.linalg.factory.Nd4j
@@ -24,24 +25,29 @@ object benchmarks {
2425

2526
@Threads(value = Threads.MAX)
2627
@State(Scope.Benchmark)
27-
class Nd4jSigmoid extends SigmoidState {
28+
class Nd4jTanh extends TanhState {
2829

2930
@transient
3031
private lazy val input = Nd4j.randn(Array.fill(numberOfDimensions)(size))
31-
private def sigmoid(x: INDArray): INDArray = {
32+
private def tanh(x: INDArray): INDArray = {
3233
val expX = Transforms.exp(x)
3334
expX.div(expX.add(1.0))
3435
}
3536
@Benchmark
36-
final def nd4jSigmoidBenchmark(): Array[Float] = {
37-
sigmoid(input).data().asFloat()
37+
final def nd4jTanhBenchmark(): Array[Float] = {
38+
(0 until numberOfIterations)
39+
.foldLeft(input) { (input, _) =>
40+
Transforms.tanh(input)
41+
}
42+
.data()
43+
.asFloat()
3844
}
3945

4046
}
4147

4248
@Threads(value = Threads.MAX)
4349
@State(Scope.Benchmark)
44-
class TensorSigmoid extends SigmoidState {
50+
class TensorTanh extends TanhState {
4551
trait Benchmarks
4652
extends StrictLogging
4753
with Tensors.UnsafeMathOptimizations
@@ -56,18 +62,18 @@ object benchmarks {
5662

5763
protected val numberOfCommandQueuesPerDevice: Int = 2
5864

59-
private def sigmoid(x: Tensor): Tensor = {
60-
val expX = Tensor.exp(x)
61-
expX / (expX + Tensor.fill(1.0f, expX.shape))
62-
}
63-
6465
def doBenchmark(): Do[() => Array[Float]] = {
6566
val input = Tensor.randomNormal(Array.fill(numberOfDimensions)(size))
6667

6768
input.doBuffer.map { _ =>
6869
{ () =>
69-
sigmoid(input).flatArray.run.blockingAwait
70-
70+
(0 until numberOfIterations)
71+
.foldLeft(input) { (input, _) =>
72+
Tensor.tanh(input)
73+
}
74+
.flatArray
75+
.run
76+
.blockingAwait
7177
}
7278
}
7379
}
@@ -77,6 +83,7 @@ object benchmarks {
7783

7884
@Setup
7985
final def setup(): Unit = {
86+
// Configuration.OPENCL_LIBRARY_NAME.set("/opt/pocl-1.1/lib/libOpenCL.dylib")
8087
assert(benchmarkResouce == null)
8188
val Do(TryT(ResourceT(resourceContinuation))) =
8289
Do.monadicCloseable(Factory[Benchmarks].newInstance()).flatMap(_.doBenchmark())
@@ -91,14 +98,17 @@ object benchmarks {
9198
}
9299

93100
@Benchmark
94-
final def tensorSigmoidBenchmark(): Array[Float] = {
101+
final def tensorTanhBenchmark(): Array[Float] = {
95102
benchmarkResouce.value.get.apply()
96103
}
97104

98105
}
99106

100-
trait SigmoidState {
101-
@Param(Array("3", "2", "1"))
107+
trait TanhState {
108+
@Param(Array("100", "10", "1"))
109+
protected var numberOfIterations: Int = _
110+
111+
@Param(Array("2", "3", "1"))
102112
protected var numberOfDimensions: Int = _
103113

104114
@Param(Array("128", "64", "32", "16"))

0 commit comments

Comments
 (0)