@@ -1739,19 +1739,19 @@ static int mca_common_cuda_is_gpu_buffer(const void *pUserBuf, opal_convertor_t
17391739    int  res ;
17401740    CUmemorytype  memType  =  0 ;
17411741    CUdeviceptr  dbuf  =  (CUdeviceptr )pUserBuf ;
1742-     CUcontext  ctx  =  NULL ;
1742+     CUcontext  ctx  =  NULL ,  memCtx   =   NULL ;
17431743#if  OPAL_CUDA_GET_ATTRIBUTES 
17441744    uint32_t  isManaged  =  0 ;
17451745    /* With CUDA 7.0, we can get multiple attributes with a single call */ 
17461746    CUpointer_attribute  attributes [3 ] =  {CU_POINTER_ATTRIBUTE_MEMORY_TYPE ,
17471747                                         CU_POINTER_ATTRIBUTE_CONTEXT ,
17481748                                         CU_POINTER_ATTRIBUTE_IS_MANAGED };
1749-     void  * attrdata [] =  {(void  * )& memType , (void  * )& ctx , (void  * )& isManaged };
1749+     void  * attrdata [] =  {(void  * )& memType , (void  * )& memCtx , (void  * )& isManaged };
17501750
17511751    res  =  cuFunc .cuPointerGetAttributes (3 , attributes , attrdata , dbuf );
17521752    OPAL_OUTPUT_VERBOSE ((101 , mca_common_cuda_output ,
17531753                        "dbuf=%p, memType=%d, ctx=%p, isManaged=%d, res=%d" ,
1754-                          (void  * )dbuf , (int )memType , (void  * )ctx , isManaged , res ));
1754+                          (void  * )dbuf , (int )memType , (void  * )memCtx , isManaged , res ));
17551755
17561756    /* Mark unified memory buffers with a flag.  This will allow all unified 
17571757     * memory to be forced through host buffers.  Note that this memory can 
@@ -1787,6 +1787,7 @@ static int mca_common_cuda_is_gpu_buffer(const void *pUserBuf, opal_convertor_t
17871787    }
17881788    /* Must be a device pointer */ 
17891789    assert (memType  ==  CU_MEMORYTYPE_DEVICE );
1790+ #endif  /* OPAL_CUDA_GET_ATTRIBUTES */ 
17901791
17911792    /* This piece of code was added in to handle in a case involving 
17921793     * OMP threads.  The user had initialized CUDA and then spawned 
@@ -1797,25 +1798,25 @@ static int mca_common_cuda_is_gpu_buffer(const void *pUserBuf, opal_convertor_t
17971798     * and set the current context to that.  It is rare that we will not 
17981799     * have a context. */ 
17991800    res  =  cuFunc .cuCtxGetCurrent (& ctx );
1800- #endif  /* OPAL_CUDA_GET_ATTRIBUTES */ 
18011801    if  (OPAL_UNLIKELY (NULL  ==  ctx )) {
18021802        if  (CUDA_SUCCESS  ==  res ) {
1803-             res  =  cuFunc .cuPointerGetAttribute (& ctx ,
1803+ #if  !OPAL_CUDA_GET_ATTRIBUTES 
1804+             res  =  cuFunc .cuPointerGetAttribute (& memCtx ,
18041805                                               CU_POINTER_ATTRIBUTE_CONTEXT , dbuf );
18051806            if  (res  !=  CUDA_SUCCESS ) {
18061807                opal_output (0 , "CUDA: error calling cuPointerGetAttribute: " 
18071808                            "res=%d, ptr=%p aborting..." , res , pUserBuf );
18081809                return  OPAL_ERROR ;
1810+             }
1811+ #endif  /* OPAL_CUDA_GET_ATTRIBUTES */ 
1812+             res  =  cuFunc .cuCtxSetCurrent (memCtx );
1813+             if  (res  !=  CUDA_SUCCESS ) {
1814+                 opal_output (0 , "CUDA: error calling cuCtxSetCurrent: " 
1815+                             "res=%d, ptr=%p aborting..." , res , pUserBuf );
1816+                 return  OPAL_ERROR ;
18091817            } else  {
1810-                 res  =  cuFunc .cuCtxSetCurrent (ctx );
1811-                 if  (res  !=  CUDA_SUCCESS ) {
1812-                     opal_output (0 , "CUDA: error calling cuCtxSetCurrent: " 
1813-                                 "res=%d, ptr=%p aborting..." , res , pUserBuf );
1814-                     return  OPAL_ERROR ;
1815-                 } else  {
1816-                     opal_output_verbose (10 , mca_common_cuda_output ,
1817-                                         "CUDA: cuCtxSetCurrent passed: ptr=%p" , pUserBuf );
1818-                 }
1818+                 opal_output_verbose (10 , mca_common_cuda_output ,
1819+                                     "CUDA: cuCtxSetCurrent passed: ptr=%p" , pUserBuf );
18191820            }
18201821        } else  {
18211822            /* Print error and proceed */ 
0 commit comments