Skip to content

Commit 8e868ed

Browse files
authored
Merge pull request #97 from ThoughtWorksInc/switch-gpu
Search device according to sbt setting key or JMH parameter
2 parents acdde8c + be85409 commit 8e868ed

File tree

3 files changed

+154
-127
lines changed

3 files changed

+154
-127
lines changed

OpenCL/src/main/scala/com/thoughtworks/compute/OpenCL.scala

Lines changed: 92 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,29 @@ import scala.language.higherKinds
5050
*/
5151
object OpenCL {
5252

53+
final case class PlatformId[Owner <: Singleton with OpenCL](handle: Long) extends AnyVal {
54+
55+
final def deviceIdsByType(deviceType: Int): Seq[DeviceId[Owner]] = {
56+
val Array(numberOfDevices) = {
57+
val a = Array(0)
58+
checkErrorCode(clGetDeviceIDs(handle, deviceType, null, a))
59+
a
60+
}
61+
val stack = stackPush()
62+
try {
63+
val deviceIdBuffer = stack.mallocPointer(numberOfDevices)
64+
checkErrorCode(clGetDeviceIDs(handle, deviceType, deviceIdBuffer, null: IntBuffer))
65+
for (i <- 0 until numberOfDevices) yield {
66+
val deviceId = deviceIdBuffer.get(i)
67+
new DeviceId[Owner](deviceId)
68+
}
69+
} finally {
70+
stack.close()
71+
}
72+
}
73+
74+
}
75+
5376
/** Returns a [[String]] for the C string `address`.
5477
*
5578
* @note We don't know the exact charset of the C string. Use [[memASCII]] because lwjgl treats them as ASCII.
@@ -96,33 +119,33 @@ object OpenCL {
96119
}
97120
})
98121
object Exceptions {
99-
final class MisalignedSubBufferOffset extends IllegalArgumentException
122+
final class MisalignedSubBufferOffset(message: String = null) extends IllegalArgumentException(message)
100123

101-
final class ExecStatusErrorForEventsInWaitList extends IllegalArgumentException
124+
final class ExecStatusErrorForEventsInWaitList(message: String = null) extends IllegalArgumentException(message)
102125

103-
final class InvalidProperty extends IllegalArgumentException
126+
final class InvalidProperty(message: String = null) extends IllegalArgumentException(message)
104127

105-
final class PlatformNotFoundKhr extends IllegalStateException
128+
final class PlatformNotFoundKhr(message: String = null) extends NoSuchElementException(message)
106129

107-
final class DeviceNotFound extends IllegalArgumentException
130+
final class DeviceNotFound(message: String = null) extends NoSuchElementException(message)
108131

109-
final class DeviceNotAvailable extends IllegalStateException
132+
final class DeviceNotAvailable(message: String = null) extends IllegalStateException(message)
110133

111-
final class CompilerNotAvailable extends IllegalStateException
134+
final class CompilerNotAvailable(message: String = null) extends IllegalStateException(message)
112135

113-
final class MemObjectAllocationFailure extends IllegalStateException
136+
final class MemObjectAllocationFailure(message: String = null) extends IllegalStateException(message)
114137

115-
final class OutOfResources extends IllegalStateException
138+
final class OutOfResources(message: String = null) extends IllegalStateException(message)
116139

117-
final class OutOfHostMemory extends IllegalStateException
140+
final class OutOfHostMemory(message: String = null) extends IllegalStateException(message)
118141

119-
final class ProfilingInfoNotAvailable extends IllegalStateException
142+
final class ProfilingInfoNotAvailable(message: String = null) extends IllegalStateException(message)
120143

121-
final class MemCopyOverlap extends IllegalStateException
144+
final class MemCopyOverlap(message: String = null) extends IllegalStateException(message)
122145

123-
final class ImageFormatMismatch extends IllegalStateException
146+
final class ImageFormatMismatch(message: String = null) extends IllegalStateException(message)
124147

125-
final class ImageFormatNotSupported extends IllegalStateException
148+
final class ImageFormatNotSupported(message: String = null) extends IllegalStateException(message)
126149

127150
final class BuildProgramFailure(buildLogs: Map[Long /* device id */, String] = Map.empty)
128151
extends IllegalStateException({
@@ -134,71 +157,71 @@ object OpenCL {
134157
.mkString("\n")
135158
})
136159

137-
final class MapFailure extends IllegalStateException
160+
final class MapFailure(message: String = null) extends IllegalStateException(message)
138161

139-
final class InvalidValue extends IllegalArgumentException
162+
final class InvalidValue(message: String = null) extends IllegalArgumentException(message)
140163

141-
final class InvalidDeviceType extends IllegalArgumentException
164+
final class InvalidDeviceType(message: String = null) extends IllegalArgumentException(message)
142165

143-
final class InvalidPlatform extends IllegalArgumentException
166+
final class InvalidPlatform(message: String = null) extends IllegalArgumentException(message)
144167

145-
final class InvalidDevice extends IllegalArgumentException
168+
final class InvalidDevice(message: String = null) extends IllegalArgumentException(message)
146169

147-
final class InvalidContext extends IllegalArgumentException
170+
final class InvalidContext(message: String = null) extends IllegalArgumentException(message)
148171

149-
final class InvalidQueueProperties extends IllegalArgumentException
172+
final class InvalidQueueProperties(message: String = null) extends IllegalArgumentException(message)
150173

151-
final class InvalidCommandQueue extends IllegalArgumentException
174+
final class InvalidCommandQueue(message: String = null) extends IllegalArgumentException(message)
152175

153-
final class InvalidHostPtr extends IllegalArgumentException
176+
final class InvalidHostPtr(message: String = null) extends IllegalArgumentException(message)
154177

155-
final class InvalidMemObject extends IllegalArgumentException
178+
final class InvalidMemObject(message: String = null) extends IllegalArgumentException(message)
156179

157-
final class InvalidImageFormatDescriptor extends IllegalArgumentException
180+
final class InvalidImageFormatDescriptor(message: String = null) extends IllegalArgumentException(message)
158181

159-
final class InvalidImageSize extends IllegalArgumentException
182+
final class InvalidImageSize(message: String = null) extends IllegalArgumentException(message)
160183

161-
final class InvalidSampler extends IllegalArgumentException
184+
final class InvalidSampler(message: String = null) extends IllegalArgumentException(message)
162185

163-
final class InvalidBinary extends IllegalArgumentException
186+
final class InvalidBinary(message: String = null) extends IllegalArgumentException(message)
164187

165-
final class InvalidBuildOptions extends IllegalArgumentException
188+
final class InvalidBuildOptions(message: String = null) extends IllegalArgumentException(message)
166189

167-
final class InvalidProgram extends IllegalArgumentException
190+
final class InvalidProgram(message: String = null) extends IllegalArgumentException(message)
168191

169-
final class InvalidProgramExecutable extends IllegalArgumentException
192+
final class InvalidProgramExecutable(message: String = null) extends IllegalArgumentException(message)
170193

171-
final class InvalidKernelName extends IllegalArgumentException
194+
final class InvalidKernelName(message: String = null) extends IllegalArgumentException(message)
172195

173-
final class InvalidKernelDefinition extends IllegalArgumentException
196+
final class InvalidKernelDefinition(message: String = null) extends IllegalArgumentException(message)
174197

175-
final class InvalidKernel extends IllegalArgumentException
198+
final class InvalidKernel(message: String = null) extends IllegalArgumentException(message)
176199

177-
final class InvalidArgIndex extends IllegalArgumentException
200+
final class InvalidArgIndex(message: String = null) extends IllegalArgumentException(message)
178201

179-
final class InvalidArgValue extends IllegalArgumentException
202+
final class InvalidArgValue(message: String = null) extends IllegalArgumentException(message)
180203

181-
final class InvalidArgSize extends IllegalArgumentException
204+
final class InvalidArgSize(message: String = null) extends IllegalArgumentException(message)
182205

183-
final class InvalidKernelArgs extends IllegalArgumentException
206+
final class InvalidKernelArgs(message: String = null) extends IllegalArgumentException(message)
184207

185-
final class InvalidWorkDimension extends IllegalArgumentException
208+
final class InvalidWorkDimension(message: String = null) extends IllegalArgumentException(message)
186209

187-
final class InvalidWorkGroupSize extends IllegalArgumentException
210+
final class InvalidWorkGroupSize(message: String = null) extends IllegalArgumentException(message)
188211

189-
final class InvalidWorkItemSize extends IllegalArgumentException
212+
final class InvalidWorkItemSize(message: String = null) extends IllegalArgumentException(message)
190213

191-
final class InvalidGlobalOffset extends IllegalArgumentException
214+
final class InvalidGlobalOffset(message: String = null) extends IllegalArgumentException(message)
192215

193-
final class InvalidEventWaitList extends IllegalArgumentException
216+
final class InvalidEventWaitList(message: String = null) extends IllegalArgumentException(message)
194217

195-
final class InvalidEvent extends IllegalArgumentException
218+
final class InvalidEvent(message: String = null) extends IllegalArgumentException(message)
196219

197-
final class InvalidOperation extends IllegalArgumentException
220+
final class InvalidOperation(message: String = null) extends IllegalArgumentException(message)
198221

199-
final class InvalidBufferSize extends IllegalArgumentException
222+
final class InvalidBufferSize(message: String = null) extends IllegalArgumentException(message)
200223

201-
final class InvalidGlobalWorkSize extends IllegalArgumentException
224+
final class InvalidGlobalWorkSize(message: String = null) extends IllegalArgumentException(message)
202225

203226
final class UnknownErrorCode(errorCode: Int) extends IllegalStateException(s"Unknown error code: $errorCode")
204227

@@ -265,25 +288,18 @@ object OpenCL {
265288
}
266289
}
267290

268-
trait UseFirstPlatform {
291+
trait UseFirstPlatform extends OpenCL {
269292
@transient
270-
protected lazy val platformId: Long = {
271-
val stack = stackPush()
272-
try {
273-
val platformIdBuffer = stack.mallocPointer(1)
274-
checkErrorCode(clGetPlatformIDs(platformIdBuffer, null: IntBuffer))
275-
platformIdBuffer.get(0)
276-
} finally {
277-
stack.close()
278-
}
293+
protected lazy val platformId: PlatformId = {
294+
platformIds.head
279295
}
280296
}
281297

282298
trait UseAllDevices extends OpenCL {
283299

284300
@transient
285301
protected lazy val deviceIds: Seq[DeviceId] = {
286-
deviceIdsByType(CL_DEVICE_TYPE_ALL)
302+
platformId.deviceIdsByType(CL_DEVICE_TYPE_ALL)
287303
}
288304

289305
}
@@ -292,7 +308,7 @@ object OpenCL {
292308

293309
@transient
294310
protected lazy val deviceIds: Seq[DeviceId] = {
295-
val allDeviceIds = deviceIdsByType(CL_DEVICE_TYPE_ALL)
311+
val allDeviceIds = platformId.deviceIdsByType(CL_DEVICE_TYPE_ALL)
296312
Seq(allDeviceIds.head)
297313
}
298314

@@ -302,23 +318,23 @@ object OpenCL {
302318

303319
@transient
304320
protected lazy val deviceIds: Seq[DeviceId] = {
305-
deviceIdsByType(CL_DEVICE_TYPE_GPU)
321+
platformId.deviceIdsByType(CL_DEVICE_TYPE_GPU)
306322
}
307323
}
308324

309325
trait UseFirstGpuDevice extends OpenCL {
310326

311327
@transient
312328
protected lazy val deviceIds: Seq[DeviceId] = {
313-
val allDeviceIds = deviceIdsByType(CL_DEVICE_TYPE_GPU)
329+
val allDeviceIds = platformId.deviceIdsByType(CL_DEVICE_TYPE_GPU)
314330
Seq(allDeviceIds.head)
315331
}
316332
}
317333
trait UseFirstCpuDevice extends OpenCL {
318334

319335
@transient
320336
protected lazy val deviceIds: Seq[DeviceId] = {
321-
val allDeviceIds = deviceIdsByType(CL_DEVICE_TYPE_CPU)
337+
val allDeviceIds = platformId.deviceIdsByType(CL_DEVICE_TYPE_CPU)
322338
Seq(allDeviceIds.head)
323339
}
324340
}
@@ -327,7 +343,7 @@ object OpenCL {
327343

328344
@transient
329345
protected lazy val deviceIds: Seq[DeviceId] = {
330-
deviceIdsByType(CL_DEVICE_TYPE_CPU)
346+
platformId.deviceIdsByType(CL_DEVICE_TYPE_CPU)
331347
}
332348
}
333349

@@ -1010,20 +1026,18 @@ trait OpenCL extends MonadicCloseable[UnitContinuation] with ImplicitsSingleton
10101026
type Event = OpenCL.Event[this.type]
10111027
type CommandQueue = OpenCL.CommandQueue[this.type]
10121028
type DeviceId = OpenCL.DeviceId[this.type]
1029+
type PlatformId = OpenCL.PlatformId[this.type]
10131030

1014-
protected final def deviceIdsByType(deviceType: Int): Seq[DeviceId] = {
1015-
val Array(numberOfDevices) = {
1016-
val a = Array(0)
1017-
checkErrorCode(clGetDeviceIDs(platformId, deviceType, null, a))
1018-
a
1019-
}
1031+
def platformIds: Seq[PlatformId] = {
10201032
val stack = stackPush()
10211033
try {
1022-
val deviceIdBuffer = stack.mallocPointer(numberOfDevices)
1023-
checkErrorCode(clGetDeviceIDs(platformId, deviceType, deviceIdBuffer, null: IntBuffer))
1024-
for (i <- 0 until deviceIdBuffer.capacity()) yield {
1025-
val deviceId = deviceIdBuffer.get(i)
1026-
new DeviceId(deviceId)
1034+
val numberOfPlatformsBuffer = stack.mallocInt(1)
1035+
checkErrorCode(clGetPlatformIDs(null, numberOfPlatformsBuffer))
1036+
val numberOfPlatforms = numberOfPlatformsBuffer.get(0)
1037+
val platformIdBuffer = stack.mallocPointer(numberOfPlatforms)
1038+
checkErrorCode(clGetPlatformIDs(platformIdBuffer, null: IntBuffer))
1039+
(0 until numberOfPlatforms).map { i =>
1040+
new PlatformId(platformIdBuffer.get(i))
10271041
}
10281042
} finally {
10291043
stack.close()
@@ -1174,20 +1188,20 @@ trait OpenCL extends MonadicCloseable[UnitContinuation] with ImplicitsSingleton
11741188

11751189
import OpenCL._
11761190

1177-
protected val platformId: Long
1191+
protected val platformId: PlatformId
11781192
protected val deviceIds: Seq[DeviceId]
11791193

11801194
@transient
11811195
protected lazy val platformCapabilities: CLCapabilities = {
1182-
CL.createPlatformCapabilities(platformId)
1196+
CL.createPlatformCapabilities(platformId.handle)
11831197
}
11841198

11851199
protected def createCommandQueue(deviceId: DeviceId, properties: Map[Int, Long]): CommandQueue = new CommandQueue(
11861200
if (deviceCapabilities(deviceId).OpenCL20) {
11871201
val cl20Properties = (properties.view.flatMap { case (key, value) => Seq(key, value) } ++ Seq(0L)).toArray
11881202
val a = Array(0)
11891203
val commandQueue =
1190-
clCreateCommandQueueWithProperties(platformId, deviceId.handle, cl20Properties, a)
1204+
clCreateCommandQueueWithProperties(platformId.handle, deviceId.handle, cl20Properties, a)
11911205
checkErrorCode(a(0))
11921206
commandQueue
11931207
} else {
@@ -1211,7 +1225,7 @@ trait OpenCL extends MonadicCloseable[UnitContinuation] with ImplicitsSingleton
12111225
val stack = stackPush()
12121226
try {
12131227
val errorCodeBuffer = stack.ints(CL_SUCCESS)
1214-
val contextProperties = stack.pointers(CL_CONTEXT_PLATFORM, platformId, 0)
1228+
val contextProperties = stack.pointers(CL_CONTEXT_PLATFORM, platformId.handle, 0)
12151229
val deviceIdBuffer = stack.pointers(deviceIds.view.map(_.handle): _*)
12161230
val context =
12171231
clCreateContext(contextProperties,

benchmarks/build.sbt

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@ enablePlugins(JmhPlugin)
22

33
libraryDependencies += "org.nd4j" % "nd4j-api" % "0.8.0"
44

5-
libraryDependencies += "org.nd4j" % "nd4j-cuda-8.0-platform" % "0.8.0"
5+
val nd4jRuntime = settingKey[String]("\"cuda-8.0\" to run benchmark on GPU, \"native\" to run benchmark on CPU.")
66

7-
libraryDependencies += "org.nd4j" % "nd4j-native-platform" % "0.8.0"
7+
nd4jRuntime in Global := "native"
8+
9+
libraryDependencies += "org.nd4j" % s"nd4j-${nd4jRuntime.value}-platform" % "0.8.0"
810

911
libraryDependencies += ("org.lwjgl" % "lwjgl" % "3.1.6").jar().classifier {
1012
import scala.util.Properties._

0 commit comments

Comments
 (0)