@@ -323,6 +323,149 @@ static inline int get_dynamic_win_info(uint64_t remote_addr, ompi_osc_ucx_module
323
323
return ret ;
324
324
}
325
325
326
+ static int atomic_op_replace_sum (
327
+ ompi_osc_ucx_module_t * module ,
328
+ struct ompi_op_t * op ,
329
+ int target ,
330
+ const void * origin_addr ,
331
+ int origin_count ,
332
+ struct ompi_datatype_t * origin_dt ,
333
+ ptrdiff_t target_disp ,
334
+ int target_count ,
335
+ struct ompi_datatype_t * target_dt ,
336
+ void * result_addr )
337
+ {
338
+ int ret = OMPI_SUCCESS ;
339
+ size_t origin_dt_bytes ;
340
+ size_t target_dt_bytes ;
341
+ ompi_datatype_type_size (origin_dt , & origin_dt_bytes );
342
+ ompi_datatype_type_size (target_dt , & target_dt_bytes );
343
+
344
+ if (origin_dt_bytes > sizeof (uint64_t ) ||
345
+ origin_dt_bytes != target_dt_bytes ||
346
+ target_count != origin_count ) {
347
+ return OMPI_ERR_NOT_SUPPORTED ;
348
+ }
349
+
350
+ uint64_t remote_addr = (module -> addrs [target ]) + target_disp * OSC_UCX_GET_DISP (module , target );
351
+
352
+ if (module -> flavor == MPI_WIN_FLAVOR_DYNAMIC ) {
353
+ ret = get_dynamic_win_info (remote_addr , module , target );
354
+ if (ret != OMPI_SUCCESS ) {
355
+ return ret ;
356
+ }
357
+ }
358
+
359
+ ucp_atomic_fetch_op_t opcode ;
360
+ if (op == & ompi_mpi_op_replace .op ) {
361
+ opcode = UCP_ATOMIC_FETCH_OP_SWAP ;
362
+ } else {
363
+ opcode = UCP_ATOMIC_FETCH_OP_FADD ;
364
+ }
365
+
366
+ for (int i = 0 ; i < origin_count ; ++ i ) {
367
+ uint64_t value = 0 ;
368
+ memcpy (& value , origin_addr , origin_dt_bytes );
369
+ ret = opal_common_ucx_wpmem_fetch_nb (module -> mem , opcode , value , target ,
370
+ result_addr ? result_addr : & (module -> req_result ),
371
+ origin_dt_bytes , remote_addr , NULL , NULL );
372
+
373
+ // advance origin and remote address
374
+ origin_addr = (void * )((intptr_t )origin_addr + origin_dt_bytes );
375
+ remote_addr += origin_dt_bytes ;
376
+ if (result_addr ) {
377
+ result_addr = (void * )((intptr_t )result_addr + origin_dt_bytes );
378
+ }
379
+ }
380
+
381
+ return ret ;
382
+ }
383
+
384
+ static int atomic_op_cswap (
385
+ ompi_osc_ucx_module_t * module ,
386
+ struct ompi_op_t * op ,
387
+ int target ,
388
+ const void * origin_addr ,
389
+ int origin_count ,
390
+ struct ompi_datatype_t * origin_dt ,
391
+ ptrdiff_t target_disp ,
392
+ int target_count ,
393
+ struct ompi_datatype_t * target_dt ,
394
+ void * result_addr )
395
+ {
396
+ int ret = OMPI_SUCCESS ;
397
+ size_t origin_dt_bytes ;
398
+ size_t target_dt_bytes ;
399
+ ompi_datatype_type_size (origin_dt , & origin_dt_bytes );
400
+ ompi_datatype_type_size (target_dt , & target_dt_bytes );
401
+
402
+ if (origin_dt_bytes > sizeof (uint64_t ) ||
403
+ origin_dt_bytes != target_dt_bytes ||
404
+ target_count != origin_count ) {
405
+ return OMPI_ERR_NOT_SUPPORTED ;
406
+ }
407
+
408
+ uint64_t remote_addr = (module -> addrs [target ]) + target_disp * OSC_UCX_GET_DISP (module , target );
409
+
410
+ if (module -> flavor == MPI_WIN_FLAVOR_DYNAMIC ) {
411
+ ret = get_dynamic_win_info (remote_addr , module , target );
412
+ if (ret != OMPI_SUCCESS ) {
413
+ return ret ;
414
+ }
415
+ }
416
+
417
+ for (int i = 0 ; i < origin_count ; ++ i ) {
418
+
419
+ uint64_t tmp_val ;
420
+ do {
421
+ uint64_t target_val = 0 ;
422
+
423
+ // get the value from the origin
424
+ ret = opal_common_ucx_wpmem_putget (module -> mem , OPAL_COMMON_UCX_GET ,
425
+ target , & target_val , origin_dt_bytes ,
426
+ remote_addr );
427
+ if (ret != OMPI_SUCCESS ) {
428
+ return ret ;
429
+ }
430
+
431
+ ret = opal_common_ucx_wpmem_flush (module -> mem , OPAL_COMMON_UCX_SCOPE_EP , target );
432
+ if (ret != OMPI_SUCCESS ) {
433
+ return ret ;
434
+ }
435
+
436
+ tmp_val = target_val ;
437
+ // compute the result value
438
+ ompi_op_reduce (op , (void * )origin_addr , & tmp_val , 1 , origin_dt );
439
+
440
+ // compare-and-swap the resulting value
441
+ ret = opal_common_ucx_wpmem_cmpswp (module -> mem , target_val , tmp_val ,
442
+ target , & tmp_val , origin_dt_bytes ,
443
+ remote_addr );
444
+ if (ret != OMPI_SUCCESS ) {
445
+ return ret ;
446
+ }
447
+
448
+ // check whether the conditional swap was successful
449
+ if (tmp_val == target_val ) {
450
+ break ;
451
+ }
452
+
453
+ } while (1 );
454
+
455
+ // store the result if necessary
456
+ if (NULL != result_addr ) {
457
+ memcpy (result_addr , & tmp_val , origin_dt_bytes );
458
+ result_addr = (void * )((intptr_t )result_addr + origin_dt_bytes );
459
+ }
460
+ // advance origin and remote address
461
+ origin_addr = (void * )((intptr_t )origin_addr + origin_dt_bytes );
462
+ remote_addr += origin_dt_bytes ;
463
+ }
464
+
465
+ return ret ;
466
+ }
467
+
468
+
326
469
int ompi_osc_ucx_put (const void * origin_addr , int origin_count , struct ompi_datatype_t * origin_dt ,
327
470
int target , ptrdiff_t target_disp , int target_count ,
328
471
struct ompi_datatype_t * target_dt , struct ompi_win_t * win ) {
@@ -449,6 +592,22 @@ int ompi_osc_ucx_accumulate(const void *origin_addr, int origin_count,
449
592
return ret ;
450
593
}
451
594
595
+ if (module -> acc_single_intrinsic ) {
596
+ if (op == & ompi_mpi_op_replace .op || op == & ompi_mpi_op_sum .op ) {
597
+ ret = atomic_op_replace_sum (module , op , target ,
598
+ origin_addr , origin_count , origin_dt ,
599
+ target_disp , target_count , target_dt ,
600
+ & (module -> req_result ));
601
+ } else {
602
+ ret = atomic_op_cswap (module , op , target ,
603
+ origin_addr , origin_count , origin_dt ,
604
+ target_disp , target_count , target_dt ,
605
+ & (module -> req_result ));
606
+ }
607
+ return ret ;
608
+ }
609
+
610
+
452
611
ret = start_atomicity (module , target );
453
612
if (ret != OMPI_SUCCESS ) {
454
613
return ret ;
@@ -569,9 +728,11 @@ int ompi_osc_ucx_compare_and_swap(const void *origin_addr, const void *compare_a
569
728
return ret ;
570
729
}
571
730
572
- ret = start_atomicity (module , target );
573
- if (ret != OMPI_SUCCESS ) {
574
- return ret ;
731
+ if (!module -> acc_single_intrinsic ) {
732
+ ret = start_atomicity (module , target );
733
+ if (ret != OMPI_SUCCESS ) {
734
+ return ret ;
735
+ }
575
736
}
576
737
577
738
if (module -> flavor == MPI_WIN_FLAVOR_DYNAMIC ) {
@@ -585,7 +746,8 @@ int ompi_osc_ucx_compare_and_swap(const void *origin_addr, const void *compare_a
585
746
ret = opal_common_ucx_wpmem_cmpswp (module -> mem ,* (uint64_t * )compare_addr ,
586
747
* (uint64_t * )origin_addr , target ,
587
748
result_addr , dt_bytes , remote_addr );
588
- if (ret != OMPI_SUCCESS ) {
749
+
750
+ if (module -> acc_single_intrinsic ) {
589
751
return ret ;
590
752
}
591
753
@@ -611,9 +773,11 @@ int ompi_osc_ucx_fetch_and_op(const void *origin_addr, void *result_addr,
611
773
ucp_atomic_fetch_op_t opcode ;
612
774
size_t dt_bytes ;
613
775
614
- ret = start_atomicity (module , target );
615
- if (ret != OMPI_SUCCESS ) {
616
- return ret ;
776
+ if (!module -> acc_single_intrinsic ) {
777
+ ret = start_atomicity (module , target );
778
+ if (ret != OMPI_SUCCESS ) {
779
+ return ret ;
780
+ }
617
781
}
618
782
619
783
if (module -> flavor == MPI_WIN_FLAVOR_DYNAMIC ) {
@@ -636,7 +800,8 @@ int ompi_osc_ucx_fetch_and_op(const void *origin_addr, void *result_addr,
636
800
637
801
ret = opal_common_ucx_wpmem_fetch (module -> mem , opcode , value , target ,
638
802
(void * )result_addr , dt_bytes , remote_addr );
639
- if (ret != OMPI_SUCCESS ) {
803
+
804
+ if (module -> acc_single_intrinsic ) {
640
805
return ret ;
641
806
}
642
807
@@ -662,6 +827,20 @@ int ompi_osc_ucx_get_accumulate(const void *origin_addr, int origin_count,
662
827
return ret ;
663
828
}
664
829
830
+ if (module -> acc_single_intrinsic ) {
831
+ if (op == & ompi_mpi_op_replace .op || op == & ompi_mpi_op_sum .op ) {
832
+ ret = atomic_op_replace_sum (module , op , target ,
833
+ origin_addr , origin_count , origin_dt ,
834
+ target_disp , target_count , target_dt , result_addr );
835
+ } else {
836
+ ret = atomic_op_cswap (module , op , target ,
837
+ origin_addr , origin_count , origin_dt ,
838
+ target_disp , target_count , target_dt , result_addr );
839
+ }
840
+ return ret ;
841
+ }
842
+
843
+
665
844
ret = start_atomicity (module , target );
666
845
if (ret != OMPI_SUCCESS ) {
667
846
return ret ;
0 commit comments