1515#include "opal/runtime/opal.h"
1616#include "opal/mca/pmix/pmix.h"
1717#include "ompi/message/message.h"
18+ #include "ompi/mca/pml/base/pml_base_bsend.h"
1819#include "pml_ucx_request.h"
1920
2021#include <inttypes.h>
@@ -506,15 +507,80 @@ int mca_pml_ucx_isend_init(const void *buf, size_t count, ompi_datatype_t *datat
506507 req -> flags = MCA_PML_UCX_REQUEST_FLAG_SEND ;
507508 req -> buffer = (void * )buf ;
508509 req -> count = count ;
509- req -> datatype = mca_pml_ucx_get_datatype (datatype );
510510 req -> tag = PML_UCX_MAKE_SEND_TAG (tag , comm );
511511 req -> send .mode = mode ;
512512 req -> send .ep = ep ;
513+ if (MCA_PML_BASE_SEND_BUFFERED == mode ) {
514+ req -> ompi_datatype = datatype ;
515+ OBJ_RETAIN (datatype );
516+ } else {
517+ req -> datatype = mca_pml_ucx_get_datatype (datatype );
518+ }
513519
514520 * request = & req -> ompi ;
515521 return OMPI_SUCCESS ;
516522}
517523
524+ static int
525+ mca_pml_ucx_bsend (ucp_ep_h ep , const void * buf , size_t count ,
526+ ompi_datatype_t * datatype , uint64_t pml_tag )
527+ {
528+ ompi_request_t * req ;
529+ void * packed_data ;
530+ size_t packed_length ;
531+ size_t offset ;
532+ uint32_t iov_count ;
533+ struct iovec iov ;
534+ opal_convertor_t opal_conv ;
535+
536+ OBJ_CONSTRUCT (& opal_conv , opal_convertor_t );
537+ opal_convertor_copy_and_prepare_for_recv (ompi_proc_local_proc -> super .proc_convertor ,
538+ & datatype -> super , count , buf , 0 ,
539+ & opal_conv );
540+ opal_convertor_get_packed_size (& opal_conv , & packed_length );
541+
542+ packed_data = mca_pml_base_bsend_request_alloc_buf (packed_length );
543+ if (OPAL_UNLIKELY (NULL == packed_data )) {
544+ OBJ_DESTRUCT (& opal_conv );
545+ PML_UCX_ERROR ("bsend: failed to allocate buffer" );
546+ return OMPI_ERR_OUT_OF_RESOURCE ;
547+ }
548+
549+ iov_count = 1 ;
550+ iov .iov_base = packed_data ;
551+ iov .iov_len = packed_length ;
552+
553+ PML_UCX_VERBOSE (8 , "bsend of packed buffer %p len %d\n" , packed_data , packed_length );
554+ offset = 0 ;
555+ opal_convertor_set_position (& opal_conv , & offset );
556+ if (0 > opal_convertor_pack (& opal_conv , & iov , & iov_count , & packed_length )) {
557+ mca_pml_base_bsend_request_free (packed_data );
558+ OBJ_DESTRUCT (& opal_conv );
559+ PML_UCX_ERROR ("bsend: failed to pack user datatype" );
560+ return OMPI_ERROR ;
561+ }
562+
563+ OBJ_DESTRUCT (& opal_conv );
564+
565+ req = (ompi_request_t * )ucp_tag_send_nb (ep , packed_data , packed_length ,
566+ ucp_dt_make_contig (1 ), pml_tag ,
567+ mca_pml_ucx_bsend_completion );
568+ if (NULL == req ) {
569+ /* request was completed in place */
570+ mca_pml_base_bsend_request_free (packed_data );
571+ return OMPI_SUCCESS ;
572+ }
573+
574+ if (OPAL_UNLIKELY (UCS_PTR_IS_ERR (req ))) {
575+ mca_pml_base_bsend_request_free (packed_data );
576+ PML_UCX_ERROR ("ucx bsend failed: %s" , ucs_status_string (UCS_PTR_STATUS (req )));
577+ return OMPI_ERROR ;
578+ }
579+
580+ req -> req_complete_cb_data = packed_data ;
581+ return OMPI_SUCCESS ;
582+ }
583+
518584int mca_pml_ucx_isend (const void * buf , size_t count , ompi_datatype_t * datatype ,
519585 int dst , int tag , mca_pml_base_send_mode_t mode ,
520586 struct ompi_communicator_t * comm ,
@@ -523,8 +589,10 @@ int mca_pml_ucx_isend(const void *buf, size_t count, ompi_datatype_t *datatype,
523589 ompi_request_t * req ;
524590 ucp_ep_h ep ;
525591
526- PML_UCX_TRACE_SEND ("isend request *%p" , buf , count , datatype , dst , tag , mode ,
527- comm , (void * )request )
592+ PML_UCX_TRACE_SEND ("i%ssend request *%p" ,
593+ buf , count , datatype , dst , tag , mode , comm ,
594+ mode == MCA_PML_BASE_SEND_BUFFERED ? "b" : "" ,
595+ (void * )request )
528596
529597 /* TODO special care to sync/buffered send */
530598
@@ -534,6 +602,13 @@ int mca_pml_ucx_isend(const void *buf, size_t count, ompi_datatype_t *datatype,
534602 return OMPI_ERROR ;
535603 }
536604
605+ /* Special care to sync/buffered send */
606+ if (OPAL_UNLIKELY (MCA_PML_BASE_SEND_BUFFERED == mode )) {
607+ * request = & ompi_pml_ucx .completed_send_req ;
608+ return mca_pml_ucx_bsend (ep , buf , count , datatype ,
609+ PML_UCX_MAKE_SEND_TAG (tag , comm ));
610+ }
611+
537612 req = (ompi_request_t * )ucp_tag_send_nb (ep , buf , count ,
538613 mca_pml_ucx_get_datatype (datatype ),
539614 PML_UCX_MAKE_SEND_TAG (tag , comm ),
@@ -559,16 +634,21 @@ int mca_pml_ucx_send(const void *buf, size_t count, ompi_datatype_t *datatype, i
559634 ompi_request_t * req ;
560635 ucp_ep_h ep ;
561636
562- PML_UCX_TRACE_SEND ("%s" , buf , count , datatype , dst , tag , mode , comm , "send" );
563-
564- /* TODO special care to sync/buffered send */
637+ PML_UCX_TRACE_SEND ("%s" , buf , count , datatype , dst , tag , mode , comm ,
638+ mode == MCA_PML_BASE_SEND_BUFFERED ? "bsend" : "send" );
565639
566640 ep = mca_pml_ucx_get_ep (comm , dst );
567641 if (OPAL_UNLIKELY (NULL == ep )) {
568642 PML_UCX_ERROR ("Failed to get ep for rank %d" , dst );
569643 return OMPI_ERROR ;
570644 }
571645
646+ /* Special care to sync/buffered send */
647+ if (OPAL_UNLIKELY (MCA_PML_BASE_SEND_BUFFERED == mode )) {
648+ return mca_pml_ucx_bsend (ep , buf , count , datatype ,
649+ PML_UCX_MAKE_SEND_TAG (tag , comm ));
650+ }
651+
572652 req = (ompi_request_t * )ucp_tag_send_nb (ep , buf , count ,
573653 mca_pml_ucx_get_datatype (datatype ),
574654 PML_UCX_MAKE_SEND_TAG (tag , comm ),
@@ -729,6 +809,7 @@ int mca_pml_ucx_start(size_t count, ompi_request_t** requests)
729809 mca_pml_ucx_persistent_request_t * preq ;
730810 ompi_request_t * tmp_req ;
731811 size_t i ;
812+ int rc ;
732813
733814 for (i = 0 ; i < count ; ++ i ) {
734815 preq = (mca_pml_ucx_persistent_request_t * )requests [i ];
@@ -743,12 +824,22 @@ int mca_pml_ucx_start(size_t count, ompi_request_t** requests)
743824 mca_pml_ucx_request_reset (& preq -> ompi );
744825
745826 if (preq -> flags & MCA_PML_UCX_REQUEST_FLAG_SEND ) {
746- /* TODO special care to sync/buffered send */
747- PML_UCX_VERBOSE (8 , "start send request %p" , (void * )preq );
748- tmp_req = (ompi_request_t * )ucp_tag_send_nb (preq -> send .ep , preq -> buffer ,
749- preq -> count , preq -> datatype ,
750- preq -> tag ,
751- mca_pml_ucx_psend_completion );
827+ if (OPAL_UNLIKELY (MCA_PML_BASE_SEND_BUFFERED == preq -> send .mode )) {
828+ PML_UCX_VERBOSE (8 , "start bsend request %p" , (void * )preq );
829+ rc = mca_pml_ucx_bsend (preq -> send .ep , preq -> buffer , preq -> count ,
830+ preq -> ompi_datatype , preq -> tag );
831+ if (OMPI_SUCCESS != rc ) {
832+ return rc ;
833+ }
834+ /* pretend that we got immediate completion */
835+ tmp_req = NULL ;
836+ } else {
837+ PML_UCX_VERBOSE (8 , "start send request %p" , (void * )preq );
838+ tmp_req = (ompi_request_t * )ucp_tag_send_nb (preq -> send .ep , preq -> buffer ,
839+ preq -> count , preq -> datatype ,
840+ preq -> tag ,
841+ mca_pml_ucx_psend_completion );
842+ }
752843 } else {
753844 PML_UCX_VERBOSE (8 , "start recv request %p" , (void * )preq );
754845 tmp_req = (ompi_request_t * )ucp_tag_recv_nb (ompi_pml_ucx .ucp_worker ,
0 commit comments