@@ -5,20 +5,17 @@ package xyz.bluspring.unitytranslate.translator
55// #endif
66import dev.architectury.event.events.common.LifecycleEvent
77import dev.architectury.event.events.common.PlayerEvent
8- import net.minecraft.Util
98import net.minecraft.network.FriendlyByteBuf
109import net.minecraft.server.level.ServerPlayer
1110import net.minecraft.world.entity.player.Player
12- import org.lwjgl.system.APIUtil
13- import org.lwjgl.system.JNI
14- import org.lwjgl.system.MemoryUtil
15- import org.lwjgl.system.SharedLibrary
1611import xyz.bluspring.unitytranslate.Language
1712import xyz.bluspring.unitytranslate.UnityTranslate
1813import xyz.bluspring.unitytranslate.client.UnityTranslateClient
1914import xyz.bluspring.unitytranslate.compat.voicechat.UTVoiceChatCompat
2015import xyz.bluspring.unitytranslate.network.PacketIds
2116import xyz.bluspring.unitytranslate.util.ClassLoaderProviderForkJoinWorkerThreadFactory
17+ import xyz.bluspring.unitytranslate.util.nativeaccess.CudaState
18+ import xyz.bluspring.unitytranslate.util.nativeaccess.NativeAccess
2219import java.util.*
2320import java.util.concurrent.CompletableFuture
2421import java.util.concurrent.ConcurrentLinkedDeque
@@ -141,132 +138,26 @@ object TranslatorManager {
141138 return null
142139 }
143140
144- private var isLibraryLoaded = false
145- private lateinit var library: SharedLibrary
146- private var PFN_cuInit : Long = 0L
147- private var PFN_cuDeviceGetCount : Long = 0L
148- private var PFN_cuDeviceComputeCapability : Long = 0L
141+ var hasChecked = false
149142
150- private var PFN_cuGetErrorName : Long = 0L
151- private var PFN_cuGetErrorString : Long = 0L
152-
153- private fun logCudaError (code : Int , at : String ) {
154- if (code == 0 )
155- return
156-
157- // TODO: these return ??? for some reason.
158- // can we figure out why?
159-
160- val errorCode = if (PFN_cuGetErrorName != MemoryUtil .NULL ) {
161- val ptr = MemoryUtil .nmemAlloc(255 )
162- JNI .callPP(code, ptr, PFN_cuGetErrorName )
163- MemoryUtil .memUTF16(ptr).apply {
164- MemoryUtil .nmemFree(ptr)
165- }
166- } else " [CUDA ERROR NAME NOT FOUND]"
167-
168- val errorDesc = if (PFN_cuGetErrorString != MemoryUtil .NULL ) {
169- val ptr = MemoryUtil .nmemAlloc(255 )
170- JNI .callPP(code, ptr, PFN_cuGetErrorString )
171- MemoryUtil .memUTF16(ptr).apply {
172- MemoryUtil .nmemFree(ptr)
173- }
174- } else " [CUDA ERROR DESC NOT FOUND]"
175-
176- UnityTranslate .logger.error(" CUDA error at $at : $code $errorCode ($errorDesc )" )
177- }
178-
179- private fun isCudaSupported (): Boolean {
143+ fun checkSupportsCuda (): Boolean {
180144 if (! UnityTranslate .config.server.shouldUseCuda) {
181145 UnityTranslate .logger.info(" CUDA is disabled in the config, not enabling CUDA support." )
182146 return false
183147 }
184148
185- if (! isLibraryLoaded) {
186- try {
187- library = if (Util .getPlatform() == Util .OS .WINDOWS ) {
188- APIUtil .apiCreateLibrary(" nvcuda.dll" )
189- } else if (Util .getPlatform() == Util .OS .LINUX ) {
190- APIUtil .apiCreateLibrary(" libcuda.so" )
191- } else {
192- return false
149+ return NativeAccess .isCudaSupported().apply {
150+ if (! hasChecked) {
151+ if (this == CudaState .AVAILABLE )
152+ UnityTranslate .logger.info(" CUDA is supported, using GPU for translations." )
153+ else {
154+ UnityTranslate .logger.info(" CUDA is not supported, using CPU for translations." )
155+ UnityTranslate .logger.info(" CUDA state: $ordinal ($name ): $message " )
193156 }
194157
195- PFN_cuInit = library.getFunctionAddress(" cuInit" )
196- PFN_cuDeviceGetCount = library.getFunctionAddress(" cuDeviceGetCount" )
197- PFN_cuDeviceComputeCapability = library.getFunctionAddress(" cuDeviceComputeCapability" )
198- PFN_cuGetErrorName = library.getFunctionAddress(" cuGetErrorName" )
199- PFN_cuGetErrorString = library.getFunctionAddress(" cuGetErrorString" )
200-
201- if (PFN_cuInit == MemoryUtil .NULL || PFN_cuDeviceGetCount == MemoryUtil .NULL || PFN_cuDeviceComputeCapability == MemoryUtil .NULL ) {
202- // TODO: remove in prod
203- UnityTranslate .logger.info(" CUDA results: $PFN_cuInit $PFN_cuDeviceGetCount $PFN_cuDeviceComputeCapability " )
204- return false
205- }
206- } catch (_: UnsatisfiedLinkError ) {
207- UnityTranslate .logger.warn(" CUDA library failed to load! Not attempting to initialize CUDA functions." )
208- return false
209- } catch (e: Throwable ) {
210- UnityTranslate .logger.warn(" An error occurred while searching for CUDA devices! You don't have to report this, don't worry." )
211- e.printStackTrace()
212- return false
158+ hasChecked = true
213159 }
214-
215- isLibraryLoaded = true
216- }
217-
218- val success = 0
219-
220- if (JNI .callI(0 , PFN_cuInit ).apply {
221- logCudaError(this , " init" )
222- } != success) {
223- return false
224- }
225-
226- val totalPtr = MemoryUtil .nmemAlloc(Int .SIZE_BYTES .toLong())
227- if (JNI .callPI(totalPtr, PFN_cuDeviceGetCount ).apply {
228- logCudaError(this , " get device count" )
229- } != success) {
230- return false
231- }
232-
233- val totalCudaDevices = MemoryUtil .memGetInt(totalPtr)
234- UnityTranslate .logger.info(" Total CUDA devices: $totalCudaDevices " )
235- if (totalCudaDevices <= 0 ) {
236- return false
237- }
238-
239- MemoryUtil .nmemFree(totalPtr)
240-
241- for (i in 0 until totalCudaDevices) {
242- val minorPtr = MemoryUtil .nmemAlloc(Int .SIZE_BYTES .toLong())
243- val majorPtr = MemoryUtil .nmemAlloc(Int .SIZE_BYTES .toLong())
244-
245- if (JNI .callPPI(majorPtr, minorPtr, i, PFN_cuDeviceComputeCapability ).apply {
246- logCudaError(this , " get device compute capability $i " )
247- } != success) {
248- continue
249- }
250-
251- val majorVersion = MemoryUtil .memGetInt(majorPtr)
252- val minorVersion = MemoryUtil .memGetInt(minorPtr)
253-
254- MemoryUtil .nmemFree(majorPtr)
255- MemoryUtil .nmemFree(minorPtr)
256-
257- UnityTranslate .logger.info(" Found device with CUDA compute capability major $majorVersion minor $minorVersion ." )
258-
259- return true
260- }
261-
262- return false
263- }
264-
265- val supportsCuda = isCudaSupported().apply {
266- if (this )
267- UnityTranslate .logger.info(" CUDA is supported, using GPU for translations." )
268- else
269- UnityTranslate .logger.info(" CUDA is not supported, using CPU for translations." )
160+ } == CudaState .AVAILABLE
270161 }
271162
272163 fun installLibreTranslate () {
0 commit comments