11#include < OffloadAPI.h>
22#include < ur/ur.hpp>
33#include < ur_api.h>
4- #include < cuda.h>
54
65#include " context.hpp"
76#include " program.hpp"
87#include " ur2offload.hpp"
98
9+ #ifdef UR_CUDA_ENABLED
10+ #include < cuda.h>
11+ #endif
12+
13+ namespace {
14+ // Workaround for Offload not supporting PTX binaries. Force CUDA programs
15+ // to be linked so they end up as CUBIN.
16+ #ifdef UR_CUDA_ENABLED
17+ ur_result_t ProgramCreateCudaWorkaround (ur_context_handle_t hContext,
18+ const uint8_t *Binary, size_t Length,
19+ ur_program_handle_t *phProgram) {
20+ uint8_t *RealBinary;
21+ size_t RealLength;
22+ CUlinkState State;
23+ cuLinkCreate (0 , nullptr , nullptr , &State);
24+
25+ cuLinkAddData (State, CU_JIT_INPUT_PTX, (char *)(Binary), Length, nullptr , 0 ,
26+ nullptr , nullptr );
27+
28+ void *CuBin = nullptr ;
29+ size_t CuBinSize = 0 ;
30+ cuLinkComplete (State, &CuBin, &CuBinSize);
31+ RealBinary = (uint8_t *)CuBin;
32+ RealLength = CuBinSize;
33+ fprintf (stderr, " Performed CUDA bin workaround (size = %lu)\n " , RealLength);
34+
35+ ur_program_handle_t Program = new ur_program_handle_t_ ();
36+ auto Res =
37+ olCreateProgram (reinterpret_cast <ol_device_handle_t >(hContext->Device ),
38+ RealBinary, RealLength, &Program->OffloadProgram );
39+
40+ // Program owns the linked module now
41+ cuLinkDestroy (State);
42+ (void )State;
43+
44+ if (Res != OL_SUCCESS) {
45+ delete Program;
46+ return offloadResultToUR (Res);
47+ }
48+
49+ *phProgram = Program;
50+
51+ return UR_RESULT_SUCCESS;
52+ }
53+ #else
54+ ur_result_t ProgramCreateCudaWorkaround (ur_context_handle_t , const uint8_t *,
55+ size_t , ur_program_handle_t *) {
56+ return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
57+ }
58+ #endif
59+ } // namespace
60+
1061UR_APIEXPORT ur_result_t UR_APICALL urProgramCreateWithBinary (
1162 ur_context_handle_t hContext, uint32_t numDevices,
1263 ur_device_handle_t *phDevices, size_t *pLengths, const uint8_t **ppBinaries,
@@ -15,45 +66,23 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramCreateWithBinary(
1566 return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
1667 }
1768
18- // Workaround for Offload not supporting PTX binaries. Force CUDA programs
19- // to be linked so they end up as CUBIN.
20- uint8_t *RealBinary;
21- size_t RealLength;
2269 ur_platform_handle_t DevicePlatform;
23- bool DidLink = false ;
24- CUlinkState State;
2570 urDeviceGetInfo (phDevices[0 ], UR_DEVICE_INFO_PLATFORM,
2671 sizeof (ur_platform_handle_t ), &DevicePlatform, nullptr );
2772 ur_platform_backend_t PlatformBackend;
2873 urPlatformGetInfo (DevicePlatform, UR_PLATFORM_INFO_BACKEND,
2974 sizeof (ur_platform_backend_t ), &PlatformBackend, nullptr );
3075 if (PlatformBackend == UR_PLATFORM_BACKEND_CUDA) {
31- cuLinkCreate (0 , nullptr , nullptr , &State);
32-
33- cuLinkAddData (State, CU_JIT_INPUT_PTX, (char *)(ppBinaries[0 ]), pLengths[0 ],
34- nullptr , 0 , nullptr , nullptr );
35-
36- void *CuBin = nullptr ;
37- size_t CuBinSize = 0 ;
38- cuLinkComplete (State, &CuBin, &CuBinSize);
39- RealBinary = (uint8_t *)CuBin;
40- RealLength = CuBinSize;
41- DidLink = true ;
42- fprintf (stderr, " Performed CUDA bin workaround (size = %lu)\n " , RealLength);
43- } else {
44- RealBinary = const_cast <uint8_t *>(ppBinaries[0 ]);
45- RealLength = pLengths[0 ];
76+ return ProgramCreateCudaWorkaround (hContext, ppBinaries[0 ], pLengths[0 ],
77+ phProgram);
4678 }
4779
80+ auto *RealBinary = const_cast <uint8_t *>(ppBinaries[0 ]);
81+
4882 ur_program_handle_t Program = new ur_program_handle_t_ ();
4983 auto Res =
5084 olCreateProgram (reinterpret_cast <ol_device_handle_t >(hContext->Device ),
51- RealBinary, RealLength, &Program->OffloadProgram );
52-
53- // Program owns the linked module now
54- if (DidLink) {
55- cuLinkDestroy (State);
56- }
85+ RealBinary, pLengths[0 ], &Program->OffloadProgram );
5786
5887 if (Res != OL_SUCCESS) {
5988 delete Program;
@@ -80,7 +109,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramBuildExp(ur_program_handle_t,
80109 return UR_RESULT_SUCCESS;
81110}
82111
83-
84112UR_APIEXPORT ur_result_t UR_APICALL
85113urProgramRetain (ur_program_handle_t hProgram) {
86114 hProgram->RefCount ++;
0 commit comments