@@ -62,6 +62,19 @@ uct_rc_gdaki_alloc(size_t size, size_t align, void **p_buf, CUdeviceptr *p_orig)
6262 return status ;
6363}
6464
65+ static void
66+ uct_rc_gdaki_calc_dev_ep_layout (size_t cq_umem_len , unsigned max_tx ,
67+ size_t * cq_umem_offset_p ,
68+ size_t * qp_umem_offset_p ,
69+ size_t * dev_ep_size_p )
70+ {
71+ * cq_umem_offset_p = ucs_align_up_pow2 (sizeof (uct_rc_gdaki_dev_ep_t ),
72+ ucs_get_page_size ());
73+ * qp_umem_offset_p = ucs_align_up_pow2 (* cq_umem_offset_p + cq_umem_len ,
74+ ucs_get_page_size ());
75+ * dev_ep_size_p = * qp_umem_offset_p + max_tx * MLX5_SEND_WQE_BB ;
76+ }
77+
6578static UCS_CLASS_INIT_FUNC (uct_rc_gdaki_ep_t , const uct_ep_params_t * params )
6679{
6780 uct_rc_gdaki_iface_t * iface = ucs_derived_of (params -> iface ,
@@ -71,13 +84,14 @@ static UCS_CLASS_INIT_FUNC(uct_rc_gdaki_ep_t, const uct_ep_params_t *params)
7184 uct_ib_iface_init_attr_t init_attr = {};
7285 uct_ib_mlx5_cq_attr_t cq_attr = {};
7386 uct_ib_mlx5_qp_attr_t qp_attr = {};
74- uct_rc_gdaki_dev_ep_t dev_ep = {};
7587 ucs_status_t status ;
7688 size_t dev_ep_size ;
7789 uct_ib_mlx5_dbrec_t dbrec ;
7890
7991 UCS_CLASS_CALL_SUPER_INIT (uct_base_ep_t , & iface -> super .super .super .super );
8092
93+ self -> dev_ep_init = 0 ;
94+
8195 status = UCT_CUDADRV_FUNC_LOG_ERR (cuCtxPushCurrent (iface -> cuda_ctx ));
8296 if (status != UCS_OK ) {
8397 return status ;
@@ -91,37 +105,29 @@ static UCS_CLASS_INIT_FUNC(uct_rc_gdaki_ep_t, const uct_ep_params_t *params)
91105 uct_ib_mlx5_wq_calc_sizes (& qp_attr );
92106
93107 cq_attr .flags |= UCT_IB_MLX5_CQ_IGNORE_OVERRUN ;
94- cq_attr .umem_offset = ucs_align_up_pow2 (sizeof (uct_rc_gdaki_dev_ep_t ),
95- ucs_get_page_size ());
96108
97109 qp_attr .mmio_mode = UCT_IB_MLX5_MMIO_MODE_DB ;
98110 qp_attr .super .srq_num = 0 ;
99- qp_attr .umem_offset = ucs_align_up_pow2 (cq_attr .umem_offset +
100- cq_attr .umem_len ,
101- ucs_get_page_size ());
102111
103112 /* Disable inline scatter to TX CQE */
104113 qp_attr .super .max_inl_cqe [UCT_IB_DIR_TX ] = 0 ;
105114
106- dev_ep_size = qp_attr .umem_offset + qp_attr .len ;
107115 /*
108116 * dev_ep layout:
109117 * +---------------------+---------+---------+
110118 * | counters, dbr | cq buff | wq buff |
111119 * +---------------------+---------+---------+
112120 */
121+ uct_rc_gdaki_calc_dev_ep_layout (cq_attr .umem_len , qp_attr .max_tx ,
122+ & cq_attr .umem_offset ,
123+ & qp_attr .umem_offset , & dev_ep_size );
124+
113125 status = uct_rc_gdaki_alloc (dev_ep_size , ucs_get_page_size (),
114126 (void * * )& self -> ep_gpu , & self -> ep_raw );
115127 if (status != UCS_OK ) {
116128 goto err_ctx ;
117129 }
118130
119- status = UCT_CUDADRV_FUNC_LOG_ERR (
120- cuMemsetD8 ((CUdeviceptr )self -> ep_gpu , 0 , dev_ep_size ));
121- if (status != UCS_OK ) {
122- goto err_mem ;
123- }
124-
125131 /* TODO add dmabuf_fd support */
126132 self -> umem = mlx5dv_devx_umem_reg (md -> super .dev .ibv_context , self -> ep_gpu ,
127133 dev_ep_size , IBV_ACCESS_LOCAL_WRITE );
@@ -169,40 +175,11 @@ static UCS_CLASS_INIT_FUNC(uct_rc_gdaki_ep_t, const uct_ep_params_t *params)
169175 goto err_dev_ep ;
170176 }
171177
172- dev_ep .atomic_va = iface -> atomic_buff ;
173- dev_ep .atomic_lkey = htonl (iface -> atomic_mr -> lkey );
174-
175- dev_ep .sq_num = self -> qp .super .qp_num ;
176- dev_ep .sq_wqe_daddr = UCS_PTR_BYTE_OFFSET (self -> ep_gpu ,
177- qp_attr .umem_offset );
178- dev_ep .sq_wqe_num = qp_attr .max_tx ;
179- dev_ep .sq_dbrec = & self -> ep_gpu -> qp_dbrec [MLX5_SND_DBR ];
180- /* FC mask is used to determine if WQE should be posted with completion.
181- * qp_attr.max_tx must be a power of 2. */
182- dev_ep .sq_fc_mask = (qp_attr .max_tx >> 1 ) - 1 ;
183- dev_ep .cqe_daddr = UCS_PTR_BYTE_OFFSET (self -> ep_gpu , cq_attr .umem_offset );
184- dev_ep .cqe_num = cq_attr .cq_size ;
185- dev_ep .sq_db = self -> sq_db ;
186-
187- status = UCT_CUDADRV_FUNC_LOG_ERR (
188- cuMemsetD8 ((CUdeviceptr )UCS_PTR_BYTE_OFFSET (self -> ep_gpu ,
189- cq_attr .umem_offset ),
190- 0xff , cq_attr .umem_len ));
191- if (status != UCS_OK ) {
192- goto err_dev_ep ;
193- }
194-
195- status = UCT_CUDADRV_FUNC_LOG_ERR (
196- cuMemcpyHtoD ((CUdeviceptr )self -> ep_gpu , & dev_ep , sizeof (dev_ep )));
197- if (status != UCS_OK ) {
198- goto err_dev_ep ;
199- }
200-
201178 (void )UCT_CUDADRV_FUNC_LOG_WARN (cuCtxPopCurrent (NULL ));
202179 return UCS_OK ;
203180
204181err_dev_ep :
205- (void )cuMemHostUnregister (dev_ep . sq_db );
182+ (void )cuMemHostUnregister (self -> qp . reg -> addr . ptr );
206183 uct_ib_mlx5_devx_destroy_qp_common (& self -> qp .super );
207184err_cq :
208185 uct_ib_mlx5_devx_destroy_cq_common (& self -> cq );
@@ -372,10 +349,74 @@ uct_rc_gdaki_create_cq(uct_ib_iface_t *ib_iface, uct_ib_dir_t dir,
372349ucs_status_t
373350uct_rc_gdaki_ep_get_device_ep (uct_ep_h tl_ep , uct_device_ep_h * device_ep_p )
374351{
375- uct_rc_gdaki_ep_t * ep = ucs_derived_of (tl_ep , uct_rc_gdaki_ep_t );
352+ uct_rc_gdaki_ep_t * ep = ucs_derived_of (tl_ep , uct_rc_gdaki_ep_t );
353+ uct_rc_gdaki_iface_t * iface = ucs_derived_of (ep -> super .super .iface ,
354+ uct_rc_gdaki_iface_t );
355+ uct_rc_gdaki_dev_ep_t dev_ep = {};
356+ unsigned cq_size , cqe_size , max_tx ;
357+ size_t cq_umem_offset , cq_umem_len , qp_umem_offset , dev_ep_size ;
358+ ucs_status_t status ;
359+
360+ if (!ep -> dev_ep_init ) {
361+ status = UCT_CUDADRV_FUNC_LOG_ERR (cuCtxPushCurrent (iface -> cuda_ctx ));
362+ if (status != UCS_OK ) {
363+ return status ;
364+ }
365+
366+ cq_size = UCS_BIT (ep -> cq .cq_length_log );
367+ cqe_size = UCS_BIT (ep -> cq .cqe_size_log );
368+ cq_umem_len = cqe_size * cq_size ;
369+ /* Reconstruct original max_tx from bb_max */
370+ max_tx = ep -> qp .bb_max + 2 * UCT_IB_MLX5_MAX_BB ;
371+
372+ uct_rc_gdaki_calc_dev_ep_layout (cq_umem_len , max_tx ,
373+ & cq_umem_offset ,
374+ & qp_umem_offset , & dev_ep_size );
375+
376+ status = UCT_CUDADRV_FUNC_LOG_ERR (
377+ cuMemsetD8 ((CUdeviceptr )ep -> ep_gpu , 0 , dev_ep_size ));
378+ if (status != UCS_OK ) {
379+ goto out_ctx ;
380+ }
381+
382+ status = UCT_CUDADRV_FUNC_LOG_ERR (cuMemsetD8 (
383+ (CUdeviceptr )UCS_PTR_BYTE_OFFSET (ep -> ep_gpu , cq_umem_offset ),
384+ 0xff , cq_umem_len ));
385+ if (status != UCS_OK ) {
386+ goto out_ctx ;
387+ }
388+
389+ dev_ep .atomic_va = iface -> atomic_buff ;
390+ dev_ep .atomic_lkey = htonl (iface -> atomic_mr -> lkey );
391+ dev_ep .sq_num = ep -> qp .super .qp_num ;
392+ dev_ep .sq_wqe_daddr = UCS_PTR_BYTE_OFFSET (ep -> ep_gpu , qp_umem_offset );
393+ dev_ep .sq_dbrec = & ep -> ep_gpu -> qp_dbrec [MLX5_SND_DBR ];
394+ dev_ep .sq_wqe_num = max_tx ;
395+ /* FC mask is used to determine if WQE should be posted with completion.
396+ * max_tx must be a power of 2. */
397+ dev_ep .sq_fc_mask = (max_tx >> 1 ) - 1 ;
398+
399+ dev_ep .cqe_daddr = UCS_PTR_BYTE_OFFSET (ep -> ep_gpu , cq_umem_offset );
400+ dev_ep .cqe_num = cq_size ;
401+ dev_ep .sq_db = ep -> sq_db ;
402+
403+ status = UCT_CUDADRV_FUNC_LOG_ERR (cuMemcpyHtoD (
404+ (CUdeviceptr )ep -> ep_gpu , & dev_ep , sizeof (dev_ep )));
405+ if (status != UCS_OK ) {
406+ goto out_ctx ;
407+ }
408+
409+ (void )UCT_CUDADRV_FUNC_LOG_WARN (cuCtxPopCurrent (NULL ));
410+
411+ ep -> dev_ep_init = 1 ;
412+ }
376413
377414 * device_ep_p = & ep -> ep_gpu -> super ;
378415 return UCS_OK ;
416+
417+ out_ctx :
418+ (void )UCT_CUDADRV_FUNC_LOG_WARN (cuCtxPopCurrent (NULL ));
419+ return status ;
379420}
380421
381422ucs_status_t
0 commit comments