@@ -264,7 +264,7 @@ static inline int start_atomicity(
264
264
return OMPI_ERROR ;
265
265
}
266
266
if (result_value == TARGET_LOCK_UNLOCKED ) {
267
- return OMPI_SUCCESS ;
267
+ break ;
268
268
}
269
269
270
270
ucp_worker_progress (mca_osc_ucx_component .wpool -> dflt_worker );
@@ -274,6 +274,8 @@ static inline int start_atomicity(
274
274
} else {
275
275
* lock_acquired = false;
276
276
}
277
+
278
+ return OMPI_SUCCESS ;
277
279
}
278
280
279
281
static inline int end_atomicity (
@@ -362,16 +364,30 @@ static inline int get_dynamic_win_info(uint64_t remote_addr, ompi_osc_ucx_module
362
364
}
363
365
364
366
static inline
365
- bool use_ucx_op (struct ompi_op_t * op , struct ompi_datatype_t * origin_dt )
367
+ bool use_atomic_op (
368
+ ompi_osc_ucx_module_t * module ,
369
+ struct ompi_op_t * op ,
370
+ struct ompi_datatype_t * origin_dt ,
371
+ struct ompi_datatype_t * target_dt ,
372
+ int origin_count ,
373
+ int target_count )
366
374
{
367
375
368
- if (op == & ompi_mpi_op_replace .op ||
369
- op == & ompi_mpi_op_sum .op ||
370
- op == & ompi_mpi_op_no_op .op ) {
371
- size_t dt_bytes ;
372
- ompi_datatype_type_size (origin_dt , & dt_bytes );
373
- if (ompi_datatype_is_predefined (origin_dt ) &&
374
- sizeof (uint64_t ) >= dt_bytes ) {
376
+ if (module -> acc_single_intrinsic &&
377
+ ompi_datatype_is_predefined (origin_dt ) &&
378
+ origin_count == 1 &&
379
+ (op == & ompi_mpi_op_replace .op ||
380
+ op == & ompi_mpi_op_sum .op ||
381
+ op == & ompi_mpi_op_no_op .op )) {
382
+ size_t origin_dt_bytes ;
383
+ size_t target_dt_bytes ;
384
+ ompi_datatype_type_size (origin_dt , & origin_dt_bytes );
385
+ ompi_datatype_type_size (target_dt , & target_dt_bytes );
386
+ /* UCX only supports 32 and 64-bit operands atm */
387
+ if (sizeof (uint64_t ) >= origin_dt_bytes &&
388
+ sizeof (uint32_t ) <= origin_dt_bytes &&
389
+ origin_dt_bytes == target_dt_bytes &&
390
+ origin_count == target_count ) {
375
391
return true;
376
392
}
377
393
}
@@ -384,25 +400,15 @@ static int do_atomic_op_intrinsic(
384
400
struct ompi_op_t * op ,
385
401
int target ,
386
402
const void * origin_addr ,
387
- int origin_count ,
388
- struct ompi_datatype_t * origin_dt ,
403
+ int count ,
404
+ struct ompi_datatype_t * dt ,
389
405
ptrdiff_t target_disp ,
390
- int target_count ,
391
- struct ompi_datatype_t * target_dt ,
392
406
void * result_addr ,
393
407
ompi_osc_ucx_request_t * ucx_req )
394
408
{
395
409
int ret = OMPI_SUCCESS ;
396
410
size_t origin_dt_bytes ;
397
- size_t target_dt_bytes ;
398
- ompi_datatype_type_size (origin_dt , & origin_dt_bytes );
399
- ompi_datatype_type_size (target_dt , & target_dt_bytes );
400
-
401
- if (sizeof (uint64_t ) > origin_dt_bytes ||
402
- origin_dt_bytes != target_dt_bytes ||
403
- target_count != origin_count ) {
404
- return OMPI_ERR_NOT_SUPPORTED ;
405
- }
411
+ ompi_datatype_type_size (dt , & origin_dt_bytes );
406
412
407
413
uint64_t remote_addr = (module -> addrs [target ]) + target_disp * OSC_UCX_GET_DISP (module , target );
408
414
@@ -430,9 +436,9 @@ static int do_atomic_op_intrinsic(
430
436
if ( result_addr ) {
431
437
output_addr = result_addr ;
432
438
}
433
- for (int i = 0 ; i < origin_count ; ++ i ) {
439
+ for (int i = 0 ; i < count ; ++ i ) {
434
440
uint64_t value = 0 ;
435
- if ((origin_count - 1 ) == i && NULL != ucx_req ) {
441
+ if ((count - 1 ) == i && NULL != ucx_req ) {
436
442
// the last item is used to feed the request, if needed
437
443
user_req_cb = & req_completion ;
438
444
user_req_ptr = ucx_req ;
@@ -596,11 +602,11 @@ int accumulate_req(const void *origin_addr, int origin_count,
596
602
return ret ;
597
603
}
598
604
599
- if (module -> acc_single_intrinsic && use_ucx_op (op , origin_dt )) {
605
+ /* rely on UCX network atomics if the user told us that it safe */
606
+ if (use_atomic_op (module , op , origin_dt , target_dt , origin_count , target_count )) {
600
607
return do_atomic_op_intrinsic (module , op , target ,
601
608
origin_addr , origin_count , origin_dt ,
602
- target_disp , target_count , target_dt ,
603
- NULL , ucx_req );
609
+ target_disp , NULL , ucx_req );
604
610
}
605
611
606
612
ret = start_atomicity (module , target , & lock_acquired );
@@ -726,14 +732,21 @@ int ompi_osc_ucx_compare_and_swap(const void *origin_addr, const void *compare_a
726
732
int ret = OMPI_SUCCESS ;
727
733
bool lock_acquired = false;
728
734
735
+ ompi_datatype_type_size (dt , & dt_bytes );
736
+ if (sizeof (uint64_t ) < dt_bytes ) {
737
+ return OMPI_ERR_NOT_SUPPORTED ;
738
+ }
739
+
729
740
ret = check_sync_state (module , target , false);
730
741
if (ret != OMPI_SUCCESS ) {
731
742
return ret ;
732
743
}
733
744
734
- ret = start_atomicity (module , target , & lock_acquired );
735
- if (ret != OMPI_SUCCESS ) {
736
- return ret ;
745
+ if (!module -> acc_single_intrinsic ) {
746
+ ret = start_atomicity (module , target , & lock_acquired );
747
+ if (ret != OMPI_SUCCESS ) {
748
+ return ret ;
749
+ }
737
750
}
738
751
739
752
if (module -> flavor == MPI_WIN_FLAVOR_DYNAMIC ) {
@@ -743,11 +756,6 @@ int ompi_osc_ucx_compare_and_swap(const void *origin_addr, const void *compare_a
743
756
}
744
757
}
745
758
746
- ompi_datatype_type_size (dt , & dt_bytes );
747
- if (sizeof (uint64_t ) < dt_bytes ) {
748
- return OMPI_ERR_NOT_SUPPORTED ;
749
- }
750
-
751
759
uint64_t compare_val ;
752
760
memcpy (& compare_val , compare_addr , dt_bytes );
753
761
memcpy (result_addr , origin_addr , dt_bytes );
@@ -841,11 +849,11 @@ int get_accumulate_req(const void *origin_addr, int origin_count,
841
849
return ret ;
842
850
}
843
851
844
- if (module -> acc_single_intrinsic && use_ucx_op (op , origin_dt )) {
852
+ /* rely on UCX network atomics if the user told us that it safe */
853
+ if (use_atomic_op (module , op , origin_dt , target_dt , origin_count , target_count )) {
845
854
return do_atomic_op_intrinsic (module , op , target ,
846
855
origin_addr , origin_count , origin_dt ,
847
- target_disp , target_count , target_dt ,
848
- result_addr , ucx_req );
856
+ target_disp , result_addr , ucx_req );
849
857
}
850
858
851
859
ret = start_atomicity (module , target , & lock_acquired );
0 commit comments