@@ -257,15 +257,30 @@ static inline int start_atomicity(ompi_osc_ucx_module_t *module, int target) {
257
257
}
258
258
}
259
259
260
- static inline int end_atomicity (ompi_osc_ucx_module_t * module , int target ) {
260
+ static inline int end_atomicity (
261
+ ompi_osc_ucx_module_t * module ,
262
+ int target ,
263
+ void * free_ptr ) {
261
264
uint64_t result_value = 0 ;
262
265
uint64_t remote_addr = (module -> state_addrs )[target ] + OSC_UCX_STATE_ACC_LOCK_OFFSET ;
263
266
int ret = OMPI_SUCCESS ;
264
267
268
+ /* fence any still active operations */
269
+ ret = opal_common_ucx_wpmem_fence (module -> mem );
270
+ if (ret != OMPI_SUCCESS ) {
271
+ OSC_UCX_VERBOSE (1 , "opal_common_ucx_mem_fence failed: %d" , ret );
272
+ return OMPI_ERROR ;
273
+ }
274
+
265
275
ret = opal_common_ucx_wpmem_fetch (module -> state_mem ,
266
276
UCP_ATOMIC_FETCH_OP_SWAP , TARGET_LOCK_UNLOCKED ,
267
277
target , & result_value , sizeof (result_value ),
268
278
remote_addr );
279
+
280
+ /* TODO: encapsulate in a request and make the release non-blocking */
281
+ if (NULL != free_ptr ) {
282
+ free (free_ptr );
283
+ }
269
284
if (ret != OMPI_SUCCESS ) {
270
285
OSC_UCX_VERBOSE (1 , "opal_common_ucx_mem_fetch failed: %d" , ret );
271
286
return OMPI_ERROR ;
@@ -546,6 +561,7 @@ int accumulate_req(const void *origin_addr, int origin_count,
546
561
547
562
ompi_osc_ucx_module_t * module = (ompi_osc_ucx_module_t * ) win -> w_osc_module ;
548
563
int ret = OMPI_SUCCESS ;
564
+ void * free_ptr = NULL ;
549
565
550
566
ret = check_sync_state (module , target , false);
551
567
if (ret != OMPI_SUCCESS ) {
@@ -576,7 +592,6 @@ int accumulate_req(const void *origin_addr, int origin_count,
576
592
return ret ;
577
593
}
578
594
} else {
579
- void * temp_addr_holder = NULL ;
580
595
void * temp_addr = NULL ;
581
596
uint32_t temp_count ;
582
597
ompi_datatype_t * temp_dt ;
@@ -593,7 +608,7 @@ int accumulate_req(const void *origin_addr, int origin_count,
593
608
}
594
609
}
595
610
ompi_datatype_get_true_extent (temp_dt , & temp_lb , & temp_extent );
596
- temp_addr = temp_addr_holder = malloc (temp_extent * temp_count );
611
+ temp_addr = free_ptr = malloc (temp_extent * temp_count );
597
612
if (temp_addr == NULL ) {
598
613
return OMPI_ERR_TEMP_OUT_OF_RESOURCE ;
599
614
}
@@ -659,20 +674,14 @@ int accumulate_req(const void *origin_addr, int origin_count,
659
674
return ret ;
660
675
}
661
676
662
- ret = opal_common_ucx_wpmem_flush (module -> mem , OPAL_COMMON_UCX_SCOPE_EP , target );
663
- if (ret != OMPI_SUCCESS ) {
664
- return ret ;
665
- }
666
-
667
- free (temp_addr_holder );
668
677
}
669
678
670
679
if (NULL != ucx_req ) {
671
680
// nothing to wait for, mark request as completed
672
681
ompi_request_complete (& ucx_req -> super , true);
673
682
}
674
683
675
- return end_atomicity (module , target );
684
+ return end_atomicity (module , target , free_ptr );
676
685
}
677
686
678
687
int ompi_osc_ucx_accumulate (const void * origin_addr , int origin_count ,
@@ -729,14 +738,7 @@ int ompi_osc_ucx_compare_and_swap(const void *origin_addr, const void *compare_a
729
738
return ret ;
730
739
}
731
740
732
- // fence before releasing the accumulate lock
733
- ret = opal_common_ucx_wpmem_fence (module -> mem );
734
- if (ret != OMPI_SUCCESS ) {
735
- OSC_UCX_VERBOSE (1 , "opal_common_ucx_mem_fence failed: %d" , ret );
736
- // don't return error, try to release the accumulate lock
737
- }
738
-
739
- return end_atomicity (module , target );
741
+ return end_atomicity (module , target , NULL );
740
742
}
741
743
742
744
int ompi_osc_ucx_fetch_and_op (const void * origin_addr , void * result_addr ,
@@ -790,7 +792,7 @@ int ompi_osc_ucx_fetch_and_op(const void *origin_addr, void *result_addr,
790
792
return ret ;
791
793
}
792
794
793
- return end_atomicity (module , target );
795
+ return end_atomicity (module , target , NULL );
794
796
} else {
795
797
return ompi_osc_ucx_get_accumulate (origin_addr , 1 , dt , result_addr , 1 , dt ,
796
798
target , target_disp , 1 , dt , op , win );
@@ -808,6 +810,7 @@ int get_accumulate_req(const void *origin_addr, int origin_count,
808
810
ompi_osc_ucx_request_t * ucx_req ) {
809
811
ompi_osc_ucx_module_t * module = (ompi_osc_ucx_module_t * ) win -> w_osc_module ;
810
812
int ret = OMPI_SUCCESS ;
813
+ void * free_addr = NULL ;
811
814
812
815
ret = check_sync_state (module , target , false);
813
816
if (ret != OMPI_SUCCESS ) {
@@ -841,7 +844,6 @@ int get_accumulate_req(const void *origin_addr, int origin_count,
841
844
return ret ;
842
845
}
843
846
} else {
844
- void * temp_addr_holder = NULL ;
845
847
void * temp_addr = NULL ;
846
848
uint32_t temp_count ;
847
849
ompi_datatype_t * temp_dt ;
@@ -858,7 +860,7 @@ int get_accumulate_req(const void *origin_addr, int origin_count,
858
860
}
859
861
}
860
862
ompi_datatype_get_true_extent (temp_dt , & temp_lb , & temp_extent );
861
- temp_addr = temp_addr_holder = malloc (temp_extent * temp_count );
863
+ temp_addr = free_addr = malloc (temp_extent * temp_count );
862
864
if (temp_addr == NULL ) {
863
865
return OMPI_ERR_TEMP_OUT_OF_RESOURCE ;
864
866
}
@@ -922,13 +924,6 @@ int get_accumulate_req(const void *origin_addr, int origin_count,
922
924
if (ret != OMPI_SUCCESS ) {
923
925
return ret ;
924
926
}
925
-
926
- ret = opal_common_ucx_wpmem_flush (module -> mem , OPAL_COMMON_UCX_SCOPE_EP , target );
927
- if (ret != OMPI_SUCCESS ) {
928
- return ret ;
929
- }
930
-
931
- free (temp_addr_holder );
932
927
}
933
928
}
934
929
@@ -938,7 +933,7 @@ int get_accumulate_req(const void *origin_addr, int origin_count,
938
933
}
939
934
940
935
941
- return end_atomicity (module , target );
936
+ return end_atomicity (module , target , free_addr );
942
937
}
943
938
944
939
int ompi_osc_ucx_get_accumulate (const void * origin_addr , int origin_count ,
0 commit comments