@@ -323,7 +323,7 @@ static inline int get_dynamic_win_info(uint64_t remote_addr, ompi_osc_ucx_module
323
323
return ret ;
324
324
}
325
325
326
- static int atomic_op_replace_sum (
326
+ static int do_atomic_op_replace_sum (
327
327
ompi_osc_ucx_module_t * module ,
328
328
struct ompi_op_t * op ,
329
329
int target ,
@@ -333,7 +333,8 @@ static int atomic_op_replace_sum(
333
333
ptrdiff_t target_disp ,
334
334
int target_count ,
335
335
struct ompi_datatype_t * target_dt ,
336
- void * result_addr )
336
+ void * result_addr ,
337
+ ompi_osc_ucx_request_t * ucx_req )
337
338
{
338
339
int ret = OMPI_SUCCESS ;
339
340
size_t origin_dt_bytes ;
@@ -363,12 +364,27 @@ static int atomic_op_replace_sum(
363
364
opcode = UCP_ATOMIC_FETCH_OP_FADD ;
364
365
}
365
366
367
+ opal_common_ucx_user_req_handler_t user_req_cb = NULL ;
368
+ void * user_req_ptr = NULL ;
366
369
for (int i = 0 ; i < origin_count ; ++ i ) {
367
370
uint64_t value = 0 ;
371
+ if ((origin_count - 1 ) == i && NULL != ucx_req ) {
372
+ // the last item is used to feed the request, if needed
373
+ user_req_cb = & req_completion ;
374
+ user_req_ptr = ucx_req ;
375
+ // issue a fence if this is the last but not the only element
376
+ if (0 < i ) {
377
+ ret = opal_common_ucx_wpmem_fence (module -> mem );
378
+ if (ret != OMPI_SUCCESS ) {
379
+ OSC_UCX_VERBOSE (1 , "opal_common_ucx_mem_fence failed: %d" , ret );
380
+ return OMPI_ERROR ;
381
+ }
382
+ }
383
+ }
368
384
memcpy (& value , origin_addr , origin_dt_bytes );
369
385
ret = opal_common_ucx_wpmem_fetch_nb (module -> mem , opcode , value , target ,
370
386
result_addr ? result_addr : & (module -> req_result ),
371
- origin_dt_bytes , remote_addr , NULL , NULL );
387
+ origin_dt_bytes , remote_addr , user_req_cb , user_req_ptr );
372
388
373
389
// advance origin and remote address
374
390
origin_addr = (void * )((intptr_t )origin_addr + origin_dt_bytes );
@@ -381,7 +397,7 @@ static int atomic_op_replace_sum(
381
397
return ret ;
382
398
}
383
399
384
- static int atomic_op_cswap (
400
+ static int do_atomic_op_cswap (
385
401
ompi_osc_ucx_module_t * module ,
386
402
struct ompi_op_t * op ,
387
403
int target ,
@@ -391,7 +407,8 @@ static int atomic_op_cswap(
391
407
ptrdiff_t target_disp ,
392
408
int target_count ,
393
409
struct ompi_datatype_t * target_dt ,
394
- void * result_addr )
410
+ void * result_addr ,
411
+ ompi_osc_ucx_request_t * ucx_req )
395
412
{
396
413
int ret = OMPI_SUCCESS ;
397
414
size_t origin_dt_bytes ;
@@ -432,6 +449,7 @@ static int atomic_op_cswap(
432
449
return ret ;
433
450
}
434
451
452
+ /* JS: move this loop into the request to overlap multiple cas operations? */
435
453
do {
436
454
437
455
tmp_val = target_val ;
@@ -451,6 +469,8 @@ static int atomic_op_cswap(
451
469
break ;
452
470
}
453
471
472
+ target_val = tmp_val ;
473
+
454
474
} while (1 );
455
475
456
476
// store the result if necessary
@@ -463,6 +483,41 @@ static int atomic_op_cswap(
463
483
remote_addr += origin_dt_bytes ;
464
484
}
465
485
486
+ if (NULL != ucx_req ) {
487
+ // nothing to wait for so mark the request as completed
488
+ ompi_request_complete (& ucx_req -> super , true);
489
+ }
490
+
491
+ return ret ;
492
+ }
493
+
494
+ static inline
495
+ int do_atomic_op (
496
+ ompi_osc_ucx_module_t * module ,
497
+ struct ompi_op_t * op ,
498
+ int target ,
499
+ const void * origin_addr ,
500
+ int origin_count ,
501
+ struct ompi_datatype_t * origin_dt ,
502
+ ptrdiff_t target_disp ,
503
+ int target_count ,
504
+ struct ompi_datatype_t * target_dt ,
505
+ void * result_addr ,
506
+ ompi_osc_ucx_request_t * ucx_req )
507
+ {
508
+ int ret ;
509
+
510
+ if (op == & ompi_mpi_op_replace .op || op == & ompi_mpi_op_sum .op ) {
511
+ ret = do_atomic_op_replace_sum (module , op , target ,
512
+ origin_addr , origin_count , origin_dt ,
513
+ target_disp , target_count , target_dt ,
514
+ result_addr , ucx_req );
515
+ } else {
516
+ ret = do_atomic_op_cswap (module , op , target ,
517
+ origin_addr , origin_count , origin_dt ,
518
+ target_disp , target_count , target_dt ,
519
+ result_addr , ucx_req );
520
+ }
466
521
return ret ;
467
522
}
468
523
@@ -576,11 +631,14 @@ int ompi_osc_ucx_get(void *origin_addr, int origin_count,
576
631
}
577
632
}
578
633
579
- int ompi_osc_ucx_accumulate (const void * origin_addr , int origin_count ,
580
- struct ompi_datatype_t * origin_dt ,
581
- int target , ptrdiff_t target_disp , int target_count ,
582
- struct ompi_datatype_t * target_dt ,
583
- struct ompi_op_t * op , struct ompi_win_t * win ) {
634
+ static
635
+ int accumulate_req (const void * origin_addr , int origin_count ,
636
+ struct ompi_datatype_t * origin_dt ,
637
+ int target , ptrdiff_t target_disp , int target_count ,
638
+ struct ompi_datatype_t * target_dt ,
639
+ struct ompi_op_t * op , struct ompi_win_t * win ,
640
+ ompi_osc_ucx_request_t * ucx_req ) {
641
+
584
642
ompi_osc_ucx_module_t * module = (ompi_osc_ucx_module_t * ) win -> w_osc_module ;
585
643
int ret = OMPI_SUCCESS ;
586
644
@@ -594,18 +652,10 @@ int ompi_osc_ucx_accumulate(const void *origin_addr, int origin_count,
594
652
}
595
653
596
654
if (module -> acc_single_intrinsic ) {
597
- if (op == & ompi_mpi_op_replace .op || op == & ompi_mpi_op_sum .op ) {
598
- ret = atomic_op_replace_sum (module , op , target ,
599
- origin_addr , origin_count , origin_dt ,
600
- target_disp , target_count , target_dt ,
601
- & (module -> req_result ));
602
- } else {
603
- ret = atomic_op_cswap (module , op , target ,
604
- origin_addr , origin_count , origin_dt ,
605
- target_disp , target_count , target_dt ,
606
- & (module -> req_result ));
607
- }
608
- return ret ;
655
+ return do_atomic_op (module , op , target ,
656
+ origin_addr , origin_count , origin_dt ,
657
+ target_disp , target_count , target_dt ,
658
+ NULL , ucx_req );
609
659
}
610
660
611
661
@@ -712,9 +762,23 @@ int ompi_osc_ucx_accumulate(const void *origin_addr, int origin_count,
712
762
free (temp_addr_holder );
713
763
}
714
764
765
+ if (NULL != ucx_req ) {
766
+ // nothing to wait for, mark request as completed
767
+ ompi_request_complete (& ucx_req -> super , true);
768
+ }
769
+
715
770
return end_atomicity (module , target );
716
771
}
717
772
773
+ int ompi_osc_ucx_accumulate (const void * origin_addr , int origin_count ,
774
+ struct ompi_datatype_t * origin_dt ,
775
+ int target , ptrdiff_t target_disp , int target_count ,
776
+ struct ompi_datatype_t * target_dt ,
777
+ struct ompi_op_t * op , struct ompi_win_t * win ) {
778
+ return accumulate_req (origin_addr , origin_count , origin_dt , target ,
779
+ target_disp , target_count , target_dt , op , win , NULL );
780
+ }
781
+
718
782
int ompi_osc_ucx_compare_and_swap (const void * origin_addr , const void * compare_addr ,
719
783
void * result_addr , struct ompi_datatype_t * dt ,
720
784
int target , ptrdiff_t target_disp ,
@@ -813,13 +877,15 @@ int ompi_osc_ucx_fetch_and_op(const void *origin_addr, void *result_addr,
813
877
}
814
878
}
815
879
816
- int ompi_osc_ucx_get_accumulate (const void * origin_addr , int origin_count ,
817
- struct ompi_datatype_t * origin_dt ,
818
- void * result_addr , int result_count ,
819
- struct ompi_datatype_t * result_dt ,
820
- int target , ptrdiff_t target_disp ,
821
- int target_count , struct ompi_datatype_t * target_dt ,
822
- struct ompi_op_t * op , struct ompi_win_t * win ) {
880
+ static
881
+ int get_accumulate_req (const void * origin_addr , int origin_count ,
882
+ struct ompi_datatype_t * origin_dt ,
883
+ void * result_addr , int result_count ,
884
+ struct ompi_datatype_t * result_dt ,
885
+ int target , ptrdiff_t target_disp ,
886
+ int target_count , struct ompi_datatype_t * target_dt ,
887
+ struct ompi_op_t * op , struct ompi_win_t * win ,
888
+ ompi_osc_ucx_request_t * ucx_req ) {
823
889
ompi_osc_ucx_module_t * module = (ompi_osc_ucx_module_t * ) win -> w_osc_module ;
824
890
int ret = OMPI_SUCCESS ;
825
891
@@ -829,19 +895,12 @@ int ompi_osc_ucx_get_accumulate(const void *origin_addr, int origin_count,
829
895
}
830
896
831
897
if (module -> acc_single_intrinsic ) {
832
- if (op == & ompi_mpi_op_replace .op || op == & ompi_mpi_op_sum .op ) {
833
- ret = atomic_op_replace_sum (module , op , target ,
834
- origin_addr , origin_count , origin_dt ,
835
- target_disp , target_count , target_dt , result_addr );
836
- } else {
837
- ret = atomic_op_cswap (module , op , target ,
838
- origin_addr , origin_count , origin_dt ,
839
- target_disp , target_count , target_dt , result_addr );
840
- }
841
- return ret ;
898
+ return do_atomic_op (module , op , target ,
899
+ origin_addr , origin_count , origin_dt ,
900
+ target_disp , target_count , target_dt ,
901
+ result_addr , ucx_req );
842
902
}
843
903
844
-
845
904
ret = start_atomicity (module , target );
846
905
if (ret != OMPI_SUCCESS ) {
847
906
return ret ;
@@ -953,9 +1012,28 @@ int ompi_osc_ucx_get_accumulate(const void *origin_addr, int origin_count,
953
1012
}
954
1013
}
955
1014
1015
+ if (NULL != ucx_req ) {
1016
+ // nothing to wait for, mark request as completed
1017
+ ompi_request_complete (& ucx_req -> super , true);
1018
+ }
1019
+
1020
+
956
1021
return end_atomicity (module , target );
957
1022
}
958
1023
1024
+ int ompi_osc_ucx_get_accumulate (const void * origin_addr , int origin_count ,
1025
+ struct ompi_datatype_t * origin_dt ,
1026
+ void * result_addr , int result_count ,
1027
+ struct ompi_datatype_t * result_dt ,
1028
+ int target , ptrdiff_t target_disp ,
1029
+ int target_count , struct ompi_datatype_t * target_dt ,
1030
+ struct ompi_op_t * op , struct ompi_win_t * win ) {
1031
+
1032
+ return get_accumulate_req (origin_addr , origin_count , origin_dt , result_addr ,
1033
+ result_count , result_dt , target , target_disp ,
1034
+ target_count , target_dt , op , win , NULL );
1035
+ }
1036
+
959
1037
int ompi_osc_ucx_rput (const void * origin_addr , int origin_count ,
960
1038
struct ompi_datatype_t * origin_dt ,
961
1039
int target , ptrdiff_t target_disp , int target_count ,
@@ -1077,14 +1155,13 @@ int ompi_osc_ucx_raccumulate(const void *origin_addr, int origin_count,
1077
1155
OMPI_OSC_UCX_REQUEST_ALLOC (win , ucx_req );
1078
1156
assert (NULL != ucx_req );
1079
1157
1080
- ret = ompi_osc_ucx_accumulate (origin_addr , origin_count , origin_dt , target , target_disp ,
1081
- target_count , target_dt , op , win );
1158
+ ret = accumulate_req (origin_addr , origin_count , origin_dt , target , target_disp ,
1159
+ target_count , target_dt , op , win , ucx_req );
1082
1160
if (ret != OMPI_SUCCESS ) {
1083
1161
OMPI_OSC_UCX_REQUEST_RETURN (ucx_req );
1084
1162
return ret ;
1085
1163
}
1086
1164
1087
- ompi_request_complete (& ucx_req -> super , true);
1088
1165
* request = & ucx_req -> super ;
1089
1166
1090
1167
return ret ;
@@ -1110,17 +1187,15 @@ int ompi_osc_ucx_rget_accumulate(const void *origin_addr, int origin_count,
1110
1187
OMPI_OSC_UCX_REQUEST_ALLOC (win , ucx_req );
1111
1188
assert (NULL != ucx_req );
1112
1189
1113
- ret = ompi_osc_ucx_get_accumulate (origin_addr , origin_count , origin_datatype ,
1114
- result_addr , result_count , result_datatype ,
1115
- target , target_disp , target_count ,
1116
- target_datatype , op , win );
1190
+ ret = get_accumulate_req (origin_addr , origin_count , origin_datatype ,
1191
+ result_addr , result_count , result_datatype ,
1192
+ target , target_disp , target_count ,
1193
+ target_datatype , op , win , ucx_req );
1117
1194
if (ret != OMPI_SUCCESS ) {
1118
1195
OMPI_OSC_UCX_REQUEST_RETURN (ucx_req );
1119
1196
return ret ;
1120
1197
}
1121
1198
1122
- ompi_request_complete (& ucx_req -> super , true);
1123
-
1124
1199
* request = & ucx_req -> super ;
1125
1200
1126
1201
return ret ;
0 commit comments