@@ -108,6 +108,10 @@ struct cudaFunctionTable {
108108#if OPAL_CUDA_GET_ATTRIBUTES
109109 int (* cuPointerGetAttributes )(unsigned int , CUpointer_attribute * , void * * , CUdeviceptr );
110110#if OPAL_CUDA_VMM_SUPPORT
111+ int (* cuDevicePrimaryCtxRetain )(CUcontext * , CUdevice );
112+ int (* cuDevicePrimaryCtxGetState )(CUdevice , unsigned int * , int * );
113+ int (* cuMemPoolGetAccess )(CUmemAccess_flags * , CUmemoryPool , CUmemLocation * );
114+ int (* cuDeviceGetAttribute )(int * , CUdevice_attribute , CUdevice );
111115 int (* cuDeviceGetCount )(int * );
112116 int (* cuMemRelease )(CUmemGenericAllocationHandle );
113117 int (* cuMemRetainAllocationHandle )(CUmemGenericAllocationHandle * , void * );
@@ -488,6 +492,10 @@ int mca_common_cuda_stage_one_init(void)
488492 OPAL_CUDA_DLSYM (libcuda_handle , cuPointerGetAttributes );
489493#endif /* OPAL_CUDA_GET_ATTRIBUTES */
490494#if OPAL_CUDA_VMM_SUPPORT
495+ OPAL_CUDA_DLSYM (libcuda_handle , cuDevicePrimaryCtxRetain );
496+ OPAL_CUDA_DLSYM (libcuda_handle , cuDevicePrimaryCtxGetState );
497+ OPAL_CUDA_DLSYM (libcuda_handle , cuMemPoolGetAccess );
498+ OPAL_CUDA_DLSYM (libcuda_handle , cuDeviceGetAttribute );
491499 OPAL_CUDA_DLSYM (libcuda_handle , cuDeviceGetCount );
492500 OPAL_CUDA_DLSYM (libcuda_handle , cuMemRelease );
493501 OPAL_CUDA_DLSYM (libcuda_handle , cuMemRetainAllocationHandle );
@@ -1745,7 +1753,90 @@ static float mydifftime(opal_timer_t ts_start, opal_timer_t ts_end) {
17451753}
17461754#endif /* OPAL_ENABLE_DEBUG */
17471755
1748- static int mca_common_cuda_check_vmm (CUdeviceptr dbuf , CUmemorytype * mem_type )
1756+ static int mca_common_cuda_check_mpool (CUdeviceptr dbuf , CUmemorytype * mem_type ,
1757+ int * dev_id )
1758+ {
1759+ #if OPAL_CUDA_VMM_SUPPORT
1760+ static int device_count = -1 ;
1761+ static int mpool_supported = -1 ;
1762+ CUresult result ;
1763+ CUmemoryPool mpool ;
1764+ CUmemAccess_flags flags ;
1765+ CUmemLocation location ;
1766+
1767+ if (mpool_supported <= 0 ) {
1768+ if (mpool_supported == -1 ) {
1769+ if (device_count == -1 ) {
1770+ result = cuFunc .cuDeviceGetCount (& device_count );
1771+ if (result != CUDA_SUCCESS || (0 == device_count )) {
1772+ mpool_supported = 0 ; /* never check again */
1773+ device_count = 0 ;
1774+ return 0 ;
1775+ }
1776+ }
1777+
1778+ /* assume uniformity of devices */
1779+ result = cuFunc .cuDeviceGetAttribute (& mpool_supported ,
1780+ CU_DEVICE_ATTRIBUTE_MEMORY_POOLS_SUPPORTED , 0 );
1781+ if (result != CUDA_SUCCESS ) {
1782+ mpool_supported = 0 ;
1783+ }
1784+ }
1785+ if (0 == mpool_supported ) {
1786+ return 0 ;
1787+ }
1788+ }
1789+
1790+ result = cuFunc .cuPointerGetAttribute (& mpool ,
1791+ CU_POINTER_ATTRIBUTE_MEMPOOL_HANDLE ,
1792+ dbuf );
1793+ if (CUDA_SUCCESS != result ) {
1794+ return 0 ;
1795+ }
1796+
1797+ /* check if device has access */
1798+ for (int i = 0 ; i < device_count ; i ++ ) {
1799+ location .type = CU_MEM_LOCATION_TYPE_DEVICE ;
1800+ location .id = i ;
1801+ result = cuFunc .cuMemPoolGetAccess (& flags , mpool , & location );
1802+ if ((CUDA_SUCCESS == result ) &&
1803+ (CU_MEM_ACCESS_FLAGS_PROT_READWRITE == flags )) {
1804+ * mem_type = CU_MEMORYTYPE_DEVICE ;
1805+ * dev_id = i ;
1806+ return 1 ;
1807+ }
1808+ }
1809+
1810+ /* host must have access as device access possibility is exhausted */
1811+ * mem_type = CU_MEMORYTYPE_HOST ;
1812+ * dev_id = -1 ;
1813+ return 0 ;
1814+ #endif
1815+
1816+ return 0 ;
1817+ }
1818+
1819+ static int mca_common_cuda_get_primary_context (CUdevice dev_id , CUcontext * pctx )
1820+ {
1821+ CUresult result ;
1822+ unsigned int flags ;
1823+ int active ;
1824+
1825+ result = cuFunc .cuDevicePrimaryCtxGetState (dev_id , & flags , & active );
1826+ if (CUDA_SUCCESS != result ) {
1827+ return OPAL_ERROR ;
1828+ }
1829+
1830+ if (active ) {
1831+ result = cuFunc .cuDevicePrimaryCtxRetain (pctx , dev_id );
1832+ return OPAL_SUCCESS ;
1833+ }
1834+
1835+ return OPAL_ERROR ;
1836+ }
1837+
1838+ static int mca_common_cuda_check_vmm (CUdeviceptr dbuf , CUmemorytype * mem_type ,
1839+ int * dev_id )
17491840{
17501841#if OPAL_CUDA_VMM_SUPPORT
17511842 static int device_count = -1 ;
@@ -1775,6 +1866,7 @@ static int mca_common_cuda_check_vmm(CUdeviceptr dbuf, CUmemorytype *mem_type)
17751866
17761867 if (prop .location .type == CU_MEM_LOCATION_TYPE_DEVICE ) {
17771868 * mem_type = CU_MEMORYTYPE_DEVICE ;
1869+ * dev_id = prop .location .id ;
17781870 cuFunc .cuMemRelease (alloc_handle );
17791871 return 1 ;
17801872 }
@@ -1788,6 +1880,7 @@ static int mca_common_cuda_check_vmm(CUdeviceptr dbuf, CUmemorytype *mem_type)
17881880 if ((CUDA_SUCCESS == result ) &&
17891881 (CU_MEM_ACCESS_FLAGS_PROT_READWRITE == flags )) {
17901882 * mem_type = CU_MEMORYTYPE_DEVICE ;
1883+ * dev_id = i ;
17911884 cuFunc .cuMemRelease (alloc_handle );
17921885 return 1 ;
17931886 }
@@ -1796,6 +1889,7 @@ static int mca_common_cuda_check_vmm(CUdeviceptr dbuf, CUmemorytype *mem_type)
17961889
17971890 /* host must have access as device access possibility is exhausted */
17981891 * mem_type = CU_MEMORYTYPE_HOST ;
1892+ * dev_id = -1 ;
17991893 cuFunc .cuMemRelease (alloc_handle );
18001894 return 1 ;
18011895
@@ -1809,12 +1903,17 @@ static int mca_common_cuda_is_gpu_buffer(const void *pUserBuf, opal_convertor_t
18091903{
18101904 int res ;
18111905 int is_vmm = 0 ;
1906+ int is_mpool = 0 ;
18121907 CUmemorytype vmm_mem_type = 0 ;
1908+ CUmemorytype mpool_mem_type = 0 ;
18131909 CUmemorytype memType = 0 ;
1910+ int vmm_dev_id = -1 ;
1911+ int mpool_dev_id = -1 ;
18141912 CUdeviceptr dbuf = (CUdeviceptr )pUserBuf ;
18151913 CUcontext ctx = NULL , memCtx = NULL ;
18161914
1817- is_vmm = mca_common_cuda_check_vmm (dbuf , & vmm_mem_type );
1915+ is_vmm = mca_common_cuda_check_vmm (dbuf , & vmm_mem_type , & vmm_dev_id );
1916+ is_mpool = mca_common_cuda_check_mpool (dbuf , & mpool_mem_type , & mpool_dev_id );
18181917
18191918#if OPAL_CUDA_GET_ATTRIBUTES
18201919 uint32_t isManaged = 0 ;
@@ -1844,6 +1943,8 @@ static int mca_common_cuda_is_gpu_buffer(const void *pUserBuf, opal_convertor_t
18441943 } else if (memType == CU_MEMORYTYPE_HOST ) {
18451944 if (is_vmm && (vmm_mem_type == CU_MEMORYTYPE_DEVICE )) {
18461945 memType = CU_MEMORYTYPE_DEVICE ;
1946+ } else if (is_mpool && (mpool_mem_type == CU_MEMORYTYPE_DEVICE )) {
1947+ memType = CU_MEMORYTYPE_DEVICE ;
18471948 } else {
18481949 /* Host memory, nothing to do here */
18491950 return 0 ;
@@ -1864,6 +1965,8 @@ static int mca_common_cuda_is_gpu_buffer(const void *pUserBuf, opal_convertor_t
18641965 } else if (memType == CU_MEMORYTYPE_HOST ) {
18651966 if (is_vmm && (vmm_mem_type == CU_MEMORYTYPE_DEVICE )) {
18661967 memType = CU_MEMORYTYPE_DEVICE ;
1968+ } else if (is_mpool && (mpool_mem_type == CU_MEMORYTYPE_DEVICE )) {
1969+ memType = CU_MEMORYTYPE_DEVICE ;
18671970 } else {
18681971 /* Host memory, nothing to do here */
18691972 return 0 ;
@@ -1893,14 +1996,18 @@ static int mca_common_cuda_is_gpu_buffer(const void *pUserBuf, opal_convertor_t
18931996 return OPAL_ERROR ;
18941997 }
18951998#endif /* OPAL_CUDA_GET_ATTRIBUTES */
1896- if (is_vmm ) {
1897- /* This function is expected to set context if pointer is device
1898- * accessible but VMM allocations have NULL context associated
1899- * which cannot be set against the calling thread */
1900- opal_output (0 ,
1901- "CUDA: unable to set context with the given pointer"
1902- "ptr=%p aborting..." , dbuf );
1903- return OPAL_ERROR ;
1999+ if (is_vmm || is_mpool ) {
2000+ if (OPAL_SUCCESS ==
2001+ mca_common_cuda_get_primary_context (
2002+ is_vmm ? vmm_dev_id : mpool_dev_id , & memCtx )) {
2003+ /* As VMM/mempool allocations have no context associated
2004+ * with them, check if device primary context can be set */
2005+ } else {
2006+ opal_output (0 ,
2007+ "CUDA: unable to set ctx with the given pointer"
2008+ "ptr=%p aborting..." , pUserBuf );
2009+ return OPAL_ERROR ;
2010+ }
19042011 }
19052012
19062013 res = cuFunc .cuCtxSetCurrent (memCtx );
0 commit comments