@@ -28,21 +28,34 @@ static bool emit_cuda_c_code(CudaKernel* spec) {
28
28
Module * final_mod ;
29
29
emit_c (config , emitter_config , dst_mod , & spec -> cuda_code_size , & spec -> cuda_code , & final_mod );
30
30
spec -> final_module = final_mod ;
31
+
32
+ if (get_log_level () <= DEBUG )
33
+ write_file ("cuda_dump.cu" , spec -> cuda_code_size - 1 , spec -> cuda_code );
34
+
31
35
return true;
32
36
}
33
37
34
38
static bool cuda_c_to_ptx (CudaKernel * kernel ) {
35
39
nvrtcProgram program ;
36
40
CHECK_NVRTC (nvrtcCreateProgram (& program , kernel -> cuda_code , kernel -> key .entry_point , 0 , NULL , NULL ), return false );
37
- nvrtcResult compile_result = nvrtcCompileProgram (program , 0 , false);
41
+
42
+ assert (kernel -> device -> cc_major < 10 && kernel -> device -> cc_minor < 10 );
43
+
44
+ char arch_flag [] = "-arch=compute_00" ;
45
+ arch_flag [14 ] = '0' + kernel -> device -> cc_major ;
46
+ arch_flag [15 ] = '0' + kernel -> device -> cc_minor ;
47
+
48
+ const char * options [] = {
49
+ arch_flag ,
50
+ "--use_fast_math"
51
+ };
52
+
53
+ nvrtcResult compile_result = nvrtcCompileProgram (program , sizeof (options )/sizeof (* options ), options );
38
54
if (compile_result != NVRTC_SUCCESS ) {
39
55
error_print ("NVRTC compilation failed: %s\n" , nvrtcGetErrorString (compile_result ));
40
56
debug_print ("Dumping source:\n%s" , kernel -> cuda_code );
41
57
}
42
58
43
- if (get_log_level () <= DEBUG )
44
- write_file ("cuda_dump.cu" , kernel -> cuda_code_size - 1 , kernel -> cuda_code );
45
-
46
59
size_t log_size ;
47
60
CHECK_NVRTC (nvrtcGetProgramLogSize (program , & log_size ), return false );
48
61
char * log_buffer = calloc (log_size , 1 );
@@ -60,13 +73,58 @@ static bool cuda_c_to_ptx(CudaKernel* kernel) {
60
73
read_file (override_file , & kernel -> ptx_size , & kernel -> ptx );
61
74
}
62
75
76
+ if (get_log_level () <= DEBUG )
77
+ write_file ("cuda_dump.ptx" , kernel -> ptx_size - 1 , kernel -> ptx );
78
+
63
79
return true;
64
80
}
65
81
66
82
static bool load_ptx_into_cuda_program (CudaKernel * kernel ) {
67
- CHECK_CUDA (cuModuleLoadDataEx (& kernel -> cuda_module , kernel -> ptx , 0 , NULL , NULL ), return false );
68
- CHECK_CUDA (cuModuleGetFunction (& kernel -> entry_point_function , kernel -> cuda_module , kernel -> key .entry_point ), return false );
83
+ char info_log [10240 ] = {};
84
+ char error_log [10240 ] = {};
85
+
86
+ CUjit_option options [] = {
87
+ CU_JIT_INFO_LOG_BUFFER , CU_JIT_INFO_LOG_BUFFER_SIZE_BYTES ,
88
+ CU_JIT_ERROR_LOG_BUFFER , CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES ,
89
+ CU_JIT_TARGET
90
+ };
91
+
92
+ void * option_values [] = {
93
+ info_log , (void * )(uintptr_t )sizeof (info_log ),
94
+ error_log , (void * )(uintptr_t )sizeof (error_log ),
95
+ (void * )(uintptr_t )(kernel -> device -> cc_major * 10 + kernel -> device -> cc_minor )
96
+ };
97
+
98
+ CUlinkState linker ;
99
+ CHECK_CUDA (cuLinkCreate (sizeof (options )/sizeof (options [0 ]), options , option_values , & linker ), goto err_linker_create );
100
+ CHECK_CUDA (cuLinkAddData (linker , CU_JIT_INPUT_PTX , kernel -> ptx , kernel -> ptx_size , NULL , 0U , NULL , NULL ), goto err_post_linker_create );
101
+
102
+ void * binary ;
103
+ size_t binary_size ;
104
+ CHECK_CUDA (cuLinkComplete (linker , & binary , & binary_size ), goto err_post_linker_create );
105
+
106
+ if (* info_log )
107
+ info_print ("CUDA JIT info: %s\n" , info_log );
108
+
109
+ if (get_log_level () <= DEBUG )
110
+ write_file ("cuda_dump.cubin" , binary_size , binary );
111
+
112
+ CHECK_CUDA (cuModuleLoadData (& kernel -> cuda_module , binary ), goto err_post_linker_create );
113
+ CHECK_CUDA (cuModuleGetFunction (& kernel -> entry_point_function , kernel -> cuda_module , kernel -> key .entry_point ), goto err_post_module_load );
114
+
115
+ cuLinkDestroy (linker );
69
116
return true;
117
+
118
+ err_post_module_load :
119
+ cuModuleUnload (kernel -> cuda_module );
120
+ err_post_linker_create :
121
+ cuLinkDestroy (linker );
122
+ if (* info_log )
123
+ info_print ("CUDA JIT info: %s\n" , info_log );
124
+ if (* error_log )
125
+ error_print ("CUDA JIT failed: %s\n" , error_log );
126
+ err_linker_create :
127
+ return false;
70
128
}
71
129
72
130
static CudaKernel * create_specialized_program (CudaDevice * device , SpecProgramKey key ) {
0 commit comments