diff --git a/opal/mca/common/cuda/common_cuda.c b/opal/mca/common/cuda/common_cuda.c index 5ab9c125f30..1ce339062fc 100644 --- a/opal/mca/common/cuda/common_cuda.c +++ b/opal/mca/common/cuda/common_cuda.c @@ -1739,19 +1739,19 @@ static int mca_common_cuda_is_gpu_buffer(const void *pUserBuf, opal_convertor_t int res; CUmemorytype memType = 0; CUdeviceptr dbuf = (CUdeviceptr)pUserBuf; - CUcontext ctx = NULL; + CUcontext ctx = NULL, memCtx = NULL; #if OPAL_CUDA_GET_ATTRIBUTES uint32_t isManaged = 0; /* With CUDA 7.0, we can get multiple attributes with a single call */ CUpointer_attribute attributes[3] = {CU_POINTER_ATTRIBUTE_MEMORY_TYPE, CU_POINTER_ATTRIBUTE_CONTEXT, CU_POINTER_ATTRIBUTE_IS_MANAGED}; - void *attrdata[] = {(void *)&memType, (void *)&ctx, (void *)&isManaged}; + void *attrdata[] = {(void *)&memType, (void *)&memCtx, (void *)&isManaged}; res = cuFunc.cuPointerGetAttributes(3, attributes, attrdata, dbuf); OPAL_OUTPUT_VERBOSE((101, mca_common_cuda_output, - "dbuf=%p, memType=%d, ctx=%p, isManaged=%d, res=%d", - (void *)dbuf, (int)memType, (void *)ctx, isManaged, res)); + "dbuf=%p, memType=%d, memCtx=%p, isManaged=%d, res=%d", + (void *)dbuf, (int)memType, (void *)memCtx, isManaged, res)); /* Mark unified memory buffers with a flag. This will allow all unified * memory to be forced through host buffers. Note that this memory can @@ -1787,6 +1787,7 @@ static int mca_common_cuda_is_gpu_buffer(const void *pUserBuf, opal_convertor_t } /* Must be a device pointer */ assert(memType == CU_MEMORYTYPE_DEVICE); +#endif /* OPAL_CUDA_GET_ATTRIBUTES */ /* This piece of code was added in to handle in a case involving * OMP threads. The user had initialized CUDA and then spawned @@ -1797,25 +1798,25 @@ static int mca_common_cuda_is_gpu_buffer(const void *pUserBuf, opal_convertor_t * and set the current context to that. It is rare that we will not * have a context. */ res = cuFunc.cuCtxGetCurrent(&ctx); -#endif /* OPAL_CUDA_GET_ATTRIBUTES */ if (OPAL_UNLIKELY(NULL == ctx)) { if (CUDA_SUCCESS == res) { - res = cuFunc.cuPointerGetAttribute(&ctx, +#if !OPAL_CUDA_GET_ATTRIBUTES + res = cuFunc.cuPointerGetAttribute(&memCtx, CU_POINTER_ATTRIBUTE_CONTEXT, dbuf); - if (res != CUDA_SUCCESS) { + if (OPAL_UNLIKELEY(res != CUDA_SUCCESS)) { opal_output(0, "CUDA: error calling cuPointerGetAttribute: " "res=%d, ptr=%p aborting...", res, pUserBuf); return OPAL_ERROR; + } +#endif /* OPAL_CUDA_GET_ATTRIBUTES */ + res = cuFunc.cuCtxSetCurrent(memCtx); + if (OPAL_UNLIKELY(res != CUDA_SUCCESS)) { + opal_output(0, "CUDA: error calling cuCtxSetCurrent: " + "res=%d, ptr=%p aborting...", res, pUserBuf); + return OPAL_ERROR; } else { - res = cuFunc.cuCtxSetCurrent(ctx); - if (res != CUDA_SUCCESS) { - opal_output(0, "CUDA: error calling cuCtxSetCurrent: " - "res=%d, ptr=%p aborting...", res, pUserBuf); - return OPAL_ERROR; - } else { - opal_output_verbose(10, mca_common_cuda_output, - "CUDA: cuCtxSetCurrent passed: ptr=%p", pUserBuf); - } + OPAL_OUTPUT_VERBOSE((10, mca_common_cuda_output, + "CUDA: cuCtxSetCurrent passed: ptr=%p", pUserBuf)); } } else { /* Print error and proceed */