@@ -1739,19 +1739,19 @@ static int mca_common_cuda_is_gpu_buffer(const void *pUserBuf, opal_convertor_t
17391739 int res ;
17401740 CUmemorytype memType = 0 ;
17411741 CUdeviceptr dbuf = (CUdeviceptr )pUserBuf ;
1742- CUcontext ctx = NULL ;
1742+ CUcontext ctx = NULL , memCtx = NULL ;
17431743#if OPAL_CUDA_GET_ATTRIBUTES
17441744 uint32_t isManaged = 0 ;
17451745 /* With CUDA 7.0, we can get multiple attributes with a single call */
17461746 CUpointer_attribute attributes [3 ] = {CU_POINTER_ATTRIBUTE_MEMORY_TYPE ,
17471747 CU_POINTER_ATTRIBUTE_CONTEXT ,
17481748 CU_POINTER_ATTRIBUTE_IS_MANAGED };
1749- void * attrdata [] = {(void * )& memType , (void * )& ctx , (void * )& isManaged };
1749+ void * attrdata [] = {(void * )& memType , (void * )& memCtx , (void * )& isManaged };
17501750
17511751 res = cuFunc .cuPointerGetAttributes (3 , attributes , attrdata , dbuf );
17521752 OPAL_OUTPUT_VERBOSE ((101 , mca_common_cuda_output ,
1753- "dbuf=%p, memType=%d, ctx =%p, isManaged=%d, res=%d" ,
1754- (void * )dbuf , (int )memType , (void * )ctx , isManaged , res ));
1753+ "dbuf=%p, memType=%d, memCtx =%p, isManaged=%d, res=%d" ,
1754+ (void * )dbuf , (int )memType , (void * )memCtx , isManaged , res ));
17551755
17561756 /* Mark unified memory buffers with a flag. This will allow all unified
17571757 * memory to be forced through host buffers. Note that this memory can
@@ -1787,6 +1787,7 @@ static int mca_common_cuda_is_gpu_buffer(const void *pUserBuf, opal_convertor_t
17871787 }
17881788 /* Must be a device pointer */
17891789 assert (memType == CU_MEMORYTYPE_DEVICE );
1790+ #endif /* OPAL_CUDA_GET_ATTRIBUTES */
17901791
17911792 /* This piece of code was added in to handle in a case involving
17921793 * OMP threads. The user had initialized CUDA and then spawned
@@ -1797,25 +1798,25 @@ static int mca_common_cuda_is_gpu_buffer(const void *pUserBuf, opal_convertor_t
17971798 * and set the current context to that. It is rare that we will not
17981799 * have a context. */
17991800 res = cuFunc .cuCtxGetCurrent (& ctx );
1800- #endif /* OPAL_CUDA_GET_ATTRIBUTES */
18011801 if (OPAL_UNLIKELY (NULL == ctx )) {
18021802 if (CUDA_SUCCESS == res ) {
1803- res = cuFunc .cuPointerGetAttribute (& ctx ,
1803+ #if !OPAL_CUDA_GET_ATTRIBUTES
1804+ res = cuFunc .cuPointerGetAttribute (& memCtx ,
18041805 CU_POINTER_ATTRIBUTE_CONTEXT , dbuf );
1805- if (res != CUDA_SUCCESS ) {
1806+ if (OPAL_UNLIKELEY ( res != CUDA_SUCCESS ) ) {
18061807 opal_output (0 , "CUDA: error calling cuPointerGetAttribute: "
18071808 "res=%d, ptr=%p aborting..." , res , pUserBuf );
18081809 return OPAL_ERROR ;
1810+ }
1811+ #endif /* OPAL_CUDA_GET_ATTRIBUTES */
1812+ res = cuFunc .cuCtxSetCurrent (memCtx );
1813+ if (OPAL_UNLIKELY (res != CUDA_SUCCESS )) {
1814+ opal_output (0 , "CUDA: error calling cuCtxSetCurrent: "
1815+ "res=%d, ptr=%p aborting..." , res , pUserBuf );
1816+ return OPAL_ERROR ;
18091817 } else {
1810- res = cuFunc .cuCtxSetCurrent (ctx );
1811- if (res != CUDA_SUCCESS ) {
1812- opal_output (0 , "CUDA: error calling cuCtxSetCurrent: "
1813- "res=%d, ptr=%p aborting..." , res , pUserBuf );
1814- return OPAL_ERROR ;
1815- } else {
1816- opal_output_verbose (10 , mca_common_cuda_output ,
1817- "CUDA: cuCtxSetCurrent passed: ptr=%p" , pUserBuf );
1818- }
1818+ OPAL_OUTPUT_VERBOSE ((10 , mca_common_cuda_output ,
1819+ "CUDA: cuCtxSetCurrent passed: ptr=%p" , pUserBuf ));
18191820 }
18201821 } else {
18211822 /* Print error and proceed */
0 commit comments