Skip to content

Commit 2f7db28

Browse files
committed
Search device for all platform in tensor benchmark
1 parent bbee23f commit 2f7db28

File tree

2 files changed

+77
-36
lines changed

2 files changed

+77
-36
lines changed

OpenCL/src/main/scala/com/thoughtworks/compute/OpenCL.scala

Lines changed: 45 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,29 @@ import scala.language.higherKinds
5050
*/
5151
object OpenCL {
5252

53+
final case class PlatformId[Owner <: Singleton with OpenCL](handle: Long) extends AnyVal {
54+
55+
final def deviceIdsByType(deviceType: Int): Seq[DeviceId[Owner]] = {
56+
val Array(numberOfDevices) = {
57+
val a = Array(0)
58+
checkErrorCode(clGetDeviceIDs(handle, deviceType, null, a))
59+
a
60+
}
61+
val stack = stackPush()
62+
try {
63+
val deviceIdBuffer = stack.mallocPointer(numberOfDevices)
64+
checkErrorCode(clGetDeviceIDs(handle, deviceType, deviceIdBuffer, null: IntBuffer))
65+
for (i <- 0 until numberOfDevices) yield {
66+
val deviceId = deviceIdBuffer.get(i)
67+
new DeviceId[Owner](deviceId)
68+
}
69+
} finally {
70+
stack.close()
71+
}
72+
}
73+
74+
}
75+
5376
/** Returns a [[String]] for the C string `address`.
5477
*
5578
* @note We don't know the exact charset of the C string. Use [[memASCII]] because lwjgl treats them as ASCII.
@@ -265,25 +288,18 @@ object OpenCL {
265288
}
266289
}
267290

268-
trait UseFirstPlatform {
291+
trait UseFirstPlatform extends OpenCL {
269292
@transient
270-
protected lazy val platformId: Long = {
271-
val stack = stackPush()
272-
try {
273-
val platformIdBuffer = stack.mallocPointer(1)
274-
checkErrorCode(clGetPlatformIDs(platformIdBuffer, null: IntBuffer))
275-
platformIdBuffer.get(0)
276-
} finally {
277-
stack.close()
278-
}
293+
protected lazy val platformId: PlatformId = {
294+
platformIds.head
279295
}
280296
}
281297

282298
trait UseAllDevices extends OpenCL {
283299

284300
@transient
285301
protected lazy val deviceIds: Seq[DeviceId] = {
286-
deviceIdsByType(CL_DEVICE_TYPE_ALL)
302+
platformId.deviceIdsByType(CL_DEVICE_TYPE_ALL)
287303
}
288304

289305
}
@@ -292,7 +308,7 @@ object OpenCL {
292308

293309
@transient
294310
protected lazy val deviceIds: Seq[DeviceId] = {
295-
val allDeviceIds = deviceIdsByType(CL_DEVICE_TYPE_ALL)
311+
val allDeviceIds = platformId.deviceIdsByType(CL_DEVICE_TYPE_ALL)
296312
Seq(allDeviceIds.head)
297313
}
298314

@@ -302,23 +318,23 @@ object OpenCL {
302318

303319
@transient
304320
protected lazy val deviceIds: Seq[DeviceId] = {
305-
deviceIdsByType(CL_DEVICE_TYPE_GPU)
321+
platformId.deviceIdsByType(CL_DEVICE_TYPE_GPU)
306322
}
307323
}
308324

309325
trait UseFirstGpuDevice extends OpenCL {
310326

311327
@transient
312328
protected lazy val deviceIds: Seq[DeviceId] = {
313-
val allDeviceIds = deviceIdsByType(CL_DEVICE_TYPE_GPU)
329+
val allDeviceIds = platformId.deviceIdsByType(CL_DEVICE_TYPE_GPU)
314330
Seq(allDeviceIds.head)
315331
}
316332
}
317333
trait UseFirstCpuDevice extends OpenCL {
318334

319335
@transient
320336
protected lazy val deviceIds: Seq[DeviceId] = {
321-
val allDeviceIds = deviceIdsByType(CL_DEVICE_TYPE_CPU)
337+
val allDeviceIds = platformId.deviceIdsByType(CL_DEVICE_TYPE_CPU)
322338
Seq(allDeviceIds.head)
323339
}
324340
}
@@ -327,7 +343,7 @@ object OpenCL {
327343

328344
@transient
329345
protected lazy val deviceIds: Seq[DeviceId] = {
330-
deviceIdsByType(CL_DEVICE_TYPE_CPU)
346+
platformId.deviceIdsByType(CL_DEVICE_TYPE_CPU)
331347
}
332348
}
333349

@@ -1010,20 +1026,18 @@ trait OpenCL extends MonadicCloseable[UnitContinuation] with ImplicitsSingleton
10101026
type Event = OpenCL.Event[this.type]
10111027
type CommandQueue = OpenCL.CommandQueue[this.type]
10121028
type DeviceId = OpenCL.DeviceId[this.type]
1029+
type PlatformId = OpenCL.PlatformId[this.type]
10131030

1014-
protected final def deviceIdsByType(deviceType: Int): Seq[DeviceId] = {
1015-
val Array(numberOfDevices) = {
1016-
val a = Array(0)
1017-
checkErrorCode(clGetDeviceIDs(platformId, deviceType, null, a))
1018-
a
1019-
}
1031+
def platformIds: Seq[PlatformId] = {
10201032
val stack = stackPush()
10211033
try {
1022-
val deviceIdBuffer = stack.mallocPointer(numberOfDevices)
1023-
checkErrorCode(clGetDeviceIDs(platformId, deviceType, deviceIdBuffer, null: IntBuffer))
1024-
for (i <- 0 until deviceIdBuffer.capacity()) yield {
1025-
val deviceId = deviceIdBuffer.get(i)
1026-
new DeviceId(deviceId)
1034+
val numberOfPlatformsBuffer = stack.mallocInt(1)
1035+
checkErrorCode(clGetPlatformIDs(null, numberOfPlatformsBuffer))
1036+
val numberOfPlatforms = numberOfPlatformsBuffer.get(0)
1037+
val platformIdBuffer = stack.mallocPointer(numberOfPlatforms)
1038+
checkErrorCode(clGetPlatformIDs(platformIdBuffer, null: IntBuffer))
1039+
(0 until numberOfPlatforms).map { i =>
1040+
new PlatformId(platformIdBuffer.get(i))
10271041
}
10281042
} finally {
10291043
stack.close()
@@ -1174,20 +1188,20 @@ trait OpenCL extends MonadicCloseable[UnitContinuation] with ImplicitsSingleton
11741188

11751189
import OpenCL._
11761190

1177-
protected val platformId: Long
1191+
protected val platformId: PlatformId
11781192
protected val deviceIds: Seq[DeviceId]
11791193

11801194
@transient
11811195
protected lazy val platformCapabilities: CLCapabilities = {
1182-
CL.createPlatformCapabilities(platformId)
1196+
CL.createPlatformCapabilities(platformId.handle)
11831197
}
11841198

11851199
protected def createCommandQueue(deviceId: DeviceId, properties: Map[Int, Long]): CommandQueue = new CommandQueue(
11861200
if (deviceCapabilities(deviceId).OpenCL20) {
11871201
val cl20Properties = (properties.view.flatMap { case (key, value) => Seq(key, value) } ++ Seq(0L)).toArray
11881202
val a = Array(0)
11891203
val commandQueue =
1190-
clCreateCommandQueueWithProperties(platformId, deviceId.handle, cl20Properties, a)
1204+
clCreateCommandQueueWithProperties(platformId.handle, deviceId.handle, cl20Properties, a)
11911205
checkErrorCode(a(0))
11921206
commandQueue
11931207
} else {
@@ -1211,7 +1225,7 @@ trait OpenCL extends MonadicCloseable[UnitContinuation] with ImplicitsSingleton
12111225
val stack = stackPush()
12121226
try {
12131227
val errorCodeBuffer = stack.ints(CL_SUCCESS)
1214-
val contextProperties = stack.pointers(CL_CONTEXT_PLATFORM, platformId, 0)
1228+
val contextProperties = stack.pointers(CL_CONTEXT_PLATFORM, platformId.handle, 0)
12151229
val deviceIdBuffer = stack.pointers(deviceIds.view.map(_.handle): _*)
12161230
val context =
12171231
clCreateContext(contextProperties,

benchmarks/src/jmh/scala/com/thoughtworks/compute/benchmarks.scala

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
package com.thoughtworks.compute
22

3+
import com.thoughtworks.compute.OpenCL.Exceptions.DeviceNotFound
34
import com.thoughtworks.compute.benchmarks.RandomNormalState
45
import com.thoughtworks.feature.Factory
56
import com.thoughtworks.future._
@@ -25,23 +26,49 @@ object benchmarks {
2526

2627
trait TensorState {
2728
@Param(Array("CPU", "GPU"))
28-
protected var deviceType: String = _
29+
protected var tensorDeviceType: String = _
2930

3031
trait BenchmarkTensors
3132
extends StrictLogging
3233
with Tensors.UnsafeMathOptimizations
3334
with Tensors.SuppressWarnings
3435
with OpenCL.LogContextNotification
3536
with OpenCL.GlobalExecutionContext
36-
with OpenCL.UseFirstPlatform
3737
with OpenCL.CommandQueuePool
3838
with OpenCL.DontReleaseEventTooEarly
3939
with Tensors.WangHashingRandomNumberGenerator {
40-
4140
@transient
42-
protected lazy val deviceIds: Seq[DeviceId] = {
43-
deviceIdsByType(classOf[CL10].getField(s"CL_DEVICE_TYPE_$deviceType").get(null).asInstanceOf[Int])
41+
protected lazy val (platformId: PlatformId, deviceIds: Seq[DeviceId]) = {
42+
val deviceType = classOf[CL10].getField(s"CL_DEVICE_TYPE_$tensorDeviceType").get(null).asInstanceOf[Int]
43+
44+
object MatchDeviceType {
45+
def unapply(platformId: PlatformId): Option[Seq[DeviceId]] = {
46+
(try {
47+
platformId.deviceIdsByType(deviceType)
48+
} catch {
49+
case e: DeviceNotFound =>
50+
return None
51+
}) match {
52+
case devices if devices.nonEmpty =>
53+
Some(devices)
54+
case _ =>
55+
None
56+
}
57+
58+
}
59+
}
60+
61+
platformIds.collectFirst {
62+
case platformId @ MatchDeviceType(deviceIds) =>
63+
(platformId, deviceIds)
64+
} match {
65+
case None =>
66+
throw new DeviceNotFound(s"$tensorDeviceType device is not found")
67+
case Some(pair) =>
68+
pair
69+
}
4470
}
71+
4572
}
4673
}
4774

0 commit comments

Comments
 (0)