3232#include "oshmem/proc/proc.h"
3333#include "oshmem/mca/spml/base/base.h"
3434#include "oshmem/mca/spml/base/spml_base_putreq.h"
35+ #include "oshmem/mca/atomic/atomic.h"
3536#include "oshmem/runtime/runtime.h"
3637
3738#include "oshmem/mca/spml/ucx/spml_ucx_component.h"
@@ -67,6 +68,7 @@ mca_spml_ucx_t mca_spml_ucx = {
6768 .spml_rmkey_free = mca_spml_ucx_rmkey_free ,
6869 .spml_rmkey_ptr = mca_spml_ucx_rmkey_ptr ,
6970 .spml_memuse_hook = mca_spml_ucx_memuse_hook ,
71+ .spml_put_all_nb = mca_spml_ucx_put_all_nb ,
7072 .self = (void * )& mca_spml_ucx
7173 },
7274
@@ -439,8 +441,8 @@ sshmem_mkey_t *mca_spml_ucx_register(void* addr,
439441 ucx_mkey -> mem_h = (ucp_mem_h )mem_seg -> context ;
440442 }
441443
442- status = ucp_rkey_pack (mca_spml_ucx .ucp_context , ucx_mkey -> mem_h ,
443- & mkeys [0 ].u .data , & len );
444+ status = ucp_rkey_pack (mca_spml_ucx .ucp_context , ucx_mkey -> mem_h ,
445+ & mkeys [0 ].u .data , & len );
444446 if (UCS_OK != status ) {
445447 goto error_unmap ;
446448 }
@@ -477,8 +479,6 @@ int mca_spml_ucx_deregister(sshmem_mkey_t *mkeys)
477479{
478480 spml_ucx_mkey_t * ucx_mkey ;
479481 map_segment_t * mem_seg ;
480- int segno ;
481- int my_pe = oshmem_my_proc_id ();
482482
483483 MCA_SPML_CALL (quiet (oshmem_ctx_default ));
484484 if (!mkeys )
@@ -493,7 +493,7 @@ int mca_spml_ucx_deregister(sshmem_mkey_t *mkeys)
493493 if (OPAL_UNLIKELY (NULL == mem_seg )) {
494494 return OSHMEM_ERROR ;
495495 }
496-
496+
497497 if (MAP_SEGMENT_ALLOC_UCX != mem_seg -> type ) {
498498 ucp_mem_unmap (mca_spml_ucx .ucp_context , ucx_mkey -> mem_h );
499499 }
@@ -545,17 +545,15 @@ static inline void _ctx_remove(mca_spml_ucx_ctx_array_t *array, mca_spml_ucx_ctx
545545 opal_atomic_wmb ();
546546}
547547
548- int mca_spml_ucx_ctx_create (long options , shmem_ctx_t * ctx )
548+ static int mca_spml_ucx_ctx_create_common (long options , mca_spml_ucx_ctx_t * * ucx_ctx_p )
549549{
550- mca_spml_ucx_ctx_t * ucx_ctx ;
551550 ucp_worker_params_t params ;
552551 ucp_ep_params_t ep_params ;
553552 size_t i , j , nprocs = oshmem_num_procs ();
554553 ucs_status_t err ;
555- int my_pe = oshmem_my_proc_id ();
556- size_t len ;
557554 spml_ucx_mkey_t * ucx_mkey ;
558555 sshmem_mkey_t * mkey ;
556+ mca_spml_ucx_ctx_t * ucx_ctx ;
559557 int rc = OSHMEM_ERROR ;
560558
561559 ucx_ctx = malloc (sizeof (mca_spml_ucx_ctx_t ));
@@ -580,10 +578,6 @@ int mca_spml_ucx_ctx_create(long options, shmem_ctx_t *ctx)
580578 goto error ;
581579 }
582580
583- if (mca_spml_ucx .active_array .ctxs_count == 0 ) {
584- opal_progress_register (spml_ucx_ctx_progress );
585- }
586-
587581 for (i = 0 ; i < nprocs ; i ++ ) {
588582 ep_params .field_mask = UCP_EP_PARAM_FIELD_REMOTE_ADDRESS ;
589583 ep_params .address = (ucp_address_t * )(mca_spml_ucx .remote_addrs_tbl [i ]);
@@ -609,11 +603,8 @@ int mca_spml_ucx_ctx_create(long options, shmem_ctx_t *ctx)
609603 }
610604 }
611605
612- SHMEM_MUTEX_LOCK (mca_spml_ucx .internal_mutex );
613- _ctx_add (& mca_spml_ucx .active_array , ucx_ctx );
614- SHMEM_MUTEX_UNLOCK (mca_spml_ucx .internal_mutex );
606+ * ucx_ctx_p = ucx_ctx ;
615607
616- (* ctx ) = (shmem_ctx_t )ucx_ctx ;
617608 return OSHMEM_SUCCESS ;
618609
619610 error2 :
@@ -634,6 +625,33 @@ int mca_spml_ucx_ctx_create(long options, shmem_ctx_t *ctx)
634625 return rc ;
635626}
636627
628+ int mca_spml_ucx_ctx_create (long options , shmem_ctx_t * ctx )
629+ {
630+ mca_spml_ucx_ctx_t * ucx_ctx ;
631+ int rc ;
632+
633+ /* Take a lock controlling aux context. AUX context may set specific
634+ * UCX parameters affecting worker creation, which are not needed for
635+ * regular contexts. */
636+ mca_spml_ucx_aux_lock ();
637+ rc = mca_spml_ucx_ctx_create_common (options , & ucx_ctx );
638+ mca_spml_ucx_aux_unlock ();
639+ if (rc != OSHMEM_SUCCESS ) {
640+ return rc ;
641+ }
642+
643+ if (mca_spml_ucx .active_array .ctxs_count == 0 ) {
644+ opal_progress_register (spml_ucx_ctx_progress );
645+ }
646+
647+ SHMEM_MUTEX_LOCK (mca_spml_ucx .internal_mutex );
648+ _ctx_add (& mca_spml_ucx .active_array , ucx_ctx );
649+ SHMEM_MUTEX_UNLOCK (mca_spml_ucx .internal_mutex );
650+
651+ (* ctx ) = (shmem_ctx_t )ucx_ctx ;
652+ return OSHMEM_SUCCESS ;
653+ }
654+
637655void mca_spml_ucx_ctx_destroy (shmem_ctx_t ctx )
638656{
639657 MCA_SPML_CALL (quiet (ctx ));
@@ -748,6 +766,15 @@ int mca_spml_ucx_quiet(shmem_ctx_t ctx)
748766 oshmem_shmem_abort (-1 );
749767 return ret ;
750768 }
769+
770+ /* If put_all_nb op/s is/are being executed asynchronously, need to wait its
771+ * completion as well. */
772+ if (ctx == oshmem_ctx_default ) {
773+ while (mca_spml_ucx .aux_refcnt ) {
774+ opal_progress ();
775+ }
776+ }
777+
751778 return OSHMEM_SUCCESS ;
752779}
753780
@@ -785,3 +812,99 @@ int mca_spml_ucx_send(void* buf,
785812
786813 return rc ;
787814}
815+
816+ static void mca_spml_ucx_put_all_complete_cb (void * request , ucs_status_t status )
817+ {
818+ if (mca_spml_ucx .async_progress && (-- mca_spml_ucx .aux_refcnt == 0 )) {
819+ opal_event_evtimer_del (mca_spml_ucx .tick_event );
820+ opal_progress_unregister (spml_ucx_progress_aux_ctx );
821+ }
822+
823+ if (request != NULL ) {
824+ ucp_request_free (request );
825+ }
826+ }
827+
828+ /* Should be called with AUX lock taken */
829+ static int mca_spml_ucx_create_aux_ctx (void )
830+ {
831+ unsigned major = 0 ;
832+ unsigned minor = 0 ;
833+ unsigned rel_number = 0 ;
834+ int rc ;
835+ bool rand_dci_supp ;
836+
837+ ucp_get_version (& major , & minor , & rel_number );
838+ rand_dci_supp = UCX_VERSION (major , minor , rel_number ) >= UCX_VERSION (1 , 6 , 0 );
839+
840+ if (rand_dci_supp ) {
841+ opal_setenv ("UCX_DC_TX_POLICY" , "rand" , 1 , & environ );
842+ opal_setenv ("UCX_DC_MLX5_TX_POLICY" , "rand" , 1 , & environ );
843+ }
844+
845+ rc = mca_spml_ucx_ctx_create_common (SHMEM_CTX_PRIVATE , & mca_spml_ucx .aux_ctx );
846+
847+ if (rand_dci_supp ) {
848+ opal_unsetenv ("UCX_DC_TX_POLICY" , & environ );
849+ opal_unsetenv ("UCX_DC_MLX5_TX_POLICY" , & environ );
850+ }
851+
852+ return rc ;
853+ }
854+
855+ int mca_spml_ucx_put_all_nb (void * dest , const void * source , size_t size , long * counter )
856+ {
857+ int my_pe = oshmem_my_proc_id ();
858+ long val = 1 ;
859+ int peer , dst_pe , rc ;
860+ shmem_ctx_t ctx ;
861+ struct timeval tv ;
862+ void * request ;
863+
864+ mca_spml_ucx_aux_lock ();
865+ if (mca_spml_ucx .async_progress ) {
866+ if (mca_spml_ucx .aux_ctx == NULL ) {
867+ rc = mca_spml_ucx_create_aux_ctx ();
868+ if (rc != OMPI_SUCCESS ) {
869+ mca_spml_ucx_aux_unlock ();
870+ oshmem_shmem_abort (-1 );
871+ }
872+ }
873+
874+ if (!mca_spml_ucx .aux_refcnt ) {
875+ tv .tv_sec = 0 ;
876+ tv .tv_usec = mca_spml_ucx .async_tick ;
877+ opal_event_evtimer_add (mca_spml_ucx .tick_event , & tv );
878+ opal_progress_register (spml_ucx_progress_aux_ctx );
879+ }
880+ ctx = (shmem_ctx_t )mca_spml_ucx .aux_ctx ;
881+ ++ mca_spml_ucx .aux_refcnt ;
882+ } else {
883+ ctx = oshmem_ctx_default ;
884+ }
885+
886+ for (peer = 0 ; peer < oshmem_num_procs (); peer ++ ) {
887+ dst_pe = (peer + my_pe ) % oshmem_group_all -> proc_count ;
888+ rc = mca_spml_ucx_put_nb (ctx ,
889+ (void * )((uintptr_t )dest + my_pe * size ),
890+ size ,
891+ (void * )((uintptr_t )source + dst_pe * size ),
892+ dst_pe , NULL );
893+ RUNTIME_CHECK_RC (rc );
894+
895+ mca_spml_ucx_fence (ctx );
896+
897+ rc = MCA_ATOMIC_CALL (add (ctx , (void * )counter , val , sizeof (val ), dst_pe ));
898+ RUNTIME_CHECK_RC (rc );
899+ }
900+
901+ request = ucp_worker_flush_nb (((mca_spml_ucx_ctx_t * )ctx )-> ucp_worker , 0 ,
902+ mca_spml_ucx_put_all_complete_cb );
903+ if (!UCS_PTR_IS_PTR (request )) {
904+ mca_spml_ucx_put_all_complete_cb (NULL , UCS_PTR_STATUS (request ));
905+ }
906+
907+ mca_spml_ucx_aux_unlock ();
908+
909+ return OSHMEM_SUCCESS ;
910+ }
0 commit comments