99
1010#include "cuda_ipc_md.h"
1111#include "cuda_ipc_cache.h"
12-
12+ #include "cuda_ipc.inl"
1313#include <string.h>
1414#include <limits.h>
1515#include <ucs/debug/log.h>
@@ -105,87 +105,13 @@ uct_cuda_ipc_md_query(uct_md_h md, uct_md_attr_v2_t *md_attr)
105105 return UCS_OK ;
106106}
107107
108- static ucs_status_t
109- uct_cuda_ipc_mem_reg_push_ctx (CUdeviceptr address , CUdevice * cuda_device_p ,
110- int * is_ctx_pushed , int * is_ctx_retained )
111- {
112- #define UCT_CUDA_IPC_NUM_ATTRS 2
113- CUcontext cuda_curr_ctx , cuda_ctx ;
114- CUdevice cuda_device ;
115- CUpointer_attribute attr_type [UCT_CUDA_IPC_NUM_ATTRS ];
116- void * attr_data [UCT_CUDA_IPC_NUM_ATTRS ];
117- int cuda_device_ordinal ;
118- ucs_status_t status ;
119-
120- attr_type [0 ] = CU_POINTER_ATTRIBUTE_CONTEXT ;
121- attr_data [0 ] = & cuda_ctx ;
122- attr_type [1 ] = CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL ;
123- attr_data [1 ] = & cuda_device_ordinal ;
124-
125- status = UCT_CUDADRV_FUNC_LOG_ERR (
126- cuPointerGetAttributes (UCT_CUDA_IPC_NUM_ATTRS , attr_type , attr_data ,
127- address ));
128- if (status != UCS_OK ) {
129- return status ;
130- }
131-
132- ucs_assertv (cuda_device_ordinal >= 0 , "cuda_device_ordinal=%d" ,
133- cuda_device_ordinal );
134-
135- status = UCT_CUDADRV_FUNC_LOG_ERR (cuDeviceGet (& cuda_device ,
136- cuda_device_ordinal ));
137- if (status != UCS_OK ) {
138- return status ;
139- }
140-
141- * is_ctx_pushed = 0 ;
142- * cuda_device_p = cuda_device ;
143-
144- if (cuda_ctx == NULL ) {
145- status = uct_cuda_primary_ctx_retain (* cuda_device_p , 0 , & cuda_ctx );
146- if (status != UCS_OK ) {
147- return status ;
148- }
149-
150- * is_ctx_retained = 1 ;
151- } else {
152- * is_ctx_retained = 0 ;
153- status = UCT_CUDADRV_FUNC_LOG_ERR (cuCtxGetCurrent (& cuda_curr_ctx ));
154- if ((status != UCS_OK ) || (cuda_curr_ctx == cuda_ctx )) {
155- /* Failed to get current context or the pointer's context is
156- * already current, no need to push/pop */
157- return status ;
158- }
159- }
160-
161- status = UCT_CUDADRV_FUNC_LOG_ERR (cuCtxPushCurrent (cuda_ctx ));
162- if (status != UCS_OK ) {
163- if (* is_ctx_retained ) {
164- UCT_CUDADRV_FUNC_LOG_WARN (cuDevicePrimaryCtxRelease (* cuda_device_p ));
165- }
166- return status ;
167- }
168-
169- * is_ctx_pushed = 1 ;
170- return UCS_OK ;
171- }
172-
173- static void uct_cuda_ipc_mem_reg_pop_ctx (CUdevice cuda_device ,
174- int is_ctx_retained )
175- {
176- UCT_CUDADRV_FUNC_LOG_WARN (cuCtxPopCurrent (NULL ));
177- if (is_ctx_retained ) {
178- UCT_CUDADRV_FUNC_LOG_WARN (cuDevicePrimaryCtxRelease (cuda_device ));
179- }
180- }
181-
182108static ucs_status_t
183109uct_cuda_ipc_mem_add_reg (void * addr , uct_cuda_ipc_memh_t * memh ,
184110 uct_cuda_ipc_lkey_t * * key_p )
185111{
186112 uct_cuda_ipc_lkey_t * key ;
187113 ucs_status_t status ;
188- int is_ctx_pushed , is_ctx_retained ;
114+ int is_ctx_pushed ;
189115 CUdevice cuda_device ;
190116#if HAVE_CUDA_FABRIC
191117#define UCT_CUDA_IPC_QUERY_NUM_ATTRS 3
@@ -202,8 +128,8 @@ uct_cuda_ipc_mem_add_reg(void *addr, uct_cuda_ipc_memh_t *memh,
202128 return UCS_ERR_NO_MEMORY ;
203129 }
204130
205- status = uct_cuda_ipc_mem_reg_push_ctx ((CUdeviceptr )addr , & cuda_device ,
206- & is_ctx_pushed , & is_ctx_retained );
131+ status = uct_cuda_ipc_check_and_push_ctx ((CUdeviceptr )addr , & cuda_device ,
132+ & is_ctx_pushed );
207133 if (status != UCS_OK ) {
208134 goto out ;
209135 }
@@ -313,17 +239,15 @@ uct_cuda_ipc_mem_add_reg(void *addr, uct_cuda_ipc_memh_t *memh,
313239 ucs_list_add_tail (& memh -> list , & key -> link );
314240 ucs_trace ("registered addr:%p/%p length:%zd type:%u dev_num:%d "
315241 "buffer_id:%llu" ,
316- addr , (void * )key -> d_bptr , key -> b_len , key -> ph .handle_type ,
317- memh -> dev_num , key -> ph .buffer_id );
242+ addr , (void * )key -> d_bptr , key -> b_len , key -> ph .handle_type ,
243+ cuda_device , key -> ph .buffer_id );
318244
319245 memh -> dev_num = cuda_device ;
320246 * key_p = key ;
321247 status = UCS_OK ;
322248
323249out_pop_ctx :
324- if (is_ctx_pushed ) {
325- uct_cuda_ipc_mem_reg_pop_ctx (memh -> dev_num , is_ctx_retained );
326- }
250+ uct_cuda_ipc_check_and_pop_ctx (is_ctx_pushed );
327251out :
328252 if (status != UCS_OK ) {
329253 ucs_free (key );
@@ -394,10 +318,6 @@ uct_cuda_ipc_is_peer_accessible(uct_cuda_ipc_component_t *component,
394318 }
395319 }
396320
397- /* Save local device number, so we use it to find remote rcache when mapping
398- * mem_handle in uct_cuda_ipc_post_cuda_async_copy */
399- rkey -> super .dev_num = cu_dev ;
400-
401321 pthread_mutex_lock (& component -> lock );
402322
403323 cache = uct_cuda_ipc_get_dev_cache (component , & rkey -> super );
0 commit comments