@@ -722,6 +722,36 @@ int ompi_osc_ucx_accumulate(const void *origin_addr, int origin_count,
722
722
target_disp , target_count , target_dt , op , win , NULL );
723
723
}
724
724
725
+ static int
726
+ do_atomic_compare_and_swap (const void * origin_addr , const void * compare_addr ,
727
+ void * result_addr , struct ompi_datatype_t * dt ,
728
+ int target , uint64_t remote_addr ,
729
+ ompi_osc_ucx_module_t * module )
730
+ {
731
+ int ret ;
732
+ bool lock_acquired = false;
733
+ size_t dt_bytes ;
734
+ if (!module -> acc_single_intrinsic ) {
735
+ ret = start_atomicity (module , target , & lock_acquired );
736
+ if (ret != OMPI_SUCCESS ) {
737
+ return ret ;
738
+ }
739
+ }
740
+
741
+ ompi_datatype_type_size (dt , & dt_bytes );
742
+ uint64_t compare_val = opal_common_ucx_load_uint64 (compare_addr , dt_bytes );
743
+ uint64_t value = opal_common_ucx_load_uint64 (origin_addr , dt_bytes );
744
+ ret = opal_common_ucx_wpmem_cmpswp_nb (module -> mem , compare_val , value , target ,
745
+ result_addr , dt_bytes , remote_addr ,
746
+ NULL , NULL );
747
+
748
+ if (module -> acc_single_intrinsic ) {
749
+ return ret ;
750
+ }
751
+
752
+ return end_atomicity (module , target , lock_acquired , NULL );
753
+ }
754
+
725
755
int ompi_osc_ucx_compare_and_swap (const void * origin_addr , const void * compare_addr ,
726
756
void * result_addr , struct ompi_datatype_t * dt ,
727
757
int target , ptrdiff_t target_disp ,
@@ -732,40 +762,55 @@ int ompi_osc_ucx_compare_and_swap(const void *origin_addr, const void *compare_a
732
762
int ret = OMPI_SUCCESS ;
733
763
bool lock_acquired = false;
734
764
735
- ompi_datatype_type_size (dt , & dt_bytes );
736
- if (sizeof (uint64_t ) < dt_bytes ) {
737
- return OMPI_ERR_NOT_SUPPORTED ;
738
- }
739
-
740
765
ret = check_sync_state (module , target , false);
741
766
if (ret != OMPI_SUCCESS ) {
742
767
return ret ;
743
768
}
744
769
745
- if (!module -> acc_single_intrinsic ) {
746
- ret = start_atomicity (module , target , & lock_acquired );
747
- if (ret != OMPI_SUCCESS ) {
748
- return ret ;
749
- }
750
- }
751
-
752
770
if (module -> flavor == MPI_WIN_FLAVOR_DYNAMIC ) {
753
771
ret = get_dynamic_win_info (remote_addr , module , target );
754
772
if (ret != OMPI_SUCCESS ) {
755
773
return ret ;
756
774
}
757
775
}
758
776
759
- uint64_t compare_val = opal_common_ucx_load_uint64 (compare_addr , dt_bytes );
760
- uint64_t value = opal_common_ucx_load_uint64 (origin_addr , dt_bytes );
761
- ret = opal_common_ucx_wpmem_cmpswp_nb (module -> mem , compare_val , value , target ,
762
- result_addr , dt_bytes , remote_addr ,
763
- NULL , NULL );
777
+ ompi_datatype_type_size (dt , & dt_bytes );
778
+ if (4 == dt_bytes || 8 == dt_bytes ) {
779
+ // fast path using UCX atomic operations
780
+ return do_atomic_compare_and_swap (origin_addr , compare_addr ,
781
+ result_addr , dt , target ,
782
+ remote_addr , module );
783
+ }
764
784
765
- if (module -> acc_single_intrinsic ) {
785
+ /* fall back to get-compare-put */
786
+
787
+ ret = start_atomicity (module , target , & lock_acquired );
788
+ if (ret != OMPI_SUCCESS ) {
766
789
return ret ;
767
790
}
768
791
792
+ ret = opal_common_ucx_wpmem_putget (module -> mem , OPAL_COMMON_UCX_GET , target ,
793
+ & result_addr , dt_bytes , remote_addr );
794
+ if (OPAL_SUCCESS != ret ) {
795
+ OSC_UCX_VERBOSE (1 , "opal_common_ucx_mem_putget failed: %d" , ret );
796
+ return OMPI_ERROR ;
797
+ }
798
+
799
+ ret = opal_common_ucx_wpmem_flush (module -> mem , OPAL_COMMON_UCX_SCOPE_EP , target );
800
+ if (ret != OPAL_SUCCESS ) {
801
+ return ret ;
802
+ }
803
+
804
+ if (0 == memcmp (result_addr , compare_addr , dt_bytes )) {
805
+ // write the new value
806
+ ret = opal_common_ucx_wpmem_putget (module -> mem , OPAL_COMMON_UCX_PUT , target ,
807
+ (void * )origin_addr , dt_bytes , remote_addr );
808
+ if (OPAL_SUCCESS != ret ) {
809
+ OSC_UCX_VERBOSE (1 , "opal_common_ucx_mem_putget failed: %d" , ret );
810
+ return OMPI_ERROR ;
811
+ }
812
+ }
813
+
769
814
return end_atomicity (module , target , lock_acquired , NULL );
770
815
}
771
816
0 commit comments