@@ -75,6 +75,9 @@ mca_pml_ucx_module_t ompi_pml_ucx = {
7575 NULL /* ucp_worker */
7676};
7777
78+ #define PML_UCX_REQ_ALLOCA () \
79+ ((char *)alloca(ompi_pml_ucx.request_size) + ompi_pml_ucx.request_size);
80+
7881static int mca_pml_ucx_send_worker_address (void )
7982{
8083 ucp_address_t * address ;
@@ -525,7 +528,7 @@ int mca_pml_ucx_recv(void *buf, size_t count, ompi_datatype_t *datatype, int src
525528 PML_UCX_TRACE_RECV ("%s" , buf , count , datatype , src , tag , comm , "recv" );
526529
527530 PML_UCX_MAKE_RECV_TAG (ucp_tag , ucp_tag_mask , tag , src , comm );
528- req = ( char * ) alloca ( ompi_pml_ucx . request_size ) + ompi_pml_ucx . request_size ;
531+ req = PML_UCX_REQ_ALLOCA () ;
529532 status = ucp_tag_recv_nbr (ompi_pml_ucx .ucp_worker , buf , count ,
530533 mca_pml_ucx_get_datatype (datatype ),
531534 ucp_tag , ucp_tag_mask , req );
@@ -715,26 +718,18 @@ int mca_pml_ucx_isend(const void *buf, size_t count, ompi_datatype_t *datatype,
715718 }
716719}
717720
718- int mca_pml_ucx_send (const void * buf , size_t count , ompi_datatype_t * datatype , int dst ,
719- int tag , mca_pml_base_send_mode_t mode ,
720- struct ompi_communicator_t * comm )
721+ static inline __opal_attribute_always_inline__ int
722+ mca_pml_ucx_send_nb (ucp_ep_h ep , const void * buf , size_t count ,
723+ ompi_datatype_t * datatype , ucp_datatype_t ucx_datatype ,
724+ ucp_tag_t tag , mca_pml_base_send_mode_t mode ,
725+ ucp_send_callback_t cb )
721726{
722727 ompi_request_t * req ;
723- ucp_ep_h ep ;
724-
725- PML_UCX_TRACE_SEND ("%s" , buf , count , datatype , dst , tag , mode , comm ,
726- mode == MCA_PML_BASE_SEND_BUFFERED ? "bsend" : "send" );
727-
728- ep = mca_pml_ucx_get_ep (comm , dst );
729- if (OPAL_UNLIKELY (NULL == ep )) {
730- PML_UCX_ERROR ("Failed to get ep for rank %d" , dst );
731- return OMPI_ERROR ;
732- }
733728
734729 req = (ompi_request_t * )mca_pml_ucx_common_send (ep , buf , count , datatype ,
735730 mca_pml_ucx_get_datatype (datatype ),
736- PML_UCX_MAKE_SEND_TAG ( tag , comm ) ,
737- mode , mca_pml_ucx_send_completion );
731+ tag , mode ,
732+ mca_pml_ucx_send_completion );
738733
739734 if (OPAL_LIKELY (req == NULL )) {
740735 return OMPI_SUCCESS ;
@@ -749,6 +744,60 @@ int mca_pml_ucx_send(const void *buf, size_t count, ompi_datatype_t *datatype, i
749744 }
750745}
751746
747+ #if HAVE_DECL_UCP_TAG_SEND_NBR
748+ static inline __opal_attribute_always_inline__ int
749+ mca_pml_ucx_send_nbr (ucp_ep_h ep , const void * buf , size_t count ,
750+ ucp_datatype_t ucx_datatype , ucp_tag_t tag )
751+
752+ {
753+ void * req ;
754+ ucs_status_t status ;
755+
756+ req = PML_UCX_REQ_ALLOCA ();
757+ status = ucp_tag_send_nbr (ep , buf , count , ucx_datatype , tag , req );
758+ if (OPAL_LIKELY (status == UCS_OK )) {
759+ return OMPI_SUCCESS ;
760+ }
761+
762+ ucp_worker_progress (ompi_pml_ucx .ucp_worker );
763+ while ((status = ucp_request_check_status (req )) == UCS_INPROGRESS ) {
764+ opal_progress ();
765+ }
766+
767+ return OPAL_LIKELY (UCS_OK == status ) ? OMPI_SUCCESS : OMPI_ERROR ;
768+ }
769+ #endif
770+
771+ int mca_pml_ucx_send (const void * buf , size_t count , ompi_datatype_t * datatype , int dst ,
772+ int tag , mca_pml_base_send_mode_t mode ,
773+ struct ompi_communicator_t * comm )
774+ {
775+ ucp_ep_h ep ;
776+
777+ PML_UCX_TRACE_SEND ("%s" , buf , count , datatype , dst , tag , mode , comm ,
778+ mode == MCA_PML_BASE_SEND_BUFFERED ? "bsend" : "send" );
779+
780+ ep = mca_pml_ucx_get_ep (comm , dst );
781+ if (OPAL_UNLIKELY (NULL == ep )) {
782+ PML_UCX_ERROR ("Failed to get ep for rank %d" , dst );
783+ return OMPI_ERROR ;
784+ }
785+
786+ #if HAVE_DECL_UCP_TAG_SEND_NBR
787+ if (OPAL_LIKELY ((MCA_PML_BASE_SEND_BUFFERED != mode ) &&
788+ (MCA_PML_BASE_SEND_SYNCHRONOUS != mode ))) {
789+ return mca_pml_ucx_send_nbr (ep , buf , count ,
790+ mca_pml_ucx_get_datatype (datatype ),
791+ PML_UCX_MAKE_SEND_TAG (tag , comm ));
792+ }
793+ #endif
794+
795+ return mca_pml_ucx_send_nb (ep , buf , count , datatype ,
796+ mca_pml_ucx_get_datatype (datatype ),
797+ PML_UCX_MAKE_SEND_TAG (tag , comm ), mode ,
798+ mca_pml_ucx_send_completion );
799+ }
800+
752801int mca_pml_ucx_iprobe (int src , int tag , struct ompi_communicator_t * comm ,
753802 int * matched , ompi_status_public_t * mpi_status )
754803{
0 commit comments