@@ -1130,7 +1130,7 @@ int MPI_Send(const void *buf, int count, MPI_Datatype datatype, int dest, int ta
11301130 if (g_cxl .initialized && getenv ("CXL_SHIM_COPY_SEND" )) {
11311131 void * cxl_buf = allocate_cxl_memory (total_size );
11321132 if (cxl_buf ) {
1133- memcpy (cxl_buf , buf , total_size );
1133+ cxl_safe_memcpy (cxl_buf , buf , total_size );
11341134 send_buf = cxl_buf ;
11351135 LOG_TRACE ("MPI_Send[%d]: copied %zu bytes to CXL at %p (rptr=0x%lx)\n" ,
11361136 call_num , total_size , cxl_buf , ptr_to_rptr (cxl_buf ));
@@ -1199,7 +1199,7 @@ int MPI_Recv(void *buf, int count, MPI_Datatype datatype, int source, int tag,
11991199 int ret = orig_MPI_Recv (recv_buf , count , datatype , source , tag , comm , status );
12001200
12011201 if (cxl_buf && ret == MPI_SUCCESS ) {
1202- memcpy (buf , cxl_buf , max_size );
1202+ cxl_safe_memcpy (buf , cxl_buf , max_size );
12031203 LOG_TRACE ("MPI_Recv[%d]: copied %zu bytes from CXL\n" , call_num , max_size );
12041204 }
12051205
@@ -1245,7 +1245,7 @@ int MPI_Isend(const void *buf, int count, MPI_Datatype datatype, int dest, int t
12451245 if (g_cxl .initialized && getenv ("CXL_SHIM_COPY_SEND" )) {
12461246 void * cxl_buf = allocate_cxl_memory (total_size );
12471247 if (cxl_buf ) {
1248- memcpy (cxl_buf , buf , total_size );
1248+ cxl_safe_memcpy (cxl_buf , buf , total_size );
12491249 send_buf = cxl_buf ;
12501250 LOG_TRACE ("MPI_Isend[%d]: copied %zu bytes to CXL at %p\n" ,
12511251 call_num , total_size , cxl_buf );
@@ -1393,10 +1393,18 @@ static cxl_window_t *register_cxl_window(MPI_Win win, void *base, size_t size,
13931393
13941394 // Allocate shared memory for window metadata if CXL is available
13951395 if (g_cxl .initialized && g_cxl .cxl_comm_enabled ) {
1396- cxl_win -> shm = (cxl_win_shm_t * )allocate_cxl_memory (sizeof (cxl_win_shm_t ));
1397- if (cxl_win -> shm ) {
1398- // Initialize on first rank
1399- if (rank == 0 ) {
1396+ LOAD_ORIGINAL (MPI_Bcast );
1397+ LOAD_ORIGINAL (MPI_Barrier );
1398+
1399+ // Only rank 0 allocates the shared window metadata structure
1400+ // Then broadcast the rptr to all ranks so they share the same structure
1401+ cxl_rptr_t shm_rptr = CXL_RPTR_NULL ;
1402+
1403+ if (rank == 0 ) {
1404+ cxl_win -> shm = (cxl_win_shm_t * )allocate_cxl_memory (sizeof (cxl_win_shm_t ));
1405+ if (cxl_win -> shm ) {
1406+ shm_rptr = ptr_to_rptr (cxl_win -> shm );
1407+ // Initialize the shared structure
14001408 cxl_win -> shm -> magic = CXL_WIN_MAGIC ;
14011409 cxl_win -> shm -> win_id = cxl_win -> win_id ;
14021410 atomic_store (& cxl_win -> shm -> ref_count , 0 );
@@ -1405,37 +1413,56 @@ static cxl_window_t *register_cxl_window(MPI_Win win, void *base, size_t size,
14051413 cxl_win -> shm -> disp_unit = 1 ;
14061414 atomic_store (& cxl_win -> shm -> barrier_count , 0 );
14071415 atomic_store (& cxl_win -> shm -> barrier_sense , 0 );
1416+ // Initialize all rank info entries
1417+ for (int r = 0 ; r < comm_size && r < CXL_MAX_RANKS ; r ++ ) {
1418+ cxl_win -> shm -> ranks [r ].base_rptr = CXL_RPTR_NULL ;
1419+ cxl_win -> shm -> ranks [r ].size = 0 ;
1420+ cxl_win -> shm -> ranks [r ].owner_rank = r ;
1421+ atomic_store (& cxl_win -> shm -> ranks [r ].lock_count , 0 );
1422+ atomic_store (& cxl_win -> shm -> ranks [r ].exclusive_lock , (uint32_t )-1 );
1423+ atomic_store (& cxl_win -> shm -> ranks [r ].fence_counter , 0 );
1424+ atomic_store (& cxl_win -> shm -> ranks [r ].sync_state , CXL_WIN_UNLOCKED );
1425+ }
1426+ __atomic_thread_fence (__ATOMIC_SEQ_CST );
14081427 }
1428+ }
1429+
1430+ // Broadcast the shm rptr from rank 0 to all ranks
1431+ orig_MPI_Bcast (& shm_rptr , sizeof (shm_rptr ), MPI_BYTE , 0 , comm );
14091432
1433+ // All ranks convert rptr to local pointer
1434+ if (shm_rptr != CXL_RPTR_NULL ) {
1435+ cxl_win -> shm = (cxl_win_shm_t * )rptr_to_ptr (shm_rptr );
1436+ }
1437+
1438+ if (cxl_win -> shm ) {
1439+ // Now all ranks share the same shm structure
14101440 // Register this rank's window region
14111441 cxl_win_rank_info_t * rank_info = & cxl_win -> shm -> ranks [rank ];
14121442
1413- // If base is in CXL memory, use it directly; otherwise copy
1443+ // Only use CXL acceleration if base is already in CXL memory
1444+ // Copying non-CXL buffers would break MPI semantics since
1445+ // Put/Get/Accumulate would modify the copy, not the original
14141446 if (is_cxl_ptr (base )) {
14151447 rank_info -> base_rptr = ptr_to_rptr (base );
1448+ rank_info -> size = size ;
1449+ __atomic_thread_fence (__ATOMIC_SEQ_CST );
1450+ atomic_fetch_add (& cxl_win -> shm -> ref_count , 1 );
1451+ cxl_win -> cxl_enabled = true;
14161452 } else {
1417- // Allocate CXL memory for this window
1418- void * cxl_base = allocate_cxl_memory (size );
1419- if (cxl_base ) {
1420- cxl_safe_memcpy (cxl_base , base , size );
1421- rank_info -> base_rptr = ptr_to_rptr (cxl_base );
1422- } else {
1423- rank_info -> base_rptr = CXL_RPTR_NULL ;
1424- }
1453+ // Non-CXL buffer - disable CXL acceleration for this window
1454+ rank_info -> base_rptr = CXL_RPTR_NULL ;
1455+ rank_info -> size = 0 ;
1456+ cxl_win -> cxl_enabled = false;
1457+ LOG_DEBUG ("Window base %p is not in CXL memory, disabling CXL acceleration\n" , base );
14251458 }
14261459
1427- rank_info -> size = size ;
1428- rank_info -> owner_rank = rank ;
1429- atomic_store (& rank_info -> lock_count , 0 );
1430- atomic_store (& rank_info -> exclusive_lock , (uint32_t )-1 );
1431- atomic_store (& rank_info -> fence_counter , 0 );
1432- atomic_store (& rank_info -> sync_state , CXL_WIN_UNLOCKED );
1433-
1434- atomic_fetch_add (& cxl_win -> shm -> ref_count , 1 );
1435- cxl_win -> cxl_enabled = (rank_info -> base_rptr != CXL_RPTR_NULL );
1460+ // Barrier to ensure all ranks have registered before proceeding
1461+ orig_MPI_Barrier (comm );
14361462
1437- LOG_DEBUG ("Registered CXL window %u for rank %d: base_rptr=0x%lx, size=%zu\n" ,
1438- cxl_win -> win_id , rank , rank_info -> base_rptr , size );
1463+ LOG_DEBUG ("Registered CXL window %u for rank %d: shm_rptr=0x%lx, base_rptr=0x%lx, size=%zu\n" ,
1464+ cxl_win -> win_id , rank , (unsigned long )shm_rptr ,
1465+ (unsigned long )rank_info -> base_rptr , size );
14391466 }
14401467 }
14411468
@@ -1740,26 +1767,22 @@ int MPI_Win_fence(int assert, MPI_Win win) {
17401767 cxl_window_t * cxl_win = find_cxl_window (win );
17411768
17421769 if (cxl_win && cxl_win -> cxl_enabled && cxl_win -> shm ) {
1743- // Full memory barrier
1770+ // Full memory barrier to ensure all CXL writes are visible
17441771 __atomic_thread_fence (__ATOMIC_SEQ_CST );
17451772
1746- // Increment local fence counter
1773+ // Track fence epochs for debugging
17471774 cxl_win_rank_info_t * my_info = & cxl_win -> shm -> ranks [cxl_win -> my_rank ];
17481775 uint64_t my_fence = atomic_fetch_add (& my_info -> fence_counter , 1 ) + 1 ;
1749-
1750- // Increment global fence
17511776 atomic_fetch_add (& cxl_win -> shm -> global_fence , 1 );
17521777
1753- // Wait for all ranks to reach this fence ( barrier)
1754- cxl_barrier ( cxl_win -> shm , cxl_win -> my_rank , cxl_win -> comm_size );
1778+ // Use MPI fence for synchronization - it will barrier internally
1779+ int ret = orig_MPI_Win_fence ( assert , win );
17551780
17561781 // Another memory barrier after synchronization
17571782 __atomic_thread_fence (__ATOMIC_SEQ_CST );
17581783
17591784 LOG_DEBUG ("MPI_Win_fence[%d]: CXL fence completed (epoch=%lu)\n" , call_num , my_fence );
1760-
1761- // Still call original for compatibility
1762- return orig_MPI_Win_fence (assert , win );
1785+ return ret ;
17631786 }
17641787
17651788 return orig_MPI_Win_fence (assert , win );
@@ -1936,7 +1959,9 @@ int MPI_Allreduce(const void *sendbuf, void *recvbuf, int count, MPI_Datatype da
19361959 LOAD_ORIGINAL (MPI_Allreduce );
19371960
19381961 // For small allreduce on COMM_WORLD with SUM, try CXL optimization
1962+ // Note: Skip if sendbuf is MPI_IN_PLACE to avoid complexity
19391963 if (g_cxl .cxl_comm_enabled && comm == MPI_COMM_WORLD &&
1964+ sendbuf != MPI_IN_PLACE &&
19401965 op == MPI_SUM && total_size <= 4096 && g_cxl .world_size <= 64 &&
19411966 (datatype == MPI_DOUBLE || datatype == MPI_FLOAT || datatype == MPI_INT ||
19421967 datatype == MPI_LONG || datatype == MPI_LONG_LONG )) {
@@ -2031,8 +2056,10 @@ int MPI_Allgather(const void *sendbuf, int sendcount, MPI_Datatype sendtype,
20312056 LOAD_ORIGINAL (MPI_Allgather );
20322057
20332058 // CXL-optimized allgather for COMM_WORLD
2034- if (g_cxl .cxl_comm_enabled && comm == MPI_COMM_WORLD && send_bytes <= 4096 &&
2035- g_cxl .world_size <= 64 ) {
2059+ // Skip MPI_IN_PLACE to avoid complexity
2060+ if (g_cxl .cxl_comm_enabled && comm == MPI_COMM_WORLD &&
2061+ sendbuf != MPI_IN_PLACE &&
2062+ send_bytes <= 4096 && g_cxl .world_size <= 64 ) {
20362063
20372064 LOAD_ORIGINAL (MPI_Barrier );
20382065
0 commit comments