Skip to content

Commit 9cd6dcb

Browse files
committed
Switch Compute.scala device according to JMH parameter
1 parent a0acf0b commit 9cd6dcb

File tree

1 file changed

+31
-47
lines changed

1 file changed

+31
-47
lines changed

benchmarks/src/jmh/scala/com/thoughtworks/compute/benchmarks.scala

Lines changed: 31 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ import com.thoughtworks.raii.asynchronous._
88
import com.thoughtworks.raii.covariant._
99
import com.thoughtworks.tryt.covariant._
1010
import com.typesafe.scalalogging.StrictLogging
11-
import org.lwjgl.opencl.CLCapabilities
11+
import org.lwjgl.opencl.{CL10, CLCapabilities}
1212
import org.lwjgl.system.Configuration
1313
import org.nd4j.linalg.api.ndarray.INDArray
1414
import org.nd4j.linalg.convolution.Convolution
@@ -23,6 +23,28 @@ import scala.util.Try
2323

2424
object benchmarks {
2525

26+
trait TensorState {
27+
@Param(Array("CPU", "GPU"))
28+
protected var deviceType: String = _
29+
30+
trait BenchmarkTensors
31+
extends StrictLogging
32+
with Tensors.UnsafeMathOptimizations
33+
with Tensors.SuppressWarnings
34+
with OpenCL.LogContextNotification
35+
with OpenCL.GlobalExecutionContext
36+
with OpenCL.UseFirstPlatform
37+
with OpenCL.CommandQueuePool
38+
with OpenCL.DontReleaseEventTooEarly
39+
with Tensors.WangHashingRandomNumberGenerator {
40+
41+
@transient
42+
protected lazy val deviceIds: Seq[DeviceId] = {
43+
deviceIdsByType(classOf[CL10].getField(s"CL_DEVICE_TYPE_$deviceType").get(null).asInstanceOf[Int])
44+
}
45+
}
46+
}
47+
2648
@Threads(value = Threads.MAX)
2749
@State(Scope.Benchmark)
2850
class Nd4jTanh extends TanhState {
@@ -47,18 +69,8 @@ object benchmarks {
4769

4870
@Threads(value = Threads.MAX)
4971
@State(Scope.Benchmark)
50-
class TensorTanh extends TanhState {
51-
trait Benchmarks
52-
extends StrictLogging
53-
with Tensors.UnsafeMathOptimizations
54-
with Tensors.SuppressWarnings
55-
with OpenCL.LogContextNotification
56-
with OpenCL.GlobalExecutionContext
57-
with OpenCL.UseAllCpuDevices
58-
with OpenCL.UseFirstPlatform
59-
with OpenCL.CommandQueuePool
60-
with OpenCL.DontReleaseEventTooEarly
61-
with Tensors.WangHashingRandomNumberGenerator {
72+
class TensorTanh extends TanhState with TensorState {
73+
trait Benchmarks extends BenchmarkTensors {
6274

6375
protected val numberOfCommandQueuesPerDevice: Int = 2
6476

@@ -130,17 +142,8 @@ object benchmarks {
130142

131143
@Threads(value = Threads.MAX)
132144
@State(Scope.Benchmark)
133-
class TensorSum extends SumState {
134-
trait Benchmarks
135-
extends StrictLogging
136-
with Tensors.UnsafeMathOptimizations
137-
with OpenCL.LogContextNotification
138-
with OpenCL.GlobalExecutionContext
139-
with OpenCL.UseAllCpuDevices
140-
with OpenCL.UseFirstPlatform
141-
with OpenCL.CommandQueuePool
142-
with OpenCL.DontReleaseEventTooEarly
143-
with Tensors.WangHashingRandomNumberGenerator {
145+
class TensorSum extends SumState with TensorState {
146+
trait Benchmarks extends BenchmarkTensors {
144147

145148
protected val numberOfCommandQueuesPerDevice: Int = 2
146149

@@ -200,17 +203,8 @@ object benchmarks {
200203

201204
@Threads(value = Threads.MAX)
202205
@State(Scope.Benchmark)
203-
class TensorRandomNormal extends RandomNormalState {
204-
trait Benchmarks
205-
extends StrictLogging
206-
with Tensors.UnsafeMathOptimizations
207-
with OpenCL.LogContextNotification
208-
with OpenCL.GlobalExecutionContext
209-
with OpenCL.UseAllCpuDevices
210-
with OpenCL.UseFirstPlatform
211-
with OpenCL.CommandQueuePool
212-
with OpenCL.DontReleaseEventTooEarly
213-
with Tensors.WangHashingRandomNumberGenerator {
206+
class TensorRandomNormal extends RandomNormalState with TensorState {
207+
trait Benchmarks extends BenchmarkTensors {
214208

215209
protected val numberOfCommandQueuesPerDevice: Int = 2
216210

@@ -295,19 +289,9 @@ object benchmarks {
295289

296290
@Threads(value = Threads.MAX)
297291
@State(Scope.Benchmark)
298-
class TensorConvolution extends ConvolutionState {
292+
class TensorConvolution extends ConvolutionState with TensorState {
299293

300-
trait Benchmarks
301-
extends StrictLogging
302-
with Tensors.UnsafeMathOptimizations
303-
with OpenCL.LogContextNotification
304-
with OpenCL.GlobalExecutionContext
305-
with OpenCL.UseAllCpuDevices
306-
with OpenCL.UseFirstPlatform
307-
with OpenCL.CommandQueuePool
308-
with OpenCL.DontReleaseEventTooEarly
309-
with Tensors.WangHashingRandomNumberGenerator
310-
with ConvolutionTensors {
294+
trait Benchmarks extends BenchmarkTensors with ConvolutionTensors {
311295

312296
protected val numberOfCommandQueuesPerDevice = 2
313297

0 commit comments

Comments
 (0)