Skip to content

Commit d3fa009

Browse files
improved CUDA module loading
1 parent 13c0351 commit d3fa009

File tree

3 files changed

+68
-6
lines changed

3 files changed

+68
-6
lines changed

src/runtime/cuda/cuda_runtime.c

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,8 @@ static CudaDevice* create_cuda_device(CudaBackend* b, int ordinal) {
6969
.specialized_programs = new_dict(SpecProgramKey, CudaKernel*, (HashFn) hash_spec_program_key, (CmpFn) cmp_spec_program_keys),
7070
};
7171
CHECK_CUDA(cuDeviceGetName(device->name, 255, handle), goto dealloc_and_return_null);
72+
CHECK_CUDA(cuDeviceGetAttribute(&device->cc_major, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, device->handle), goto dealloc_and_return_null);
73+
CHECK_CUDA(cuDeviceGetAttribute(&device->cc_minor, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, device->handle), goto dealloc_and_return_null);
7274
CHECK_CUDA(cuCtxCreate(&device->context, 0, handle), goto dealloc_and_return_null);
7375
return device;
7476

src/runtime/cuda/cuda_runtime_private.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ typedef struct {
2424
CUdevice handle;
2525
CUcontext context;
2626
char name[256];
27+
int cc_major;
28+
int cc_minor;
2729
struct Dict* specialized_programs;
2830
} CudaDevice;
2931

src/runtime/cuda/cuda_runtime_program.c

Lines changed: 64 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,21 +28,34 @@ static bool emit_cuda_c_code(CudaKernel* spec) {
2828
Module* final_mod;
2929
emit_c(config, emitter_config, dst_mod, &spec->cuda_code_size, &spec->cuda_code, &final_mod);
3030
spec->final_module = final_mod;
31+
32+
if (get_log_level() <= DEBUG)
33+
write_file("cuda_dump.cu", spec->cuda_code_size - 1, spec->cuda_code);
34+
3135
return true;
3236
}
3337

3438
static bool cuda_c_to_ptx(CudaKernel* kernel) {
3539
nvrtcProgram program;
3640
CHECK_NVRTC(nvrtcCreateProgram(&program, kernel->cuda_code, kernel->key.entry_point, 0, NULL, NULL), return false);
37-
nvrtcResult compile_result = nvrtcCompileProgram(program, 0, false);
41+
42+
assert(kernel->device->cc_major < 10 && kernel->device->cc_minor < 10);
43+
44+
char arch_flag[] = "-arch=compute_00";
45+
arch_flag[14] = '0' + kernel->device->cc_major;
46+
arch_flag[15] = '0' + kernel->device->cc_minor;
47+
48+
const char* options[] = {
49+
arch_flag,
50+
"--use_fast_math"
51+
};
52+
53+
nvrtcResult compile_result = nvrtcCompileProgram(program, sizeof(options)/sizeof(*options), options);
3854
if (compile_result != NVRTC_SUCCESS) {
3955
error_print("NVRTC compilation failed: %s\n", nvrtcGetErrorString(compile_result));
4056
debug_print("Dumping source:\n%s", kernel->cuda_code);
4157
}
4258

43-
if (get_log_level() <= DEBUG)
44-
write_file("cuda_dump.cu", kernel->cuda_code_size - 1, kernel->cuda_code);
45-
4659
size_t log_size;
4760
CHECK_NVRTC(nvrtcGetProgramLogSize(program, &log_size), return false);
4861
char* log_buffer = calloc(log_size, 1);
@@ -60,13 +73,58 @@ static bool cuda_c_to_ptx(CudaKernel* kernel) {
6073
read_file(override_file, &kernel->ptx_size, &kernel->ptx);
6174
}
6275

76+
if (get_log_level() <= DEBUG)
77+
write_file("cuda_dump.ptx", kernel->ptx_size - 1, kernel->ptx);
78+
6379
return true;
6480
}
6581

6682
static bool load_ptx_into_cuda_program(CudaKernel* kernel) {
67-
CHECK_CUDA(cuModuleLoadDataEx(&kernel->cuda_module, kernel->ptx, 0, NULL, NULL), return false);
68-
CHECK_CUDA(cuModuleGetFunction(&kernel->entry_point_function, kernel->cuda_module, kernel->key.entry_point), return false);
83+
char info_log[10240] = {};
84+
char error_log[10240] = {};
85+
86+
CUjit_option options[] = {
87+
CU_JIT_INFO_LOG_BUFFER, CU_JIT_INFO_LOG_BUFFER_SIZE_BYTES,
88+
CU_JIT_ERROR_LOG_BUFFER, CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES,
89+
CU_JIT_TARGET
90+
};
91+
92+
void* option_values[] = {
93+
info_log, (void*)(uintptr_t)sizeof(info_log),
94+
error_log, (void*)(uintptr_t)sizeof(error_log),
95+
(void*)(uintptr_t)(kernel->device->cc_major * 10 + kernel->device->cc_minor)
96+
};
97+
98+
CUlinkState linker;
99+
CHECK_CUDA(cuLinkCreate(sizeof(options)/sizeof(options[0]), options, option_values, &linker), goto err_linker_create);
100+
CHECK_CUDA(cuLinkAddData(linker, CU_JIT_INPUT_PTX, kernel->ptx, kernel->ptx_size, NULL, 0U, NULL, NULL), goto err_post_linker_create);
101+
102+
void* binary;
103+
size_t binary_size;
104+
CHECK_CUDA(cuLinkComplete(linker, &binary, &binary_size), goto err_post_linker_create);
105+
106+
if (*info_log)
107+
info_print("CUDA JIT info: %s\n", info_log);
108+
109+
if (get_log_level() <= DEBUG)
110+
write_file("cuda_dump.cubin", binary_size, binary);
111+
112+
CHECK_CUDA(cuModuleLoadData(&kernel->cuda_module, binary), goto err_post_linker_create);
113+
CHECK_CUDA(cuModuleGetFunction(&kernel->entry_point_function, kernel->cuda_module, kernel->key.entry_point), goto err_post_module_load);
114+
115+
cuLinkDestroy(linker);
69116
return true;
117+
118+
err_post_module_load:
119+
cuModuleUnload(kernel->cuda_module);
120+
err_post_linker_create:
121+
cuLinkDestroy(linker);
122+
if (*info_log)
123+
info_print("CUDA JIT info: %s\n", info_log);
124+
if (*error_log)
125+
error_print("CUDA JIT failed: %s\n", error_log);
126+
err_linker_create:
127+
return false;
70128
}
71129

72130
static CudaKernel* create_specialized_program(CudaDevice* device, SpecProgramKey key) {

0 commit comments

Comments
 (0)