@@ -323,7 +323,25 @@ static inline int get_dynamic_win_info(uint64_t remote_addr, ompi_osc_ucx_module
323
323
return ret ;
324
324
}
325
325
326
- static int do_atomic_op_replace_sum (
326
+ static inline
327
+ bool use_ucx_op (struct ompi_op_t * op , struct ompi_datatype_t * origin_dt )
328
+ {
329
+
330
+ if (op == & ompi_mpi_op_replace .op ||
331
+ op == & ompi_mpi_op_sum .op ||
332
+ op == & ompi_mpi_op_no_op .op ) {
333
+ size_t dt_bytes ;
334
+ ompi_datatype_type_size (origin_dt , & dt_bytes );
335
+ if (ompi_datatype_is_predefined (origin_dt ) &&
336
+ sizeof (uint64_t ) >= dt_bytes ) {
337
+ return true;
338
+ }
339
+ }
340
+
341
+ return false;
342
+ }
343
+
344
+ static int do_atomic_op_intrinsic (
327
345
ompi_osc_ucx_module_t * module ,
328
346
struct ompi_op_t * op ,
329
347
int target ,
@@ -342,7 +360,7 @@ static int do_atomic_op_replace_sum(
342
360
ompi_datatype_type_size (origin_dt , & origin_dt_bytes );
343
361
ompi_datatype_type_size (target_dt , & target_dt_bytes );
344
362
345
- if (origin_dt_bytes > sizeof (uint64_t ) ||
363
+ if (sizeof (uint64_t ) > origin_dt_bytes ||
346
364
origin_dt_bytes != target_dt_bytes ||
347
365
target_count != origin_count ) {
348
366
return OMPI_ERR_NOT_SUPPORTED ;
@@ -409,133 +427,6 @@ static int do_atomic_op_replace_sum(
409
427
return ret ;
410
428
}
411
429
412
- static int do_atomic_op_cswap (
413
- ompi_osc_ucx_module_t * module ,
414
- struct ompi_op_t * op ,
415
- int target ,
416
- const void * origin_addr ,
417
- int origin_count ,
418
- struct ompi_datatype_t * origin_dt ,
419
- ptrdiff_t target_disp ,
420
- int target_count ,
421
- struct ompi_datatype_t * target_dt ,
422
- void * result_addr ,
423
- ompi_osc_ucx_request_t * ucx_req )
424
- {
425
- int ret = OMPI_SUCCESS ;
426
- size_t origin_dt_bytes ;
427
- size_t target_dt_bytes ;
428
- ompi_datatype_type_size (origin_dt , & origin_dt_bytes );
429
- ompi_datatype_type_size (target_dt , & target_dt_bytes );
430
-
431
- if (origin_dt_bytes > sizeof (uint64_t ) ||
432
- origin_dt_bytes != target_dt_bytes ||
433
- target_count != origin_count ) {
434
- return OMPI_ERR_NOT_SUPPORTED ;
435
- }
436
-
437
- uint64_t remote_addr = (module -> addrs [target ]) + target_disp * OSC_UCX_GET_DISP (module , target );
438
-
439
- if (module -> flavor == MPI_WIN_FLAVOR_DYNAMIC ) {
440
- ret = get_dynamic_win_info (remote_addr , module , target );
441
- if (ret != OMPI_SUCCESS ) {
442
- return ret ;
443
- }
444
- }
445
-
446
- for (int i = 0 ; i < origin_count ; ++ i ) {
447
-
448
- uint64_t tmp_val ;
449
- uint64_t target_val = 0 ;
450
-
451
- // get the value from the origin
452
- ret = opal_common_ucx_wpmem_putget (module -> mem , OPAL_COMMON_UCX_GET ,
453
- target , & target_val , origin_dt_bytes ,
454
- remote_addr );
455
- if (ret != OMPI_SUCCESS ) {
456
- return ret ;
457
- }
458
-
459
- ret = opal_common_ucx_wpmem_flush (module -> mem , OPAL_COMMON_UCX_SCOPE_EP , target );
460
- if (ret != OMPI_SUCCESS ) {
461
- return ret ;
462
- }
463
-
464
- /* JS: move this loop into the request to overlap multiple cas operations? */
465
- do {
466
-
467
- tmp_val = target_val ;
468
- // compute the result value
469
- ompi_op_reduce (op , (void * )origin_addr , & tmp_val , 1 , origin_dt );
470
-
471
- // compare-and-swap the resulting value
472
- ret = opal_common_ucx_wpmem_cmpswp (module -> mem , target_val , tmp_val ,
473
- target , & tmp_val , origin_dt_bytes ,
474
- remote_addr );
475
- if (ret != OMPI_SUCCESS ) {
476
- return ret ;
477
- }
478
-
479
- // check whether the conditional swap was successful
480
- if (tmp_val == target_val ) {
481
- break ;
482
- }
483
-
484
- target_val = tmp_val ;
485
-
486
- } while (1 );
487
-
488
- // store the result if necessary
489
- if (NULL != result_addr ) {
490
- memcpy (result_addr , & tmp_val , origin_dt_bytes );
491
- result_addr = (void * )((intptr_t )result_addr + origin_dt_bytes );
492
- }
493
- // advance origin and remote address
494
- origin_addr = (void * )((intptr_t )origin_addr + origin_dt_bytes );
495
- remote_addr += origin_dt_bytes ;
496
- }
497
-
498
- if (NULL != ucx_req ) {
499
- // nothing to wait for so mark the request as completed
500
- ompi_request_complete (& ucx_req -> super , true);
501
- }
502
-
503
- return ret ;
504
- }
505
-
506
- static inline
507
- int do_atomic_op (
508
- ompi_osc_ucx_module_t * module ,
509
- struct ompi_op_t * op ,
510
- int target ,
511
- const void * origin_addr ,
512
- int origin_count ,
513
- struct ompi_datatype_t * origin_dt ,
514
- ptrdiff_t target_disp ,
515
- int target_count ,
516
- struct ompi_datatype_t * target_dt ,
517
- void * result_addr ,
518
- ompi_osc_ucx_request_t * ucx_req )
519
- {
520
- int ret ;
521
-
522
- if (op == & ompi_mpi_op_replace .op ||
523
- op == & ompi_mpi_op_sum .op ||
524
- op == & ompi_mpi_op_no_op .op ) {
525
- ret = do_atomic_op_replace_sum (module , op , target ,
526
- origin_addr , origin_count , origin_dt ,
527
- target_disp , target_count , target_dt ,
528
- result_addr , ucx_req );
529
- } else {
530
- ret = do_atomic_op_cswap (module , op , target ,
531
- origin_addr , origin_count , origin_dt ,
532
- target_disp , target_count , target_dt ,
533
- result_addr , ucx_req );
534
- }
535
- return ret ;
536
- }
537
-
538
-
539
430
int ompi_osc_ucx_put (const void * origin_addr , int origin_count , struct ompi_datatype_t * origin_dt ,
540
431
int target , ptrdiff_t target_disp , int target_count ,
541
432
struct ompi_datatype_t * target_dt , struct ompi_win_t * win ) {
@@ -665,11 +556,11 @@ int accumulate_req(const void *origin_addr, int origin_count,
665
556
return ret ;
666
557
}
667
558
668
- if (module -> acc_single_intrinsic ) {
669
- return do_atomic_op (module , op , target ,
670
- origin_addr , origin_count , origin_dt ,
671
- target_disp , target_count , target_dt ,
672
- NULL , ucx_req );
559
+ if (module -> acc_single_intrinsic && use_ucx_op ( op , origin_dt ) ) {
560
+ return do_atomic_op_intrinsic (module , op , target ,
561
+ origin_addr , origin_count , origin_dt ,
562
+ target_disp , target_count , target_dt ,
563
+ NULL , ucx_req );
673
564
}
674
565
675
566
@@ -923,11 +814,11 @@ int get_accumulate_req(const void *origin_addr, int origin_count,
923
814
return ret ;
924
815
}
925
816
926
- if (module -> acc_single_intrinsic ) {
927
- return do_atomic_op (module , op , target ,
928
- origin_addr , origin_count , origin_dt ,
929
- target_disp , target_count , target_dt ,
930
- result_addr , ucx_req );
817
+ if (module -> acc_single_intrinsic && use_ucx_op ( op , origin_dt ) ) {
818
+ return do_atomic_op_intrinsic (module , op , target ,
819
+ origin_addr , origin_count , origin_dt ,
820
+ target_disp , target_count , target_dt ,
821
+ result_addr , ucx_req );
931
822
}
932
823
933
824
ret = start_atomicity (module , target );
0 commit comments