@@ -18,7 +18,8 @@ import org.lwjgl.system.MemoryStack._
18
18
import org .lwjgl .system .Pointer ._
19
19
20
20
import scala .collection .mutable
21
- import com .thoughtworks .compute .Memory .{Aux , Box }
21
+ import com .thoughtworks .each .Monadic ._
22
+ import com .thoughtworks .compute .Memory .Box
22
23
import com .thoughtworks .compute .OpenCL .{Event , checkErrorCode }, Event .Status
23
24
import org .lwjgl .system .jni .JNINativeInterface
24
25
import org .lwjgl .system ._
@@ -86,7 +87,7 @@ object OpenCL {
86
87
private def decodeString (byteBuffer : ByteBuffer ): String = memASCII(byteBuffer)
87
88
88
89
@ volatile
89
- var defaultLogger : (String , ByteBuffer ) => Unit = { (errorInfo : String , data : ByteBuffer ) =>
90
+ var defaultLogger : (String , Option [ ByteBuffer ] ) => Unit = { (errorInfo, data) =>
90
91
// TODO: Add a test for in the case that Context is closed
91
92
Console .err.println(raw """ An OpenCL notify comes out after its corresponding handler is freed
92
93
message: $errorInfo
@@ -105,13 +106,17 @@ object OpenCL {
105
106
private val contextCallback : CLContextCallback = CLContextCallback .create(new CLContextCallbackI {
106
107
def invoke (errInfo : Long , privateInfo : Long , size : Long , userData : Long ): Unit = {
107
108
val errorInfo = decodeString(errInfo)
108
- val data = memByteBuffer(privateInfo, size.toInt)
109
+ val dataOption = if (privateInfo != NULL ) {
110
+ Some (memByteBuffer(privateInfo, size.toInt))
111
+ } else {
112
+ None
113
+ }
109
114
memGlobalRefToObject[OpenCL ](userData) match {
110
115
case null =>
111
- defaultLogger(decodeString(errInfo), memByteBuffer(privateInfo, size.toInt) )
116
+ defaultLogger(decodeString(errInfo), dataOption )
112
117
case opencl =>
113
118
if (size.isValidInt) {
114
- opencl.handleOpenCLNotification(decodeString(errInfo), memByteBuffer(privateInfo, size.toInt) )
119
+ opencl.handleOpenCLNotification(decodeString(errInfo), dataOption )
115
120
} else {
116
121
throw new IllegalArgumentException (s " numberOfBytes( $size) is too large " )
117
122
}
@@ -147,12 +152,13 @@ object OpenCL {
147
152
148
153
final class ImageFormatNotSupported (message : String = null ) extends IllegalStateException (message)
149
154
150
- final class BuildProgramFailure (buildLogs : Map [Long /* device id */ , String ] = Map .empty)
155
+ final class BuildProgramFailure [Owner <: Singleton with OpenCL ](
156
+ buildLogs : Map [DeviceId [Owner ], String ] = Map .empty[DeviceId [Owner ], String ])
151
157
extends IllegalStateException ({
152
158
buildLogs.view
153
159
.map {
154
160
case (deviceId, buildLog) =>
155
- f " CL_BUILD_PROGRAM_FAILURE on device 0x $deviceId%X: \n $buildLog"
161
+ f " CL_BUILD_PROGRAM_FAILURE on device 0x ${ deviceId.handle} %X: \n $buildLog"
156
162
}
157
163
.mkString(" \n " )
158
164
})
@@ -784,15 +790,17 @@ object OpenCL {
784
790
result(0 )
785
791
}
786
792
787
- def deviceIds : Seq [Long ] = {
793
+ def deviceIds : Seq [DeviceId [ Owner ] ] = {
788
794
val stack = stackPush()
789
795
try {
790
796
val sizeBuffer = stack.mallocPointer(1 )
791
797
checkErrorCode(clGetProgramInfo(this .handle, CL_PROGRAM_DEVICES , null : PointerBuffer , sizeBuffer))
792
798
val numberOfDeviceIds = sizeBuffer.get(0 ).toInt / POINTER_SIZE
793
799
val programDevicesBuffer = stack.mallocPointer(numberOfDeviceIds)
794
800
checkErrorCode(clGetProgramInfo(this .handle, CL_PROGRAM_DEVICES , programDevicesBuffer, sizeBuffer))
795
- (0 until numberOfDeviceIds).map(programDevicesBuffer.get)
801
+ (0 until numberOfDeviceIds).map { i =>
802
+ DeviceId [Owner ](programDevicesBuffer.get(i))
803
+ }
796
804
} finally {
797
805
stack.close()
798
806
}
@@ -821,16 +829,16 @@ object OpenCL {
821
829
}
822
830
}
823
831
824
- private def buildLogs (deviceIds : Seq [Long ] ): Map [Long /* device ID */ , String ] = {
832
+ private def buildLogs (deviceIds : Seq [DeviceId [ Owner ]] ): Map [DeviceId [ Owner ] , String ] = {
825
833
val stack = stackPush()
826
834
try {
827
835
val sizeBuffer = stack.mallocPointer(1 )
828
836
deviceIds.view.map { deviceId =>
829
837
checkErrorCode(
830
- clGetProgramBuildInfo(this .handle, deviceId, CL_PROGRAM_BUILD_LOG , null : PointerBuffer , sizeBuffer))
838
+ clGetProgramBuildInfo(this .handle, deviceId.handle , CL_PROGRAM_BUILD_LOG , null : PointerBuffer , sizeBuffer))
831
839
val logBuffer = MemoryUtil .memAlloc(sizeBuffer.get(0 ).toInt) // stack.malloc()
832
840
try {
833
- checkErrorCode(clGetProgramBuildInfo(this .handle, deviceId, CL_PROGRAM_BUILD_LOG , logBuffer, null ))
841
+ checkErrorCode(clGetProgramBuildInfo(this .handle, deviceId.handle , CL_PROGRAM_BUILD_LOG , logBuffer, null ))
834
842
(deviceId, decodeString(logBuffer))
835
843
} finally {
836
844
MemoryUtil .memFree(logBuffer)
@@ -841,7 +849,7 @@ object OpenCL {
841
849
}
842
850
}
843
851
844
- private def checkBuildErrorCode (deviceIdsOption : Option [Seq [Long ]], errorCode : Int ): Unit = {
852
+ private def checkBuildErrorCode (deviceIdsOption : Option [Seq [DeviceId [ Owner ] ]], errorCode : Int ): Unit = {
845
853
errorCode match {
846
854
case CL_BUILD_PROGRAM_FAILURE =>
847
855
val logs = deviceIdsOption match {
@@ -853,10 +861,12 @@ object OpenCL {
853
861
}
854
862
}
855
863
856
- def build (deviceIds : Seq [Long ], options : CharSequence = " " ): Unit = {
864
+ def build (deviceIds : Seq [DeviceId [ Owner ] ], options : CharSequence = " " ): Unit = {
857
865
val stack = stackPush()
858
866
try {
859
- checkBuildErrorCode(Some (deviceIds), clBuildProgram(handle, stack.pointers(deviceIds : _* ), options, null , NULL ))
867
+ checkBuildErrorCode(
868
+ Some (deviceIds),
869
+ clBuildProgram(handle, stack.pointers(deviceIds.view.map(_.handle): _* ), options, null , NULL ))
860
870
} finally {
861
871
stack.close()
862
872
}
@@ -920,7 +930,7 @@ object OpenCL {
920
930
commandQueue : CommandQueue ,
921
931
deviceBuffer : DeviceBuffer [Element ],
922
932
hostBuffer : Destination ,
923
- preconditionEvents : Event * )(implicit memory : Aux [Element , Destination ]): Do [Event ] =
933
+ preconditionEvents : Event * )(implicit memory : Memory . Aux [Element , Destination ]): Do [Event ] =
924
934
super .enqueueReadBuffer(commandQueue, deviceBuffer, hostBuffer, preconditionEvents : _* ).map { event =>
925
935
@ tailrec
926
936
def enqueueEvent (): Unit = {
@@ -1010,8 +1020,13 @@ object OpenCL {
1010
1020
1011
1021
trait LogContextNotification extends OpenCL {
1012
1022
1013
- protected def handleOpenCLNotification (errorInfo : String , privateInfo : ByteBuffer ): Unit = {
1014
- Logger .takingImplicit[ByteBuffer ](logger.underlying).info(errorInfo)(privateInfo)
1023
+ protected def handleOpenCLNotification (errorInfo : String , privateInfoOption : Option [ByteBuffer ]): Unit = {
1024
+ privateInfoOption match {
1025
+ case None =>
1026
+ logger.info(errorInfo)
1027
+ case Some (privateInfo) =>
1028
+ Logger .takingImplicit[ByteBuffer ](logger.underlying).info(errorInfo)(privateInfo)
1029
+ }
1015
1030
}
1016
1031
}
1017
1032
@@ -1184,7 +1199,7 @@ trait OpenCL extends MonadicCloseable[UnitContinuation] with ImplicitsSingleton
1184
1199
releaseContext >> super .monadicClose
1185
1200
}
1186
1201
1187
- protected def handleOpenCLNotification (errorInfo : String , privateInfo : ByteBuffer ): Unit
1202
+ protected def handleOpenCLNotification (errorInfo : String , privateInfo : Option [ ByteBuffer ] ): Unit
1188
1203
1189
1204
import OpenCL ._
1190
1205
0 commit comments