15
15
#include "opal/runtime/opal.h"
16
16
#include "opal/mca/pmix/pmix.h"
17
17
#include "ompi/message/message.h"
18
+ #include "ompi/mca/pml/base/pml_base_bsend.h"
18
19
#include "pml_ucx_request.h"
19
20
20
21
#include <inttypes.h>
@@ -333,7 +334,7 @@ static void mca_pml_ucx_waitall(void **reqs, size_t *count_p)
333
334
ucs_status_t status ;
334
335
size_t i ;
335
336
336
- PML_UCX_VERBOSE (2 , "waiting for %d disconnect requests" , * count_p );
337
+ PML_UCX_VERBOSE (2 , "waiting for %d disconnect requests" , ( int ) * count_p );
337
338
for (i = 0 ; i < * count_p ; ++ i ) {
338
339
do {
339
340
opal_progress ();
@@ -343,7 +344,7 @@ static void mca_pml_ucx_waitall(void **reqs, size_t *count_p)
343
344
PML_UCX_ERROR ("disconnect request failed: %s" ,
344
345
ucs_status_string (status ));
345
346
}
346
- ucp_request_release (reqs [i ]);
347
+ ucp_request_free (reqs [i ]);
347
348
reqs [i ] = NULL ;
348
349
}
349
350
@@ -391,7 +392,7 @@ int mca_pml_ucx_del_procs(struct ompi_proc_t **procs, size_t nprocs)
391
392
392
393
proc -> proc_endpoints [OMPI_PROC_ENDPOINT_TAG_PML ] = NULL ;
393
394
394
- if (num_reqs >= ompi_pml_ucx .num_disconnect ) {
395
+ if (( int ) num_reqs >= ompi_pml_ucx .num_disconnect ) {
395
396
mca_pml_ucx_waitall (dreqs , & num_reqs );
396
397
}
397
398
}
@@ -494,7 +495,7 @@ int mca_pml_ucx_recv(void *buf, size_t count, ompi_datatype_t *datatype, int src
494
495
PML_UCX_TRACE_RECV ("%s" , buf , count , datatype , src , tag , comm , "recv" );
495
496
496
497
PML_UCX_MAKE_RECV_TAG (ucp_tag , ucp_tag_mask , tag , src , comm );
497
- req = alloca (ompi_pml_ucx .request_size ) + ompi_pml_ucx .request_size ;
498
+ req = ( char * ) alloca (ompi_pml_ucx .request_size ) + ompi_pml_ucx .request_size ;
498
499
status = ucp_tag_recv_nbr (ompi_pml_ucx .ucp_worker , buf , count ,
499
500
mca_pml_ucx_get_datatype (datatype ),
500
501
ucp_tag , ucp_tag_mask , req );
@@ -556,15 +557,80 @@ int mca_pml_ucx_isend_init(const void *buf, size_t count, ompi_datatype_t *datat
556
557
req -> flags = MCA_PML_UCX_REQUEST_FLAG_SEND ;
557
558
req -> buffer = (void * )buf ;
558
559
req -> count = count ;
559
- req -> datatype = mca_pml_ucx_get_datatype (datatype );
560
560
req -> tag = PML_UCX_MAKE_SEND_TAG (tag , comm );
561
561
req -> send .mode = mode ;
562
562
req -> send .ep = ep ;
563
+ if (MCA_PML_BASE_SEND_BUFFERED == mode ) {
564
+ req -> ompi_datatype = datatype ;
565
+ OBJ_RETAIN (datatype );
566
+ } else {
567
+ req -> datatype = mca_pml_ucx_get_datatype (datatype );
568
+ }
563
569
564
570
* request = & req -> ompi ;
565
571
return OMPI_SUCCESS ;
566
572
}
567
573
574
+ static int
575
+ mca_pml_ucx_bsend (ucp_ep_h ep , const void * buf , size_t count ,
576
+ ompi_datatype_t * datatype , uint64_t pml_tag )
577
+ {
578
+ ompi_request_t * req ;
579
+ void * packed_data ;
580
+ size_t packed_length ;
581
+ size_t offset ;
582
+ uint32_t iov_count ;
583
+ struct iovec iov ;
584
+ opal_convertor_t opal_conv ;
585
+
586
+ OBJ_CONSTRUCT (& opal_conv , opal_convertor_t );
587
+ opal_convertor_copy_and_prepare_for_recv (ompi_proc_local_proc -> super .proc_convertor ,
588
+ & datatype -> super , count , buf , 0 ,
589
+ & opal_conv );
590
+ opal_convertor_get_packed_size (& opal_conv , & packed_length );
591
+
592
+ packed_data = mca_pml_base_bsend_request_alloc_buf (packed_length );
593
+ if (OPAL_UNLIKELY (NULL == packed_data )) {
594
+ OBJ_DESTRUCT (& opal_conv );
595
+ PML_UCX_ERROR ("bsend: failed to allocate buffer" );
596
+ return OMPI_ERR_OUT_OF_RESOURCE ;
597
+ }
598
+
599
+ iov_count = 1 ;
600
+ iov .iov_base = packed_data ;
601
+ iov .iov_len = packed_length ;
602
+
603
+ PML_UCX_VERBOSE (8 , "bsend of packed buffer %p len %d\n" , packed_data , packed_length );
604
+ offset = 0 ;
605
+ opal_convertor_set_position (& opal_conv , & offset );
606
+ if (0 > opal_convertor_pack (& opal_conv , & iov , & iov_count , & packed_length )) {
607
+ mca_pml_base_bsend_request_free (packed_data );
608
+ OBJ_DESTRUCT (& opal_conv );
609
+ PML_UCX_ERROR ("bsend: failed to pack user datatype" );
610
+ return OMPI_ERROR ;
611
+ }
612
+
613
+ OBJ_DESTRUCT (& opal_conv );
614
+
615
+ req = (ompi_request_t * )ucp_tag_send_nb (ep , packed_data , packed_length ,
616
+ ucp_dt_make_contig (1 ), pml_tag ,
617
+ mca_pml_ucx_bsend_completion );
618
+ if (NULL == req ) {
619
+ /* request was completed in place */
620
+ mca_pml_base_bsend_request_free (packed_data );
621
+ return OMPI_SUCCESS ;
622
+ }
623
+
624
+ if (OPAL_UNLIKELY (UCS_PTR_IS_ERR (req ))) {
625
+ mca_pml_base_bsend_request_free (packed_data );
626
+ PML_UCX_ERROR ("ucx bsend failed: %s" , ucs_status_string (UCS_PTR_STATUS (req )));
627
+ return OMPI_ERROR ;
628
+ }
629
+
630
+ req -> req_complete_cb_data = packed_data ;
631
+ return OMPI_SUCCESS ;
632
+ }
633
+
568
634
int mca_pml_ucx_isend (const void * buf , size_t count , ompi_datatype_t * datatype ,
569
635
int dst , int tag , mca_pml_base_send_mode_t mode ,
570
636
struct ompi_communicator_t * comm ,
@@ -573,8 +639,10 @@ int mca_pml_ucx_isend(const void *buf, size_t count, ompi_datatype_t *datatype,
573
639
ompi_request_t * req ;
574
640
ucp_ep_h ep ;
575
641
576
- PML_UCX_TRACE_SEND ("isend request *%p" , buf , count , datatype , dst , tag , mode ,
577
- comm , (void * )request )
642
+ PML_UCX_TRACE_SEND ("i%ssend request *%p" ,
643
+ buf , count , datatype , dst , tag , mode , comm ,
644
+ mode == MCA_PML_BASE_SEND_BUFFERED ? "b" : "" ,
645
+ (void * )request )
578
646
579
647
/* TODO special care to sync/buffered send */
580
648
@@ -584,6 +652,13 @@ int mca_pml_ucx_isend(const void *buf, size_t count, ompi_datatype_t *datatype,
584
652
return OMPI_ERROR ;
585
653
}
586
654
655
+ /* Special care to sync/buffered send */
656
+ if (OPAL_UNLIKELY (MCA_PML_BASE_SEND_BUFFERED == mode )) {
657
+ * request = & ompi_pml_ucx .completed_send_req ;
658
+ return mca_pml_ucx_bsend (ep , buf , count , datatype ,
659
+ PML_UCX_MAKE_SEND_TAG (tag , comm ));
660
+ }
661
+
587
662
req = (ompi_request_t * )ucp_tag_send_nb (ep , buf , count ,
588
663
mca_pml_ucx_get_datatype (datatype ),
589
664
PML_UCX_MAKE_SEND_TAG (tag , comm ),
@@ -609,16 +684,21 @@ int mca_pml_ucx_send(const void *buf, size_t count, ompi_datatype_t *datatype, i
609
684
ompi_request_t * req ;
610
685
ucp_ep_h ep ;
611
686
612
- PML_UCX_TRACE_SEND ("%s" , buf , count , datatype , dst , tag , mode , comm , "send" );
613
-
614
- /* TODO special care to sync/buffered send */
687
+ PML_UCX_TRACE_SEND ("%s" , buf , count , datatype , dst , tag , mode , comm ,
688
+ mode == MCA_PML_BASE_SEND_BUFFERED ? "bsend" : "send" );
615
689
616
690
ep = mca_pml_ucx_get_ep (comm , dst );
617
691
if (OPAL_UNLIKELY (NULL == ep )) {
618
692
PML_UCX_ERROR ("Failed to get ep for rank %d" , dst );
619
693
return OMPI_ERROR ;
620
694
}
621
695
696
+ /* Special care to sync/buffered send */
697
+ if (OPAL_UNLIKELY (MCA_PML_BASE_SEND_BUFFERED == mode )) {
698
+ return mca_pml_ucx_bsend (ep , buf , count , datatype ,
699
+ PML_UCX_MAKE_SEND_TAG (tag , comm ));
700
+ }
701
+
622
702
req = (ompi_request_t * )ucp_tag_send_nb (ep , buf , count ,
623
703
mca_pml_ucx_get_datatype (datatype ),
624
704
PML_UCX_MAKE_SEND_TAG (tag , comm ),
@@ -781,6 +861,7 @@ int mca_pml_ucx_start(size_t count, ompi_request_t** requests)
781
861
mca_pml_ucx_persistent_request_t * preq ;
782
862
ompi_request_t * tmp_req ;
783
863
size_t i ;
864
+ int rc ;
784
865
785
866
for (i = 0 ; i < count ; ++ i ) {
786
867
preq = (mca_pml_ucx_persistent_request_t * )requests [i ];
@@ -795,12 +876,22 @@ int mca_pml_ucx_start(size_t count, ompi_request_t** requests)
795
876
mca_pml_ucx_request_reset (& preq -> ompi );
796
877
797
878
if (preq -> flags & MCA_PML_UCX_REQUEST_FLAG_SEND ) {
798
- /* TODO special care to sync/buffered send */
799
- PML_UCX_VERBOSE (8 , "start send request %p" , (void * )preq );
800
- tmp_req = (ompi_request_t * )ucp_tag_send_nb (preq -> send .ep , preq -> buffer ,
801
- preq -> count , preq -> datatype ,
802
- preq -> tag ,
803
- mca_pml_ucx_psend_completion );
879
+ if (OPAL_UNLIKELY (MCA_PML_BASE_SEND_BUFFERED == preq -> send .mode )) {
880
+ PML_UCX_VERBOSE (8 , "start bsend request %p" , (void * )preq );
881
+ rc = mca_pml_ucx_bsend (preq -> send .ep , preq -> buffer , preq -> count ,
882
+ preq -> ompi_datatype , preq -> tag );
883
+ if (OMPI_SUCCESS != rc ) {
884
+ return rc ;
885
+ }
886
+ /* pretend that we got immediate completion */
887
+ tmp_req = NULL ;
888
+ } else {
889
+ PML_UCX_VERBOSE (8 , "start send request %p" , (void * )preq );
890
+ tmp_req = (ompi_request_t * )ucp_tag_send_nb (preq -> send .ep , preq -> buffer ,
891
+ preq -> count , preq -> datatype ,
892
+ preq -> tag ,
893
+ mca_pml_ucx_psend_completion );
894
+ }
804
895
} else {
805
896
PML_UCX_VERBOSE (8 , "start recv request %p" , (void * )preq );
806
897
tmp_req = (ompi_request_t * )ucp_tag_recv_nb (ompi_pml_ucx .ucp_worker ,
0 commit comments