Skip to content

Commit fb3420d

Browse files
committed
update
1 parent 6abc472 commit fb3420d

File tree

1 file changed

+65
-38
lines changed

1 file changed

+65
-38
lines changed

workloads/gromacs/mpi_cxl_shim.c

Lines changed: 65 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1130,7 +1130,7 @@ int MPI_Send(const void *buf, int count, MPI_Datatype datatype, int dest, int ta
11301130
if (g_cxl.initialized && getenv("CXL_SHIM_COPY_SEND")) {
11311131
void *cxl_buf = allocate_cxl_memory(total_size);
11321132
if (cxl_buf) {
1133-
memcpy(cxl_buf, buf, total_size);
1133+
cxl_safe_memcpy(cxl_buf, buf, total_size);
11341134
send_buf = cxl_buf;
11351135
LOG_TRACE("MPI_Send[%d]: copied %zu bytes to CXL at %p (rptr=0x%lx)\n",
11361136
call_num, total_size, cxl_buf, ptr_to_rptr(cxl_buf));
@@ -1199,7 +1199,7 @@ int MPI_Recv(void *buf, int count, MPI_Datatype datatype, int source, int tag,
11991199
int ret = orig_MPI_Recv(recv_buf, count, datatype, source, tag, comm, status);
12001200

12011201
if (cxl_buf && ret == MPI_SUCCESS) {
1202-
memcpy(buf, cxl_buf, max_size);
1202+
cxl_safe_memcpy(buf, cxl_buf, max_size);
12031203
LOG_TRACE("MPI_Recv[%d]: copied %zu bytes from CXL\n", call_num, max_size);
12041204
}
12051205

@@ -1245,7 +1245,7 @@ int MPI_Isend(const void *buf, int count, MPI_Datatype datatype, int dest, int t
12451245
if (g_cxl.initialized && getenv("CXL_SHIM_COPY_SEND")) {
12461246
void *cxl_buf = allocate_cxl_memory(total_size);
12471247
if (cxl_buf) {
1248-
memcpy(cxl_buf, buf, total_size);
1248+
cxl_safe_memcpy(cxl_buf, buf, total_size);
12491249
send_buf = cxl_buf;
12501250
LOG_TRACE("MPI_Isend[%d]: copied %zu bytes to CXL at %p\n",
12511251
call_num, total_size, cxl_buf);
@@ -1393,10 +1393,18 @@ static cxl_window_t *register_cxl_window(MPI_Win win, void *base, size_t size,
13931393

13941394
// Allocate shared memory for window metadata if CXL is available
13951395
if (g_cxl.initialized && g_cxl.cxl_comm_enabled) {
1396-
cxl_win->shm = (cxl_win_shm_t *)allocate_cxl_memory(sizeof(cxl_win_shm_t));
1397-
if (cxl_win->shm) {
1398-
// Initialize on first rank
1399-
if (rank == 0) {
1396+
LOAD_ORIGINAL(MPI_Bcast);
1397+
LOAD_ORIGINAL(MPI_Barrier);
1398+
1399+
// Only rank 0 allocates the shared window metadata structure
1400+
// Then broadcast the rptr to all ranks so they share the same structure
1401+
cxl_rptr_t shm_rptr = CXL_RPTR_NULL;
1402+
1403+
if (rank == 0) {
1404+
cxl_win->shm = (cxl_win_shm_t *)allocate_cxl_memory(sizeof(cxl_win_shm_t));
1405+
if (cxl_win->shm) {
1406+
shm_rptr = ptr_to_rptr(cxl_win->shm);
1407+
// Initialize the shared structure
14001408
cxl_win->shm->magic = CXL_WIN_MAGIC;
14011409
cxl_win->shm->win_id = cxl_win->win_id;
14021410
atomic_store(&cxl_win->shm->ref_count, 0);
@@ -1405,37 +1413,56 @@ static cxl_window_t *register_cxl_window(MPI_Win win, void *base, size_t size,
14051413
cxl_win->shm->disp_unit = 1;
14061414
atomic_store(&cxl_win->shm->barrier_count, 0);
14071415
atomic_store(&cxl_win->shm->barrier_sense, 0);
1416+
// Initialize all rank info entries
1417+
for (int r = 0; r < comm_size && r < CXL_MAX_RANKS; r++) {
1418+
cxl_win->shm->ranks[r].base_rptr = CXL_RPTR_NULL;
1419+
cxl_win->shm->ranks[r].size = 0;
1420+
cxl_win->shm->ranks[r].owner_rank = r;
1421+
atomic_store(&cxl_win->shm->ranks[r].lock_count, 0);
1422+
atomic_store(&cxl_win->shm->ranks[r].exclusive_lock, (uint32_t)-1);
1423+
atomic_store(&cxl_win->shm->ranks[r].fence_counter, 0);
1424+
atomic_store(&cxl_win->shm->ranks[r].sync_state, CXL_WIN_UNLOCKED);
1425+
}
1426+
__atomic_thread_fence(__ATOMIC_SEQ_CST);
14081427
}
1428+
}
1429+
1430+
// Broadcast the shm rptr from rank 0 to all ranks
1431+
orig_MPI_Bcast(&shm_rptr, sizeof(shm_rptr), MPI_BYTE, 0, comm);
14091432

1433+
// All ranks convert rptr to local pointer
1434+
if (shm_rptr != CXL_RPTR_NULL) {
1435+
cxl_win->shm = (cxl_win_shm_t *)rptr_to_ptr(shm_rptr);
1436+
}
1437+
1438+
if (cxl_win->shm) {
1439+
// Now all ranks share the same shm structure
14101440
// Register this rank's window region
14111441
cxl_win_rank_info_t *rank_info = &cxl_win->shm->ranks[rank];
14121442

1413-
// If base is in CXL memory, use it directly; otherwise copy
1443+
// Only use CXL acceleration if base is already in CXL memory
1444+
// Copying non-CXL buffers would break MPI semantics since
1445+
// Put/Get/Accumulate would modify the copy, not the original
14141446
if (is_cxl_ptr(base)) {
14151447
rank_info->base_rptr = ptr_to_rptr(base);
1448+
rank_info->size = size;
1449+
__atomic_thread_fence(__ATOMIC_SEQ_CST);
1450+
atomic_fetch_add(&cxl_win->shm->ref_count, 1);
1451+
cxl_win->cxl_enabled = true;
14161452
} else {
1417-
// Allocate CXL memory for this window
1418-
void *cxl_base = allocate_cxl_memory(size);
1419-
if (cxl_base) {
1420-
cxl_safe_memcpy(cxl_base, base, size);
1421-
rank_info->base_rptr = ptr_to_rptr(cxl_base);
1422-
} else {
1423-
rank_info->base_rptr = CXL_RPTR_NULL;
1424-
}
1453+
// Non-CXL buffer - disable CXL acceleration for this window
1454+
rank_info->base_rptr = CXL_RPTR_NULL;
1455+
rank_info->size = 0;
1456+
cxl_win->cxl_enabled = false;
1457+
LOG_DEBUG("Window base %p is not in CXL memory, disabling CXL acceleration\n", base);
14251458
}
14261459

1427-
rank_info->size = size;
1428-
rank_info->owner_rank = rank;
1429-
atomic_store(&rank_info->lock_count, 0);
1430-
atomic_store(&rank_info->exclusive_lock, (uint32_t)-1);
1431-
atomic_store(&rank_info->fence_counter, 0);
1432-
atomic_store(&rank_info->sync_state, CXL_WIN_UNLOCKED);
1433-
1434-
atomic_fetch_add(&cxl_win->shm->ref_count, 1);
1435-
cxl_win->cxl_enabled = (rank_info->base_rptr != CXL_RPTR_NULL);
1460+
// Barrier to ensure all ranks have registered before proceeding
1461+
orig_MPI_Barrier(comm);
14361462

1437-
LOG_DEBUG("Registered CXL window %u for rank %d: base_rptr=0x%lx, size=%zu\n",
1438-
cxl_win->win_id, rank, rank_info->base_rptr, size);
1463+
LOG_DEBUG("Registered CXL window %u for rank %d: shm_rptr=0x%lx, base_rptr=0x%lx, size=%zu\n",
1464+
cxl_win->win_id, rank, (unsigned long)shm_rptr,
1465+
(unsigned long)rank_info->base_rptr, size);
14391466
}
14401467
}
14411468

@@ -1740,26 +1767,22 @@ int MPI_Win_fence(int assert, MPI_Win win) {
17401767
cxl_window_t *cxl_win = find_cxl_window(win);
17411768

17421769
if (cxl_win && cxl_win->cxl_enabled && cxl_win->shm) {
1743-
// Full memory barrier
1770+
// Full memory barrier to ensure all CXL writes are visible
17441771
__atomic_thread_fence(__ATOMIC_SEQ_CST);
17451772

1746-
// Increment local fence counter
1773+
// Track fence epochs for debugging
17471774
cxl_win_rank_info_t *my_info = &cxl_win->shm->ranks[cxl_win->my_rank];
17481775
uint64_t my_fence = atomic_fetch_add(&my_info->fence_counter, 1) + 1;
1749-
1750-
// Increment global fence
17511776
atomic_fetch_add(&cxl_win->shm->global_fence, 1);
17521777

1753-
// Wait for all ranks to reach this fence (barrier)
1754-
cxl_barrier(cxl_win->shm, cxl_win->my_rank, cxl_win->comm_size);
1778+
// Use MPI fence for synchronization - it will barrier internally
1779+
int ret = orig_MPI_Win_fence(assert, win);
17551780

17561781
// Another memory barrier after synchronization
17571782
__atomic_thread_fence(__ATOMIC_SEQ_CST);
17581783

17591784
LOG_DEBUG("MPI_Win_fence[%d]: CXL fence completed (epoch=%lu)\n", call_num, my_fence);
1760-
1761-
// Still call original for compatibility
1762-
return orig_MPI_Win_fence(assert, win);
1785+
return ret;
17631786
}
17641787

17651788
return orig_MPI_Win_fence(assert, win);
@@ -1936,7 +1959,9 @@ int MPI_Allreduce(const void *sendbuf, void *recvbuf, int count, MPI_Datatype da
19361959
LOAD_ORIGINAL(MPI_Allreduce);
19371960

19381961
// For small allreduce on COMM_WORLD with SUM, try CXL optimization
1962+
// Note: Skip if sendbuf is MPI_IN_PLACE to avoid complexity
19391963
if (g_cxl.cxl_comm_enabled && comm == MPI_COMM_WORLD &&
1964+
sendbuf != MPI_IN_PLACE &&
19401965
op == MPI_SUM && total_size <= 4096 && g_cxl.world_size <= 64 &&
19411966
(datatype == MPI_DOUBLE || datatype == MPI_FLOAT || datatype == MPI_INT ||
19421967
datatype == MPI_LONG || datatype == MPI_LONG_LONG)) {
@@ -2031,8 +2056,10 @@ int MPI_Allgather(const void *sendbuf, int sendcount, MPI_Datatype sendtype,
20312056
LOAD_ORIGINAL(MPI_Allgather);
20322057

20332058
// CXL-optimized allgather for COMM_WORLD
2034-
if (g_cxl.cxl_comm_enabled && comm == MPI_COMM_WORLD && send_bytes <= 4096 &&
2035-
g_cxl.world_size <= 64) {
2059+
// Skip MPI_IN_PLACE to avoid complexity
2060+
if (g_cxl.cxl_comm_enabled && comm == MPI_COMM_WORLD &&
2061+
sendbuf != MPI_IN_PLACE &&
2062+
send_bytes <= 4096 && g_cxl.world_size <= 64) {
20362063

20372064
LOAD_ORIGINAL(MPI_Barrier);
20382065

0 commit comments

Comments
 (0)