Skip to content

Commit 2a5182e

Browse files
committed
Make DeviceId a case class instead of raw Long
1 parent 097ac0b commit 2a5182e

File tree

2 files changed

+35
-20
lines changed

2 files changed

+35
-20
lines changed

OpenCL/build.sbt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,6 @@ fork := true
3535

3636
scalacOptions += "-Ypartial-unification"
3737

38-
libraryDependencies += "com.thoughtworks.each" %% "each" % "3.3.1" % Test
38+
libraryDependencies += "com.thoughtworks.each" %% "each" % "3.3.1"
3939

4040
enablePlugins(Example)

OpenCL/src/main/scala/com/thoughtworks/compute/OpenCL.scala

Lines changed: 34 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@ import org.lwjgl.system.MemoryStack._
1818
import org.lwjgl.system.Pointer._
1919

2020
import scala.collection.mutable
21-
import com.thoughtworks.compute.Memory.{Aux, Box}
21+
import com.thoughtworks.each.Monadic._
22+
import com.thoughtworks.compute.Memory.Box
2223
import com.thoughtworks.compute.OpenCL.{Event, checkErrorCode}, Event.Status
2324
import org.lwjgl.system.jni.JNINativeInterface
2425
import org.lwjgl.system._
@@ -86,7 +87,7 @@ object OpenCL {
8687
private def decodeString(byteBuffer: ByteBuffer): String = memASCII(byteBuffer)
8788

8889
@volatile
89-
var defaultLogger: (String, ByteBuffer) => Unit = { (errorInfo: String, data: ByteBuffer) =>
90+
var defaultLogger: (String, Option[ByteBuffer]) => Unit = { (errorInfo, data) =>
9091
// TODO: Add a test for in the case that Context is closed
9192
Console.err.println(raw"""An OpenCL notify comes out after its corresponding handler is freed
9293
message: $errorInfo
@@ -105,13 +106,17 @@ object OpenCL {
105106
private val contextCallback: CLContextCallback = CLContextCallback.create(new CLContextCallbackI {
106107
def invoke(errInfo: Long, privateInfo: Long, size: Long, userData: Long): Unit = {
107108
val errorInfo = decodeString(errInfo)
108-
val data = memByteBuffer(privateInfo, size.toInt)
109+
val dataOption = if (privateInfo != NULL) {
110+
Some(memByteBuffer(privateInfo, size.toInt))
111+
} else {
112+
None
113+
}
109114
memGlobalRefToObject[OpenCL](userData) match {
110115
case null =>
111-
defaultLogger(decodeString(errInfo), memByteBuffer(privateInfo, size.toInt))
116+
defaultLogger(decodeString(errInfo), dataOption)
112117
case opencl =>
113118
if (size.isValidInt) {
114-
opencl.handleOpenCLNotification(decodeString(errInfo), memByteBuffer(privateInfo, size.toInt))
119+
opencl.handleOpenCLNotification(decodeString(errInfo), dataOption)
115120
} else {
116121
throw new IllegalArgumentException(s"numberOfBytes($size) is too large")
117122
}
@@ -147,12 +152,13 @@ object OpenCL {
147152

148153
final class ImageFormatNotSupported(message: String = null) extends IllegalStateException(message)
149154

150-
final class BuildProgramFailure(buildLogs: Map[Long /* device id */, String] = Map.empty)
155+
final class BuildProgramFailure[Owner <: Singleton with OpenCL](
156+
buildLogs: Map[DeviceId[Owner], String] = Map.empty[DeviceId[Owner], String])
151157
extends IllegalStateException({
152158
buildLogs.view
153159
.map {
154160
case (deviceId, buildLog) =>
155-
f"CL_BUILD_PROGRAM_FAILURE on device 0x$deviceId%X:\n$buildLog"
161+
f"CL_BUILD_PROGRAM_FAILURE on device 0x${deviceId.handle}%X:\n$buildLog"
156162
}
157163
.mkString("\n")
158164
})
@@ -784,15 +790,17 @@ object OpenCL {
784790
result(0)
785791
}
786792

787-
def deviceIds: Seq[Long] = {
793+
def deviceIds: Seq[DeviceId[Owner]] = {
788794
val stack = stackPush()
789795
try {
790796
val sizeBuffer = stack.mallocPointer(1)
791797
checkErrorCode(clGetProgramInfo(this.handle, CL_PROGRAM_DEVICES, null: PointerBuffer, sizeBuffer))
792798
val numberOfDeviceIds = sizeBuffer.get(0).toInt / POINTER_SIZE
793799
val programDevicesBuffer = stack.mallocPointer(numberOfDeviceIds)
794800
checkErrorCode(clGetProgramInfo(this.handle, CL_PROGRAM_DEVICES, programDevicesBuffer, sizeBuffer))
795-
(0 until numberOfDeviceIds).map(programDevicesBuffer.get)
801+
(0 until numberOfDeviceIds).map { i =>
802+
DeviceId[Owner](programDevicesBuffer.get(i))
803+
}
796804
} finally {
797805
stack.close()
798806
}
@@ -821,16 +829,16 @@ object OpenCL {
821829
}
822830
}
823831

824-
private def buildLogs(deviceIds: Seq[Long]): Map[Long /* device ID */, String] = {
832+
private def buildLogs(deviceIds: Seq[DeviceId[Owner]]): Map[DeviceId[Owner], String] = {
825833
val stack = stackPush()
826834
try {
827835
val sizeBuffer = stack.mallocPointer(1)
828836
deviceIds.view.map { deviceId =>
829837
checkErrorCode(
830-
clGetProgramBuildInfo(this.handle, deviceId, CL_PROGRAM_BUILD_LOG, null: PointerBuffer, sizeBuffer))
838+
clGetProgramBuildInfo(this.handle, deviceId.handle, CL_PROGRAM_BUILD_LOG, null: PointerBuffer, sizeBuffer))
831839
val logBuffer = MemoryUtil.memAlloc(sizeBuffer.get(0).toInt) //stack.malloc()
832840
try {
833-
checkErrorCode(clGetProgramBuildInfo(this.handle, deviceId, CL_PROGRAM_BUILD_LOG, logBuffer, null))
841+
checkErrorCode(clGetProgramBuildInfo(this.handle, deviceId.handle, CL_PROGRAM_BUILD_LOG, logBuffer, null))
834842
(deviceId, decodeString(logBuffer))
835843
} finally {
836844
MemoryUtil.memFree(logBuffer)
@@ -841,7 +849,7 @@ object OpenCL {
841849
}
842850
}
843851

844-
private def checkBuildErrorCode(deviceIdsOption: Option[Seq[Long]], errorCode: Int): Unit = {
852+
private def checkBuildErrorCode(deviceIdsOption: Option[Seq[DeviceId[Owner]]], errorCode: Int): Unit = {
845853
errorCode match {
846854
case CL_BUILD_PROGRAM_FAILURE =>
847855
val logs = deviceIdsOption match {
@@ -853,10 +861,12 @@ object OpenCL {
853861
}
854862
}
855863

856-
def build(deviceIds: Seq[Long], options: CharSequence = ""): Unit = {
864+
def build(deviceIds: Seq[DeviceId[Owner]], options: CharSequence = ""): Unit = {
857865
val stack = stackPush()
858866
try {
859-
checkBuildErrorCode(Some(deviceIds), clBuildProgram(handle, stack.pointers(deviceIds: _*), options, null, NULL))
867+
checkBuildErrorCode(
868+
Some(deviceIds),
869+
clBuildProgram(handle, stack.pointers(deviceIds.view.map(_.handle): _*), options, null, NULL))
860870
} finally {
861871
stack.close()
862872
}
@@ -920,7 +930,7 @@ object OpenCL {
920930
commandQueue: CommandQueue,
921931
deviceBuffer: DeviceBuffer[Element],
922932
hostBuffer: Destination,
923-
preconditionEvents: Event*)(implicit memory: Aux[Element, Destination]): Do[Event] =
933+
preconditionEvents: Event*)(implicit memory: Memory.Aux[Element, Destination]): Do[Event] =
924934
super.enqueueReadBuffer(commandQueue, deviceBuffer, hostBuffer, preconditionEvents: _*).map { event =>
925935
@tailrec
926936
def enqueueEvent(): Unit = {
@@ -1010,8 +1020,13 @@ object OpenCL {
10101020

10111021
trait LogContextNotification extends OpenCL {
10121022

1013-
protected def handleOpenCLNotification(errorInfo: String, privateInfo: ByteBuffer): Unit = {
1014-
Logger.takingImplicit[ByteBuffer](logger.underlying).info(errorInfo)(privateInfo)
1023+
protected def handleOpenCLNotification(errorInfo: String, privateInfoOption: Option[ByteBuffer]): Unit = {
1024+
privateInfoOption match {
1025+
case None =>
1026+
logger.info(errorInfo)
1027+
case Some(privateInfo) =>
1028+
Logger.takingImplicit[ByteBuffer](logger.underlying).info(errorInfo)(privateInfo)
1029+
}
10151030
}
10161031
}
10171032

@@ -1184,7 +1199,7 @@ trait OpenCL extends MonadicCloseable[UnitContinuation] with ImplicitsSingleton
11841199
releaseContext >> super.monadicClose
11851200
}
11861201

1187-
protected def handleOpenCLNotification(errorInfo: String, privateInfo: ByteBuffer): Unit
1202+
protected def handleOpenCLNotification(errorInfo: String, privateInfo: Option[ByteBuffer]): Unit
11881203

11891204
import OpenCL._
11901205

0 commit comments

Comments
 (0)