@@ -378,6 +378,35 @@ static inline int get_dynamic_win_info(uint64_t remote_addr, ompi_osc_ucx_module
378378 return ret ;
379379}
380380
381+ static inline
382+ bool osc_is_atomic_dt_op_supported (
383+ struct ompi_datatype_t * dt ,
384+ struct ompi_op_t * op ,
385+ size_t dt_bytes ,
386+ uint64_t remote_addr )
387+ {
388+ /* UCX atomics are only supported on 32 and 64 bit values */
389+ if (!ompi_datatype_is_predefined (dt ) ||
390+ !ompi_osc_base_is_atomic_size_supported (remote_addr , dt_bytes )) {
391+ return false;
392+ }
393+ /* Hardware-based atomic add for floating point is not supported */
394+ else if ((
395+ op == & ompi_mpi_op_no_op .op
396+ || op == & ompi_mpi_op_replace .op
397+ || op == & ompi_mpi_op_sum .op
398+ )
399+ && !(
400+ op == & ompi_mpi_op_sum .op
401+ && (dt == MPI_FLOAT || dt == MPI_DOUBLE
402+ || dt == MPI_LONG_DOUBLE || dt == MPI_FLOAT_INT )
403+ )) {
404+ return true;
405+ }
406+
407+ return false;
408+ }
409+
381410static inline
382411bool use_atomic_op (
383412 ompi_osc_ucx_module_t * module ,
@@ -388,25 +417,16 @@ bool use_atomic_op(
388417 int origin_count ,
389418 int target_count )
390419{
420+ size_t origin_dt_bytes ;
391421
392- if (module -> acc_single_intrinsic &&
393- ompi_datatype_is_predefined (origin_dt ) &&
394- origin_count == 1 &&
395- (op == & ompi_mpi_op_replace .op ||
396- op == & ompi_mpi_op_sum .op ||
397- op == & ompi_mpi_op_no_op .op )) {
398- size_t origin_dt_bytes ;
399- size_t target_dt_bytes ;
422+ if (!module -> acc_single_intrinsic || origin_count != 1 || target_count != 1
423+ || origin_dt != target_dt ) {
424+ return false;
425+ } else {
400426 ompi_datatype_type_size (origin_dt , & origin_dt_bytes );
401- ompi_datatype_type_size (target_dt , & target_dt_bytes );
402- /* UCX only supports 32 and 64-bit operands atm */
403- if (ompi_osc_base_is_atomic_size_supported (remote_addr , origin_dt_bytes ) &&
404- origin_dt_bytes == target_dt_bytes && origin_count == target_count ) {
405- return true;
406- }
427+ return osc_is_atomic_dt_op_supported (origin_dt , op , origin_dt_bytes ,
428+ remote_addr );
407429 }
408-
409- return false;
410430}
411431
412432static int do_atomic_op_intrinsic (
@@ -859,10 +879,7 @@ int ompi_osc_ucx_fetch_and_op(const void *origin_addr, void *result_addr,
859879 uint64_t remote_addr = (module -> addrs [target ]) + target_disp * OSC_UCX_GET_DISP (module , target );
860880 ompi_datatype_type_size (dt , & dt_bytes );
861881
862- /* UCX atomics are only supported on 32 and 64 bit values */
863- if (ompi_osc_base_is_atomic_size_supported (remote_addr , dt_bytes ) &&
864- (op == & ompi_mpi_op_no_op .op || op == & ompi_mpi_op_replace .op ||
865- op == & ompi_mpi_op_sum .op )) {
882+ if (osc_is_atomic_dt_op_supported (dt , op , dt_bytes , remote_addr )) {
866883 uint64_t value ;
867884 ucp_atomic_fetch_op_t opcode ;
868885 bool lock_acquired = false;
@@ -973,6 +990,7 @@ int get_accumulate_req(const void *origin_addr, int origin_count,
973990 if (ret != OMPI_SUCCESS ) {
974991 return ret ;
975992 }
993+ temp_count *= target_count ;
976994 }
977995 ompi_datatype_get_true_extent (temp_dt , & temp_lb , & temp_extent );
978996 temp_addr = free_addr = malloc (temp_extent * temp_count );
0 commit comments