3232#include "oshmem/proc/proc.h"
3333#include "oshmem/mca/spml/base/base.h"
3434#include "oshmem/mca/spml/base/spml_base_putreq.h"
35+ #include "oshmem/mca/atomic/atomic.h"
3536#include "oshmem/runtime/runtime.h"
3637
3738#include "oshmem/mca/spml/ucx/spml_ucx_component.h"
@@ -67,6 +68,7 @@ mca_spml_ucx_t mca_spml_ucx = {
6768 .spml_rmkey_free = mca_spml_ucx_rmkey_free ,
6869 .spml_rmkey_ptr = mca_spml_ucx_rmkey_ptr ,
6970 .spml_memuse_hook = mca_spml_ucx_memuse_hook ,
71+ .spml_put_all_nb = mca_spml_ucx_put_all_nb ,
7072 .self = (void * )& mca_spml_ucx
7173 },
7274
@@ -439,8 +441,8 @@ sshmem_mkey_t *mca_spml_ucx_register(void* addr,
439441 ucx_mkey -> mem_h = (ucp_mem_h )mem_seg -> context ;
440442 }
441443
442- status = ucp_rkey_pack (mca_spml_ucx .ucp_context , ucx_mkey -> mem_h ,
443- & mkeys [0 ].u .data , & len );
444+ status = ucp_rkey_pack (mca_spml_ucx .ucp_context , ucx_mkey -> mem_h ,
445+ & mkeys [0 ].u .data , & len );
444446 if (UCS_OK != status ) {
445447 goto error_unmap ;
446448 }
@@ -477,8 +479,6 @@ int mca_spml_ucx_deregister(sshmem_mkey_t *mkeys)
477479{
478480 spml_ucx_mkey_t * ucx_mkey ;
479481 map_segment_t * mem_seg ;
480- int segno ;
481- int my_pe = oshmem_my_proc_id ();
482482
483483 MCA_SPML_CALL (quiet (oshmem_ctx_default ));
484484 if (!mkeys )
@@ -493,7 +493,7 @@ int mca_spml_ucx_deregister(sshmem_mkey_t *mkeys)
493493 if (OPAL_UNLIKELY (NULL == mem_seg )) {
494494 return OSHMEM_ERROR ;
495495 }
496-
496+
497497 if (MAP_SEGMENT_ALLOC_UCX != mem_seg -> type ) {
498498 ucp_mem_unmap (mca_spml_ucx .ucp_context , ucx_mkey -> mem_h );
499499 }
@@ -545,17 +545,15 @@ static inline void _ctx_remove(mca_spml_ucx_ctx_array_t *array, mca_spml_ucx_ctx
545545 opal_atomic_wmb ();
546546}
547547
548- int mca_spml_ucx_ctx_create (long options , shmem_ctx_t * ctx )
548+ static int mca_spml_ucx_ctx_create_common (long options , mca_spml_ucx_ctx_t * * ucx_ctx_p )
549549{
550- mca_spml_ucx_ctx_t * ucx_ctx ;
551550 ucp_worker_params_t params ;
552551 ucp_ep_params_t ep_params ;
553552 size_t i , j , nprocs = oshmem_num_procs ();
554553 ucs_status_t err ;
555- int my_pe = oshmem_my_proc_id ();
556- size_t len ;
557554 spml_ucx_mkey_t * ucx_mkey ;
558555 sshmem_mkey_t * mkey ;
556+ mca_spml_ucx_ctx_t * ucx_ctx ;
559557 int rc = OSHMEM_ERROR ;
560558
561559 ucx_ctx = malloc (sizeof (mca_spml_ucx_ctx_t ));
@@ -580,10 +578,6 @@ int mca_spml_ucx_ctx_create(long options, shmem_ctx_t *ctx)
580578 goto error ;
581579 }
582580
583- if (mca_spml_ucx .active_array .ctxs_count == 0 ) {
584- opal_progress_register (spml_ucx_ctx_progress );
585- }
586-
587581 for (i = 0 ; i < nprocs ; i ++ ) {
588582 ep_params .field_mask = UCP_EP_PARAM_FIELD_REMOTE_ADDRESS ;
589583 ep_params .address = (ucp_address_t * )(mca_spml_ucx .remote_addrs_tbl [i ]);
@@ -609,11 +603,8 @@ int mca_spml_ucx_ctx_create(long options, shmem_ctx_t *ctx)
609603 }
610604 }
611605
612- SHMEM_MUTEX_LOCK (mca_spml_ucx .internal_mutex );
613- _ctx_add (& mca_spml_ucx .active_array , ucx_ctx );
614- SHMEM_MUTEX_UNLOCK (mca_spml_ucx .internal_mutex );
606+ * ucx_ctx_p = ucx_ctx ;
615607
616- (* ctx ) = (shmem_ctx_t )ucx_ctx ;
617608 return OSHMEM_SUCCESS ;
618609
619610 error2 :
@@ -634,6 +625,33 @@ int mca_spml_ucx_ctx_create(long options, shmem_ctx_t *ctx)
634625 return rc ;
635626}
636627
628+ int mca_spml_ucx_ctx_create (long options , shmem_ctx_t * ctx )
629+ {
630+ mca_spml_ucx_ctx_t * ucx_ctx ;
631+ int rc ;
632+
633+ /* Take a lock controlling context creation. AUX context may set specific
634+ * UCX parameters affecting worker creation, which are not needed for
635+ * regular contexts. */
636+ pthread_mutex_lock (& mca_spml_ucx .ctx_create_mutex );
637+ rc = mca_spml_ucx_ctx_create_common (options , & ucx_ctx );
638+ pthread_mutex_unlock (& mca_spml_ucx .ctx_create_mutex );
639+ if (rc != OSHMEM_SUCCESS ) {
640+ return rc ;
641+ }
642+
643+ if (mca_spml_ucx .active_array .ctxs_count == 0 ) {
644+ opal_progress_register (spml_ucx_ctx_progress );
645+ }
646+
647+ SHMEM_MUTEX_LOCK (mca_spml_ucx .internal_mutex );
648+ _ctx_add (& mca_spml_ucx .active_array , ucx_ctx );
649+ SHMEM_MUTEX_UNLOCK (mca_spml_ucx .internal_mutex );
650+
651+ (* ctx ) = (shmem_ctx_t )ucx_ctx ;
652+ return OSHMEM_SUCCESS ;
653+ }
654+
637655void mca_spml_ucx_ctx_destroy (shmem_ctx_t ctx )
638656{
639657 MCA_SPML_CALL (quiet (ctx ));
@@ -748,6 +766,15 @@ int mca_spml_ucx_quiet(shmem_ctx_t ctx)
748766 oshmem_shmem_abort (-1 );
749767 return ret ;
750768 }
769+
770+ /* If put_all_nb op/s is/are being executed asynchronously, need to wait its
771+ * completion as well. */
772+ if (ctx == oshmem_ctx_default ) {
773+ while (mca_spml_ucx .aux_refcnt ) {
774+ opal_progress ();
775+ }
776+ }
777+
751778 return OSHMEM_SUCCESS ;
752779}
753780
@@ -785,3 +812,99 @@ int mca_spml_ucx_send(void* buf,
785812
786813 return rc ;
787814}
815+
816+ /* this can be called with request==NULL in case of immediate completion */
817+ static void mca_spml_ucx_put_all_complete_cb (void * request , ucs_status_t status )
818+ {
819+ if (mca_spml_ucx .async_progress && (-- mca_spml_ucx .aux_refcnt == 0 )) {
820+ opal_event_evtimer_del (mca_spml_ucx .tick_event );
821+ opal_progress_unregister (spml_ucx_progress_aux_ctx );
822+ }
823+
824+ if (request != NULL ) {
825+ ucp_request_free (request );
826+ }
827+ }
828+
829+ /* Should be called with AUX lock taken */
830+ static int mca_spml_ucx_create_aux_ctx (void )
831+ {
832+ unsigned major = 0 ;
833+ unsigned minor = 0 ;
834+ unsigned rel_number = 0 ;
835+ int rc ;
836+ bool rand_dci_supp ;
837+
838+ ucp_get_version (& major , & minor , & rel_number );
839+ rand_dci_supp = UCX_VERSION (major , minor , rel_number ) >= UCX_VERSION (1 , 6 , 0 );
840+
841+ if (rand_dci_supp ) {
842+ pthread_mutex_lock (& mca_spml_ucx .ctx_create_mutex );
843+ opal_setenv ("UCX_DC_MLX5_TX_POLICY" , "rand" , 0 , & environ );
844+ }
845+
846+ rc = mca_spml_ucx_ctx_create_common (SHMEM_CTX_PRIVATE , & mca_spml_ucx .aux_ctx );
847+
848+ if (rand_dci_supp ) {
849+ opal_unsetenv ("UCX_DC_MLX5_TX_POLICY" , & environ );
850+ pthread_mutex_unlock (& mca_spml_ucx .ctx_create_mutex );
851+ }
852+
853+ return rc ;
854+ }
855+
856+ int mca_spml_ucx_put_all_nb (void * dest , const void * source , size_t size , long * counter )
857+ {
858+ int my_pe = oshmem_my_proc_id ();
859+ long val = 1 ;
860+ int peer , dst_pe , rc ;
861+ shmem_ctx_t ctx ;
862+ struct timeval tv ;
863+ void * request ;
864+
865+ mca_spml_ucx_aux_lock ();
866+ if (mca_spml_ucx .async_progress ) {
867+ if (mca_spml_ucx .aux_ctx == NULL ) {
868+ rc = mca_spml_ucx_create_aux_ctx ();
869+ if (rc != OMPI_SUCCESS ) {
870+ mca_spml_ucx_aux_unlock ();
871+ oshmem_shmem_abort (-1 );
872+ }
873+ }
874+
875+ if (mca_spml_ucx .aux_refcnt ++ == 0 ) {
876+ tv .tv_sec = 0 ;
877+ tv .tv_usec = mca_spml_ucx .async_tick ;
878+ opal_event_evtimer_add (mca_spml_ucx .tick_event , & tv );
879+ opal_progress_register (spml_ucx_progress_aux_ctx );
880+ }
881+ ctx = (shmem_ctx_t )mca_spml_ucx .aux_ctx ;
882+ } else {
883+ ctx = oshmem_ctx_default ;
884+ }
885+
886+ for (peer = 0 ; peer < oshmem_num_procs (); peer ++ ) {
887+ dst_pe = (peer + my_pe ) % oshmem_group_all -> proc_count ;
888+ rc = mca_spml_ucx_put_nb (ctx ,
889+ (void * )((uintptr_t )dest + my_pe * size ),
890+ size ,
891+ (void * )((uintptr_t )source + dst_pe * size ),
892+ dst_pe , NULL );
893+ RUNTIME_CHECK_RC (rc );
894+
895+ mca_spml_ucx_fence (ctx );
896+
897+ rc = MCA_ATOMIC_CALL (add (ctx , (void * )counter , val , sizeof (val ), dst_pe ));
898+ RUNTIME_CHECK_RC (rc );
899+ }
900+
901+ request = ucp_worker_flush_nb (((mca_spml_ucx_ctx_t * )ctx )-> ucp_worker , 0 ,
902+ mca_spml_ucx_put_all_complete_cb );
903+ if (!UCS_PTR_IS_PTR (request )) {
904+ mca_spml_ucx_put_all_complete_cb (NULL , UCS_PTR_STATUS (request ));
905+ }
906+
907+ mca_spml_ucx_aux_unlock ();
908+
909+ return OSHMEM_SUCCESS ;
910+ }
0 commit comments