@@ -8,7 +8,7 @@ import com.thoughtworks.raii.asynchronous._
8
8
import com .thoughtworks .raii .covariant ._
9
9
import com .thoughtworks .tryt .covariant ._
10
10
import com .typesafe .scalalogging .StrictLogging
11
- import org .lwjgl .opencl .CLCapabilities
11
+ import org .lwjgl .opencl .{ CL10 , CLCapabilities }
12
12
import org .lwjgl .system .Configuration
13
13
import org .nd4j .linalg .api .ndarray .INDArray
14
14
import org .nd4j .linalg .convolution .Convolution
@@ -23,6 +23,28 @@ import scala.util.Try
23
23
24
24
object benchmarks {
25
25
26
+ trait TensorState {
27
+ @ Param (Array (" CPU" , " GPU" ))
28
+ protected var deviceType : String = _
29
+
30
+ trait BenchmarkTensors
31
+ extends StrictLogging
32
+ with Tensors .UnsafeMathOptimizations
33
+ with Tensors .SuppressWarnings
34
+ with OpenCL .LogContextNotification
35
+ with OpenCL .GlobalExecutionContext
36
+ with OpenCL .UseFirstPlatform
37
+ with OpenCL .CommandQueuePool
38
+ with OpenCL .DontReleaseEventTooEarly
39
+ with Tensors .WangHashingRandomNumberGenerator {
40
+
41
+ @ transient
42
+ protected lazy val deviceIds : Seq [DeviceId ] = {
43
+ deviceIdsByType(classOf [CL10 ].getField(s " CL_DEVICE_TYPE_ $deviceType" ).get(null ).asInstanceOf [Int ])
44
+ }
45
+ }
46
+ }
47
+
26
48
@ Threads (value = Threads .MAX )
27
49
@ State (Scope .Benchmark )
28
50
class Nd4jTanh extends TanhState {
@@ -47,18 +69,8 @@ object benchmarks {
47
69
48
70
@ Threads (value = Threads .MAX )
49
71
@ State (Scope .Benchmark )
50
- class TensorTanh extends TanhState {
51
- trait Benchmarks
52
- extends StrictLogging
53
- with Tensors .UnsafeMathOptimizations
54
- with Tensors .SuppressWarnings
55
- with OpenCL .LogContextNotification
56
- with OpenCL .GlobalExecutionContext
57
- with OpenCL .UseAllCpuDevices
58
- with OpenCL .UseFirstPlatform
59
- with OpenCL .CommandQueuePool
60
- with OpenCL .DontReleaseEventTooEarly
61
- with Tensors .WangHashingRandomNumberGenerator {
72
+ class TensorTanh extends TanhState with TensorState {
73
+ trait Benchmarks extends BenchmarkTensors {
62
74
63
75
protected val numberOfCommandQueuesPerDevice : Int = 2
64
76
@@ -130,17 +142,8 @@ object benchmarks {
130
142
131
143
@ Threads (value = Threads .MAX )
132
144
@ State (Scope .Benchmark )
133
- class TensorSum extends SumState {
134
- trait Benchmarks
135
- extends StrictLogging
136
- with Tensors .UnsafeMathOptimizations
137
- with OpenCL .LogContextNotification
138
- with OpenCL .GlobalExecutionContext
139
- with OpenCL .UseAllCpuDevices
140
- with OpenCL .UseFirstPlatform
141
- with OpenCL .CommandQueuePool
142
- with OpenCL .DontReleaseEventTooEarly
143
- with Tensors .WangHashingRandomNumberGenerator {
145
+ class TensorSum extends SumState with TensorState {
146
+ trait Benchmarks extends BenchmarkTensors {
144
147
145
148
protected val numberOfCommandQueuesPerDevice : Int = 2
146
149
@@ -200,17 +203,8 @@ object benchmarks {
200
203
201
204
@ Threads (value = Threads .MAX )
202
205
@ State (Scope .Benchmark )
203
- class TensorRandomNormal extends RandomNormalState {
204
- trait Benchmarks
205
- extends StrictLogging
206
- with Tensors .UnsafeMathOptimizations
207
- with OpenCL .LogContextNotification
208
- with OpenCL .GlobalExecutionContext
209
- with OpenCL .UseAllCpuDevices
210
- with OpenCL .UseFirstPlatform
211
- with OpenCL .CommandQueuePool
212
- with OpenCL .DontReleaseEventTooEarly
213
- with Tensors .WangHashingRandomNumberGenerator {
206
+ class TensorRandomNormal extends RandomNormalState with TensorState {
207
+ trait Benchmarks extends BenchmarkTensors {
214
208
215
209
protected val numberOfCommandQueuesPerDevice : Int = 2
216
210
@@ -295,19 +289,9 @@ object benchmarks {
295
289
296
290
@ Threads (value = Threads .MAX )
297
291
@ State (Scope .Benchmark )
298
- class TensorConvolution extends ConvolutionState {
292
+ class TensorConvolution extends ConvolutionState with TensorState {
299
293
300
- trait Benchmarks
301
- extends StrictLogging
302
- with Tensors .UnsafeMathOptimizations
303
- with OpenCL .LogContextNotification
304
- with OpenCL .GlobalExecutionContext
305
- with OpenCL .UseAllCpuDevices
306
- with OpenCL .UseFirstPlatform
307
- with OpenCL .CommandQueuePool
308
- with OpenCL .DontReleaseEventTooEarly
309
- with Tensors .WangHashingRandomNumberGenerator
310
- with ConvolutionTensors {
294
+ trait Benchmarks extends BenchmarkTensors with ConvolutionTensors {
311
295
312
296
protected val numberOfCommandQueuesPerDevice = 2
313
297
0 commit comments