1515#include "opal/runtime/opal.h"
1616#include "opal/mca/pmix/pmix.h"
1717#include "ompi/message/message.h"
18+ #include "ompi/mca/pml/base/pml_base_bsend.h"
1819#include "pml_ucx_request.h"
1920
2021#include <inttypes.h>
@@ -333,7 +334,7 @@ static void mca_pml_ucx_waitall(void **reqs, size_t *count_p)
333334 ucs_status_t status ;
334335 size_t i ;
335336
336- PML_UCX_VERBOSE (2 , "waiting for %d disconnect requests" , * count_p );
337+ PML_UCX_VERBOSE (2 , "waiting for %d disconnect requests" , ( int ) * count_p );
337338 for (i = 0 ; i < * count_p ; ++ i ) {
338339 do {
339340 opal_progress ();
@@ -343,7 +344,7 @@ static void mca_pml_ucx_waitall(void **reqs, size_t *count_p)
343344 PML_UCX_ERROR ("disconnect request failed: %s" ,
344345 ucs_status_string (status ));
345346 }
346- ucp_request_release (reqs [i ]);
347+ ucp_request_free (reqs [i ]);
347348 reqs [i ] = NULL ;
348349 }
349350
@@ -391,7 +392,7 @@ int mca_pml_ucx_del_procs(struct ompi_proc_t **procs, size_t nprocs)
391392
392393 proc -> proc_endpoints [OMPI_PROC_ENDPOINT_TAG_PML ] = NULL ;
393394
394- if (num_reqs >= ompi_pml_ucx .num_disconnect ) {
395+ if (( int ) num_reqs >= ompi_pml_ucx .num_disconnect ) {
395396 mca_pml_ucx_waitall (dreqs , & num_reqs );
396397 }
397398 }
@@ -494,7 +495,7 @@ int mca_pml_ucx_recv(void *buf, size_t count, ompi_datatype_t *datatype, int src
494495 PML_UCX_TRACE_RECV ("%s" , buf , count , datatype , src , tag , comm , "recv" );
495496
496497 PML_UCX_MAKE_RECV_TAG (ucp_tag , ucp_tag_mask , tag , src , comm );
497- req = alloca (ompi_pml_ucx .request_size ) + ompi_pml_ucx .request_size ;
498+ req = ( char * ) alloca (ompi_pml_ucx .request_size ) + ompi_pml_ucx .request_size ;
498499 status = ucp_tag_recv_nbr (ompi_pml_ucx .ucp_worker , buf , count ,
499500 mca_pml_ucx_get_datatype (datatype ),
500501 ucp_tag , ucp_tag_mask , req );
@@ -556,15 +557,80 @@ int mca_pml_ucx_isend_init(const void *buf, size_t count, ompi_datatype_t *datat
556557 req -> flags = MCA_PML_UCX_REQUEST_FLAG_SEND ;
557558 req -> buffer = (void * )buf ;
558559 req -> count = count ;
559- req -> datatype = mca_pml_ucx_get_datatype (datatype );
560560 req -> tag = PML_UCX_MAKE_SEND_TAG (tag , comm );
561561 req -> send .mode = mode ;
562562 req -> send .ep = ep ;
563+ if (MCA_PML_BASE_SEND_BUFFERED == mode ) {
564+ req -> ompi_datatype = datatype ;
565+ OBJ_RETAIN (datatype );
566+ } else {
567+ req -> datatype = mca_pml_ucx_get_datatype (datatype );
568+ }
563569
564570 * request = & req -> ompi ;
565571 return OMPI_SUCCESS ;
566572}
567573
574+ static int
575+ mca_pml_ucx_bsend (ucp_ep_h ep , const void * buf , size_t count ,
576+ ompi_datatype_t * datatype , uint64_t pml_tag )
577+ {
578+ ompi_request_t * req ;
579+ void * packed_data ;
580+ size_t packed_length ;
581+ size_t offset ;
582+ uint32_t iov_count ;
583+ struct iovec iov ;
584+ opal_convertor_t opal_conv ;
585+
586+ OBJ_CONSTRUCT (& opal_conv , opal_convertor_t );
587+ opal_convertor_copy_and_prepare_for_recv (ompi_proc_local_proc -> super .proc_convertor ,
588+ & datatype -> super , count , buf , 0 ,
589+ & opal_conv );
590+ opal_convertor_get_packed_size (& opal_conv , & packed_length );
591+
592+ packed_data = mca_pml_base_bsend_request_alloc_buf (packed_length );
593+ if (OPAL_UNLIKELY (NULL == packed_data )) {
594+ OBJ_DESTRUCT (& opal_conv );
595+ PML_UCX_ERROR ("bsend: failed to allocate buffer" );
596+ return OMPI_ERR_OUT_OF_RESOURCE ;
597+ }
598+
599+ iov_count = 1 ;
600+ iov .iov_base = packed_data ;
601+ iov .iov_len = packed_length ;
602+
603+ PML_UCX_VERBOSE (8 , "bsend of packed buffer %p len %d\n" , packed_data , packed_length );
604+ offset = 0 ;
605+ opal_convertor_set_position (& opal_conv , & offset );
606+ if (0 > opal_convertor_pack (& opal_conv , & iov , & iov_count , & packed_length )) {
607+ mca_pml_base_bsend_request_free (packed_data );
608+ OBJ_DESTRUCT (& opal_conv );
609+ PML_UCX_ERROR ("bsend: failed to pack user datatype" );
610+ return OMPI_ERROR ;
611+ }
612+
613+ OBJ_DESTRUCT (& opal_conv );
614+
615+ req = (ompi_request_t * )ucp_tag_send_nb (ep , packed_data , packed_length ,
616+ ucp_dt_make_contig (1 ), pml_tag ,
617+ mca_pml_ucx_bsend_completion );
618+ if (NULL == req ) {
619+ /* request was completed in place */
620+ mca_pml_base_bsend_request_free (packed_data );
621+ return OMPI_SUCCESS ;
622+ }
623+
624+ if (OPAL_UNLIKELY (UCS_PTR_IS_ERR (req ))) {
625+ mca_pml_base_bsend_request_free (packed_data );
626+ PML_UCX_ERROR ("ucx bsend failed: %s" , ucs_status_string (UCS_PTR_STATUS (req )));
627+ return OMPI_ERROR ;
628+ }
629+
630+ req -> req_complete_cb_data = packed_data ;
631+ return OMPI_SUCCESS ;
632+ }
633+
568634int mca_pml_ucx_isend (const void * buf , size_t count , ompi_datatype_t * datatype ,
569635 int dst , int tag , mca_pml_base_send_mode_t mode ,
570636 struct ompi_communicator_t * comm ,
@@ -573,8 +639,10 @@ int mca_pml_ucx_isend(const void *buf, size_t count, ompi_datatype_t *datatype,
573639 ompi_request_t * req ;
574640 ucp_ep_h ep ;
575641
576- PML_UCX_TRACE_SEND ("isend request *%p" , buf , count , datatype , dst , tag , mode ,
577- comm , (void * )request )
642+ PML_UCX_TRACE_SEND ("i%ssend request *%p" ,
643+ buf , count , datatype , dst , tag , mode , comm ,
644+ mode == MCA_PML_BASE_SEND_BUFFERED ? "b" : "" ,
645+ (void * )request )
578646
579647 /* TODO special care to sync/buffered send */
580648
@@ -584,6 +652,13 @@ int mca_pml_ucx_isend(const void *buf, size_t count, ompi_datatype_t *datatype,
584652 return OMPI_ERROR ;
585653 }
586654
655+ /* Special care to sync/buffered send */
656+ if (OPAL_UNLIKELY (MCA_PML_BASE_SEND_BUFFERED == mode )) {
657+ * request = & ompi_pml_ucx .completed_send_req ;
658+ return mca_pml_ucx_bsend (ep , buf , count , datatype ,
659+ PML_UCX_MAKE_SEND_TAG (tag , comm ));
660+ }
661+
587662 req = (ompi_request_t * )ucp_tag_send_nb (ep , buf , count ,
588663 mca_pml_ucx_get_datatype (datatype ),
589664 PML_UCX_MAKE_SEND_TAG (tag , comm ),
@@ -609,16 +684,21 @@ int mca_pml_ucx_send(const void *buf, size_t count, ompi_datatype_t *datatype, i
609684 ompi_request_t * req ;
610685 ucp_ep_h ep ;
611686
612- PML_UCX_TRACE_SEND ("%s" , buf , count , datatype , dst , tag , mode , comm , "send" );
613-
614- /* TODO special care to sync/buffered send */
687+ PML_UCX_TRACE_SEND ("%s" , buf , count , datatype , dst , tag , mode , comm ,
688+ mode == MCA_PML_BASE_SEND_BUFFERED ? "bsend" : "send" );
615689
616690 ep = mca_pml_ucx_get_ep (comm , dst );
617691 if (OPAL_UNLIKELY (NULL == ep )) {
618692 PML_UCX_ERROR ("Failed to get ep for rank %d" , dst );
619693 return OMPI_ERROR ;
620694 }
621695
696+ /* Special care to sync/buffered send */
697+ if (OPAL_UNLIKELY (MCA_PML_BASE_SEND_BUFFERED == mode )) {
698+ return mca_pml_ucx_bsend (ep , buf , count , datatype ,
699+ PML_UCX_MAKE_SEND_TAG (tag , comm ));
700+ }
701+
622702 req = (ompi_request_t * )ucp_tag_send_nb (ep , buf , count ,
623703 mca_pml_ucx_get_datatype (datatype ),
624704 PML_UCX_MAKE_SEND_TAG (tag , comm ),
@@ -781,6 +861,7 @@ int mca_pml_ucx_start(size_t count, ompi_request_t** requests)
781861 mca_pml_ucx_persistent_request_t * preq ;
782862 ompi_request_t * tmp_req ;
783863 size_t i ;
864+ int rc ;
784865
785866 for (i = 0 ; i < count ; ++ i ) {
786867 preq = (mca_pml_ucx_persistent_request_t * )requests [i ];
@@ -795,12 +876,22 @@ int mca_pml_ucx_start(size_t count, ompi_request_t** requests)
795876 mca_pml_ucx_request_reset (& preq -> ompi );
796877
797878 if (preq -> flags & MCA_PML_UCX_REQUEST_FLAG_SEND ) {
798- /* TODO special care to sync/buffered send */
799- PML_UCX_VERBOSE (8 , "start send request %p" , (void * )preq );
800- tmp_req = (ompi_request_t * )ucp_tag_send_nb (preq -> send .ep , preq -> buffer ,
801- preq -> count , preq -> datatype ,
802- preq -> tag ,
803- mca_pml_ucx_psend_completion );
879+ if (OPAL_UNLIKELY (MCA_PML_BASE_SEND_BUFFERED == preq -> send .mode )) {
880+ PML_UCX_VERBOSE (8 , "start bsend request %p" , (void * )preq );
881+ rc = mca_pml_ucx_bsend (preq -> send .ep , preq -> buffer , preq -> count ,
882+ preq -> ompi_datatype , preq -> tag );
883+ if (OMPI_SUCCESS != rc ) {
884+ return rc ;
885+ }
886+ /* pretend that we got immediate completion */
887+ tmp_req = NULL ;
888+ } else {
889+ PML_UCX_VERBOSE (8 , "start send request %p" , (void * )preq );
890+ tmp_req = (ompi_request_t * )ucp_tag_send_nb (preq -> send .ep , preq -> buffer ,
891+ preq -> count , preq -> datatype ,
892+ preq -> tag ,
893+ mca_pml_ucx_psend_completion );
894+ }
804895 } else {
805896 PML_UCX_VERBOSE (8 , "start recv request %p" , (void * )preq );
806897 tmp_req = (ompi_request_t * )ucp_tag_recv_nb (ompi_pml_ucx .ucp_worker ,
0 commit comments