diff --git a/util/tracer_nvbit/tracer_tool/tracer_tool.cu b/util/tracer_nvbit/tracer_tool/tracer_tool.cu index b36c33b18..4dbef0e62 100644 --- a/util/tracer_nvbit/tracer_tool/tracer_tool.cu +++ b/util/tracer_nvbit/tracer_tool/tracer_tool.cu @@ -34,6 +34,19 @@ #define TRACER_VERSION "5" +static int get_attr_with_kernel_fallback(CUfunction func, + CUfunction_attribute attr) { + int value = 0; + CUresult res = cuFuncGetAttribute(&value, attr, func); + if (res == CUDA_ERROR_INVALID_HANDLE) { + CUdevice dev = 0; + if (cuCtxGetDevice(&dev) == CUDA_SUCCESS) { + cuKernelGetAttribute(&value, attr, (CUkernel)func, dev); + } + } + return value; +} + /* Channel used to communicate from GPU to CPU receiving thread */ #define CHANNEL_SIZE (1l << 20) static __managed__ ChannelDev channel_dev; @@ -502,16 +515,12 @@ static void enter_kernel_launch(CUcontext ctx, CUfunction func, } // Get the number of registers and shared memory size for the kernel - int nregs; - CUDA_SAFECALL(cuFuncGetAttribute(&nregs, CU_FUNC_ATTRIBUTE_NUM_REGS, func)); - - int shmem_static_nbytes; - CUDA_SAFECALL(cuFuncGetAttribute(&shmem_static_nbytes, - CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, func)); - - int binary_version; - CUDA_SAFECALL(cuFuncGetAttribute(&binary_version, - CU_FUNC_ATTRIBUTE_BINARY_VERSION, func)); + int nregs = + get_attr_with_kernel_fallback(func, CU_FUNC_ATTRIBUTE_NUM_REGS); + int shmem_static_nbytes = get_attr_with_kernel_fallback( + func, CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES); + int binary_version = + get_attr_with_kernel_fallback(func, CU_FUNC_ATTRIBUTE_BINARY_VERSION); // Instrument the kernel if needed instrument_function_if_needed(ctx, func);