@@ -235,48 +235,73 @@ static inline int ddt_put_get(ompi_osc_ucx_module_t *module,
235
235
return ret ;
236
236
}
237
237
238
- static inline int start_atomicity (ompi_osc_ucx_module_t * module , int target ) {
238
+ static inline bool need_acc_lock (ompi_osc_ucx_module_t * module , int target )
239
+ {
240
+ ompi_osc_ucx_lock_t * lock = NULL ;
241
+ opal_hash_table_get_value_uint32 (& module -> outstanding_locks ,
242
+ (uint32_t ) target , (void * * ) & lock );
243
+
244
+ /* if there is an exclusive lock there is no need to acqurie the accumulate lock */
245
+ return !(NULL != lock && LOCK_EXCLUSIVE == lock -> type );
246
+ }
247
+
248
+ static inline int start_atomicity (
249
+ ompi_osc_ucx_module_t * module ,
250
+ int target ,
251
+ bool * lock_acquired ) {
239
252
uint64_t result_value = -1 ;
240
253
uint64_t remote_addr = (module -> state_addrs )[target ] + OSC_UCX_STATE_ACC_LOCK_OFFSET ;
241
254
int ret = OMPI_SUCCESS ;
242
255
243
- for (;;) {
244
- ret = opal_common_ucx_wpmem_cmpswp (module -> state_mem ,
245
- TARGET_LOCK_UNLOCKED , TARGET_LOCK_EXCLUSIVE ,
246
- target , & result_value , sizeof (result_value ),
247
- remote_addr );
248
- if (ret != OMPI_SUCCESS ) {
249
- OSC_UCX_VERBOSE (1 , "opal_common_ucx_mem_cmpswp failed: %d" , ret );
250
- return OMPI_ERROR ;
251
- }
252
- if (result_value == TARGET_LOCK_UNLOCKED ) {
253
- return OMPI_SUCCESS ;
256
+ if (need_acc_lock (module , target )) {
257
+ for (;;) {
258
+ ret = opal_common_ucx_wpmem_cmpswp (module -> state_mem ,
259
+ TARGET_LOCK_UNLOCKED , TARGET_LOCK_EXCLUSIVE ,
260
+ target , & result_value , sizeof (result_value ),
261
+ remote_addr );
262
+ if (ret != OMPI_SUCCESS ) {
263
+ OSC_UCX_VERBOSE (1 , "opal_common_ucx_mem_cmpswp failed: %d" , ret );
264
+ return OMPI_ERROR ;
265
+ }
266
+ if (result_value == TARGET_LOCK_UNLOCKED ) {
267
+ return OMPI_SUCCESS ;
268
+ }
269
+
270
+ ucp_worker_progress (mca_osc_ucx_component .wpool -> dflt_worker );
254
271
}
255
272
256
- ucp_worker_progress (mca_osc_ucx_component .wpool -> dflt_worker );
273
+ * lock_acquired = true;
274
+ } else {
275
+ * lock_acquired = false;
257
276
}
258
277
}
259
278
260
279
static inline int end_atomicity (
261
280
ompi_osc_ucx_module_t * module ,
262
281
int target ,
282
+ bool lock_acquired ,
263
283
void * free_ptr ) {
264
- uint64_t result_value = 0 ;
265
284
uint64_t remote_addr = (module -> state_addrs )[target ] + OSC_UCX_STATE_ACC_LOCK_OFFSET ;
266
285
int ret = OMPI_SUCCESS ;
267
286
268
- /* fence any still active operations */
269
- ret = opal_common_ucx_wpmem_fence (module -> mem );
270
- if (ret != OMPI_SUCCESS ) {
271
- OSC_UCX_VERBOSE (1 , "opal_common_ucx_mem_fence failed: %d" , ret );
272
- return OMPI_ERROR ;
273
- }
274
-
275
- ret = opal_common_ucx_wpmem_fetch (module -> state_mem ,
276
- UCP_ATOMIC_FETCH_OP_SWAP , TARGET_LOCK_UNLOCKED ,
277
- target , & result_value , sizeof (result_value ),
278
- remote_addr );
287
+ if (lock_acquired ) {
288
+ uint64_t result_value = 0 ;
289
+ /* fence any still active operations */
290
+ ret = opal_common_ucx_wpmem_fence (module -> mem );
291
+ if (ret != OMPI_SUCCESS ) {
292
+ OSC_UCX_VERBOSE (1 , "opal_common_ucx_mem_fence failed: %d" , ret );
293
+ return OMPI_ERROR ;
294
+ }
279
295
296
+ ret = opal_common_ucx_wpmem_fetch (module -> state_mem ,
297
+ UCP_ATOMIC_FETCH_OP_SWAP , TARGET_LOCK_UNLOCKED ,
298
+ target , & result_value , sizeof (result_value ),
299
+ remote_addr );
300
+ assert (result_value == TARGET_LOCK_EXCLUSIVE );
301
+ } else if (NULL != free_ptr ){
302
+ /* flush before freeing the buffer */
303
+ ret = opal_common_ucx_wpmem_flush (module -> state_mem , OPAL_COMMON_UCX_SCOPE_EP , target );
304
+ }
280
305
/* TODO: encapsulate in a request and make the release non-blocking */
281
306
if (NULL != free_ptr ) {
282
307
free (free_ptr );
@@ -286,8 +311,6 @@ static inline int end_atomicity(
286
311
return OMPI_ERROR ;
287
312
}
288
313
289
- assert (result_value == TARGET_LOCK_EXCLUSIVE );
290
-
291
314
return ret ;
292
315
}
293
316
@@ -562,6 +585,7 @@ int accumulate_req(const void *origin_addr, int origin_count,
562
585
ompi_osc_ucx_module_t * module = (ompi_osc_ucx_module_t * ) win -> w_osc_module ;
563
586
int ret = OMPI_SUCCESS ;
564
587
void * free_ptr = NULL ;
588
+ bool lock_acquired = false;
565
589
566
590
ret = check_sync_state (module , target , false);
567
591
if (ret != OMPI_SUCCESS ) {
@@ -579,8 +603,7 @@ int accumulate_req(const void *origin_addr, int origin_count,
579
603
NULL , ucx_req );
580
604
}
581
605
582
-
583
- ret = start_atomicity (module , target );
606
+ ret = start_atomicity (module , target , & lock_acquired );
584
607
if (ret != OMPI_SUCCESS ) {
585
608
return ret ;
586
609
}
@@ -681,7 +704,7 @@ int accumulate_req(const void *origin_addr, int origin_count,
681
704
ompi_request_complete (& ucx_req -> super , true);
682
705
}
683
706
684
- return end_atomicity (module , target , free_ptr );
707
+ return end_atomicity (module , target , lock_acquired , free_ptr );
685
708
}
686
709
687
710
int ompi_osc_ucx_accumulate (const void * origin_addr , int origin_count ,
@@ -701,17 +724,16 @@ int ompi_osc_ucx_compare_and_swap(const void *origin_addr, const void *compare_a
701
724
uint64_t remote_addr = (module -> addrs [target ]) + target_disp * OSC_UCX_GET_DISP (module , target );
702
725
size_t dt_bytes ;
703
726
int ret = OMPI_SUCCESS ;
727
+ bool lock_acquired = false;
704
728
705
729
ret = check_sync_state (module , target , false);
706
730
if (ret != OMPI_SUCCESS ) {
707
731
return ret ;
708
732
}
709
733
710
- if (!module -> acc_single_intrinsic ) {
711
- ret = start_atomicity (module , target );
712
- if (ret != OMPI_SUCCESS ) {
713
- return ret ;
714
- }
734
+ ret = start_atomicity (module , target , & lock_acquired );
735
+ if (ret != OMPI_SUCCESS ) {
736
+ return ret ;
715
737
}
716
738
717
739
if (module -> flavor == MPI_WIN_FLAVOR_DYNAMIC ) {
@@ -738,7 +760,7 @@ int ompi_osc_ucx_compare_and_swap(const void *origin_addr, const void *compare_a
738
760
return ret ;
739
761
}
740
762
741
- return end_atomicity (module , target , NULL );
763
+ return end_atomicity (module , target , lock_acquired , NULL );
742
764
}
743
765
744
766
int ompi_osc_ucx_fetch_and_op (const void * origin_addr , void * result_addr ,
@@ -759,9 +781,10 @@ int ompi_osc_ucx_fetch_and_op(const void *origin_addr, void *result_addr,
759
781
uint64_t value = origin_addr ? * (uint64_t * )origin_addr : 0 ;
760
782
ucp_atomic_fetch_op_t opcode ;
761
783
size_t dt_bytes ;
784
+ bool lock_acquired = false;
762
785
763
786
if (!module -> acc_single_intrinsic ) {
764
- ret = start_atomicity (module , target );
787
+ ret = start_atomicity (module , target , & lock_acquired );
765
788
if (ret != OMPI_SUCCESS ) {
766
789
return ret ;
767
790
}
@@ -792,7 +815,7 @@ int ompi_osc_ucx_fetch_and_op(const void *origin_addr, void *result_addr,
792
815
return ret ;
793
816
}
794
817
795
- return end_atomicity (module , target , NULL );
818
+ return end_atomicity (module , target , lock_acquired , NULL );
796
819
} else {
797
820
return ompi_osc_ucx_get_accumulate (origin_addr , 1 , dt , result_addr , 1 , dt ,
798
821
target , target_disp , 1 , dt , op , win );
@@ -811,6 +834,7 @@ int get_accumulate_req(const void *origin_addr, int origin_count,
811
834
ompi_osc_ucx_module_t * module = (ompi_osc_ucx_module_t * ) win -> w_osc_module ;
812
835
int ret = OMPI_SUCCESS ;
813
836
void * free_addr = NULL ;
837
+ bool lock_acquired = false;
814
838
815
839
ret = check_sync_state (module , target , false);
816
840
if (ret != OMPI_SUCCESS ) {
@@ -824,7 +848,7 @@ int get_accumulate_req(const void *origin_addr, int origin_count,
824
848
result_addr , ucx_req );
825
849
}
826
850
827
- ret = start_atomicity (module , target );
851
+ ret = start_atomicity (module , target , & lock_acquired );
828
852
if (ret != OMPI_SUCCESS ) {
829
853
return ret ;
830
854
}
@@ -933,7 +957,7 @@ int get_accumulate_req(const void *origin_addr, int origin_count,
933
957
}
934
958
935
959
936
- return end_atomicity (module , target , free_addr );
960
+ return end_atomicity (module , target , lock_acquired , free_addr );
937
961
}
938
962
939
963
int ompi_osc_ucx_get_accumulate (const void * origin_addr , int origin_count ,
0 commit comments