@@ -25,6 +25,12 @@ static int mca_pml_ucx_request_free(ompi_request_t **rptr)
2525 return OMPI_SUCCESS ;
2626}
2727
28+ static int mca_pml_ucx_request_cancel (ompi_request_t * req , int flag )
29+ {
30+ ucp_request_cancel (ompi_pml_ucx .ucp_worker , req );
31+ return OMPI_SUCCESS ;
32+ }
33+
2834void mca_pml_ucx_send_completion (void * request , ucs_status_t status )
2935{
3036 ompi_request_t * req = request ;
@@ -55,12 +61,19 @@ void mca_pml_ucx_recv_completion(void *request, ucs_status_t status,
5561 OPAL_THREAD_UNLOCK (& ompi_request_lock );
5662}
5763
58- void mca_pml_ucx_persistent_requset_complete (mca_pml_ucx_persistent_request_t * preq ,
64+ static void mca_pml_ucx_persistent_request_detach (mca_pml_ucx_persistent_request_t * preq ,
65+ ompi_request_t * tmp_req )
66+ {
67+ tmp_req -> req_complete_cb_data = NULL ;
68+ preq -> tmp_req = NULL ;
69+ }
70+
71+ void mca_pml_ucx_persistent_request_complete (mca_pml_ucx_persistent_request_t * preq ,
5972 ompi_request_t * tmp_req )
6073{
6174 preq -> ompi .req_status = tmp_req -> req_status ;
6275 ompi_request_complete (& preq -> ompi , true);
63- tmp_req -> req_complete_cb_data = NULL ;
76+ mca_pml_ucx_persistent_request_detach ( preq , tmp_req ) ;
6477 mca_pml_ucx_request_reset (tmp_req );
6578 ucp_request_release (tmp_req );
6679}
@@ -73,7 +86,8 @@ static inline void mca_pml_ucx_preq_completion(ompi_request_t *tmp_req)
7386 ompi_request_complete (tmp_req , false);
7487 preq = (mca_pml_ucx_persistent_request_t * )tmp_req -> req_complete_cb_data ;
7588 if (preq != NULL ) {
76- mca_pml_ucx_persistent_requset_complete (preq , tmp_req );
89+ PML_UCX_ASSERT (preq -> tmp_req != NULL );
90+ mca_pml_ucx_persistent_request_complete (preq , tmp_req );
7791 }
7892 OPAL_THREAD_UNLOCK (& ompi_request_lock );
7993}
@@ -120,7 +134,8 @@ void mca_pml_ucx_request_init(void *request)
120134 ompi_request_t * ompi_req = request ;
121135 OBJ_CONSTRUCT (ompi_req , ompi_request_t );
122136 mca_pml_ucx_request_init_common (ompi_req , false, OMPI_REQUEST_ACTIVE ,
123- mca_pml_ucx_request_free , NULL );
137+ mca_pml_ucx_request_free ,
138+ mca_pml_ucx_request_cancel );
124139}
125140
126141void mca_pml_ucx_request_cleanup (void * request )
@@ -133,18 +148,35 @@ void mca_pml_ucx_request_cleanup(void *request)
133148
134149static int mca_pml_ucx_persistent_request_free (ompi_request_t * * rptr )
135150{
136- mca_pml_ucx_persistent_request_t * req = (mca_pml_ucx_persistent_request_t * )* rptr ;
151+ mca_pml_ucx_persistent_request_t * preq = (mca_pml_ucx_persistent_request_t * )* rptr ;
152+ ompi_request_t * tmp_req = preq -> tmp_req ;
137153
154+ preq -> ompi .req_state = OMPI_REQUEST_INVALID ;
155+ if (tmp_req != NULL ) {
156+ mca_pml_ucx_persistent_request_detach (preq , tmp_req );
157+ ucp_request_release (tmp_req );
158+ }
159+ PML_UCX_FREELIST_RETURN (& ompi_pml_ucx .persistent_reqs , & preq -> ompi .super );
138160 * rptr = MPI_REQUEST_NULL ;
139- req -> ompi .req_state = OMPI_REQUEST_INVALID ;
140- PML_UCX_FREELIST_RETURN (& ompi_pml_ucx .persistent_reqs , & req -> ompi .super );
161+ return OMPI_SUCCESS ;
162+ }
163+
164+ static int mca_pml_ucx_persistent_request_cancel (ompi_request_t * req , int flag )
165+ {
166+ mca_pml_ucx_persistent_request_t * preq = (mca_pml_ucx_persistent_request_t * )req ;
167+
168+ if (preq -> tmp_req != NULL ) {
169+ ucp_request_cancel (ompi_pml_ucx .ucp_worker , preq -> tmp_req );
170+ }
141171 return OMPI_SUCCESS ;
142172}
143173
144174static void mca_pml_ucx_persisternt_request_construct (mca_pml_ucx_persistent_request_t * req )
145175{
146176 mca_pml_ucx_request_init_common (& req -> ompi , true, OMPI_REQUEST_INACTIVE ,
147- mca_pml_ucx_persistent_request_free , NULL );
177+ mca_pml_ucx_persistent_request_free ,
178+ mca_pml_ucx_persistent_request_cancel );
179+ req -> tmp_req = NULL ;
148180}
149181
150182static void mca_pml_ucx_persisternt_request_destruct (mca_pml_ucx_persistent_request_t * req )
0 commit comments