Skip to content

Commit 78cd17d

Browse files
committed
update
1 parent a80ba07 commit 78cd17d

File tree

1 file changed

+84
-79
lines changed

1 file changed

+84
-79
lines changed

workloads/gromacs/mpi_cxl_shim.c

Lines changed: 84 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,26 @@ typedef uint64_t cxl_rptr_t;
4040
#define MIN(a, b) ((a) < (b) ? (a) : (b))
4141
#endif
4242

43+
// Safe memory copy for CXL memory - avoids AVX-512/SIMD instructions that may crash on CXL
44+
// Uses volatile to prevent compiler from optimizing into SIMD memcpy
45+
static inline void cxl_safe_memcpy(void *dst, const void *src, size_t n) {
46+
volatile unsigned char *d = (volatile unsigned char *)dst;
47+
const volatile unsigned char *s = (const volatile unsigned char *)src;
48+
49+
// Copy 8 bytes at a time for better performance while staying safe
50+
while (n >= 8) {
51+
*(volatile uint64_t *)d = *(const volatile uint64_t *)s;
52+
d += 8;
53+
s += 8;
54+
n -= 8;
55+
}
56+
// Copy remaining bytes
57+
while (n > 0) {
58+
*d++ = *s++;
59+
n--;
60+
}
61+
}
62+
4363
// Add color output for better visibility
4464
#define RED "\x1b[31m"
4565
#define GREEN "\x1b[32m"
@@ -678,68 +698,46 @@ static cxl_rptr_t allocate_cxl_rptr(size_t size) {
678698
// CXL Collective Synchronization
679699
// ============================================================================
680700

681-
// Flush a cache line to ensure visibility on DAX devices
682-
static inline void cxl_clflush(const void *addr) {
683-
__asm__ volatile("clflush (%0)" :: "r"(addr) : "memory");
684-
}
685-
686-
// Flush a range of memory
701+
// Memory fence to ensure visibility - MPI_Barrier handles cross-node sync
687702
static inline void cxl_flush_range(const void *addr, size_t size) {
688-
const char *p = (const char *)((uintptr_t)addr & ~(CACHELINE_SIZE - 1));
689-
const char *end = (const char *)addr + size;
690-
while (p < end) {
691-
cxl_clflush(p);
692-
p += CACHELINE_SIZE;
693-
}
694-
__asm__ volatile("sfence" ::: "memory");
703+
(void)addr;
704+
(void)size;
705+
// Use memory fence instead of clflush - MPI_Barrier ensures cross-node visibility
706+
__atomic_thread_fence(__ATOMIC_SEQ_CST);
695707
}
696708

697-
// Sense-reversing barrier using CXL shared memory
709+
// Sense-reversing barrier using CXL shared memory (not used - MPI_Barrier preferred)
698710
// Returns the phase number after the barrier (for collective data coordination)
699711
static uint32_t cxl_collective_barrier(int num_ranks) {
700712
if (!g_cxl.header || num_ranks <= 0) return 0;
701713

702-
// Memory fence and flush before entering barrier
703714
__atomic_thread_fence(__ATOMIC_SEQ_CST);
704-
cxl_clflush(&g_cxl.header->coll_barrier_count);
705-
cxl_clflush(&g_cxl.header->coll_barrier_sense);
706-
__asm__ volatile("sfence" ::: "memory");
707715

708716
// Get current sense value
709717
uint32_t my_sense = atomic_load(&g_cxl.header->coll_barrier_sense);
710718

711719
// Increment barrier count
712720
uint32_t count = atomic_fetch_add(&g_cxl.header->coll_barrier_count, 1) + 1;
713-
cxl_clflush(&g_cxl.header->coll_barrier_count);
714-
__asm__ volatile("sfence" ::: "memory");
715721

716722
LOG_TRACE("Barrier: rank=%d, count=%u/%d, sense=%u\n",
717723
g_cxl.my_rank, count, num_ranks, my_sense);
718724

719725
if (count == (uint32_t)num_ranks) {
720726
// Last one in - reset count and flip sense
721727
atomic_store(&g_cxl.header->coll_barrier_count, 0);
722-
// Increment phase for collective coordination
723728
uint32_t new_phase = atomic_fetch_add(&g_cxl.header->coll_phase, 1) + 1;
724-
// Flip sense to release all waiters
725729
__atomic_thread_fence(__ATOMIC_SEQ_CST);
726730
atomic_store(&g_cxl.header->coll_barrier_sense, 1 - my_sense);
727-
// Flush to ensure other nodes see the update
728-
cxl_clflush(&g_cxl.header->coll_barrier_count);
729-
cxl_clflush(&g_cxl.header->coll_barrier_sense);
730-
cxl_clflush(&g_cxl.header->coll_phase);
731-
__asm__ volatile("sfence" ::: "memory");
731+
__atomic_thread_fence(__ATOMIC_SEQ_CST);
732732
LOG_TRACE("Barrier: rank=%d is LAST, new_phase=%u\n", g_cxl.my_rank, new_phase);
733733
return new_phase;
734734
} else {
735735
// Wait for sense to flip
736736
int spin_count = 0;
737737
while (atomic_load(&g_cxl.header->coll_barrier_sense) == my_sense) {
738738
__asm__ volatile("pause" ::: "memory");
739-
if (++spin_count % 1000000 == 0) {
740-
// Periodically re-flush to pick up updates from other nodes
741-
cxl_clflush(&g_cxl.header->coll_barrier_sense);
742-
__asm__ volatile("lfence" ::: "memory");
739+
if (++spin_count % 10000000 == 0) {
740+
LOG_TRACE("Barrier WAIT: rank=%d, sense=%u\n", g_cxl.my_rank, my_sense);
743741
}
744742
}
745743
__atomic_thread_fence(__ATOMIC_SEQ_CST);
@@ -753,20 +751,14 @@ static void cxl_collective_register_buffer(int rank, void *buf) {
753751
if (!g_cxl.header || rank < 0 || rank >= CXL_MAX_RANKS) return;
754752
cxl_rptr_t rptr = ptr_to_rptr(buf);
755753
g_cxl.header->coll_data_rptr[rank] = rptr;
756-
__atomic_thread_fence(__ATOMIC_RELEASE);
757-
// Flush to ensure other nodes see this registration
758-
cxl_clflush(&g_cxl.header->coll_data_rptr[rank]);
759-
__asm__ volatile("sfence" ::: "memory");
754+
__atomic_thread_fence(__ATOMIC_SEQ_CST);
760755
LOG_TRACE("Registered buffer: rank=%d, buf=%p, rptr=0x%lx\n", rank, buf, (unsigned long)rptr);
761756
}
762757

763758
// Get another rank's buffer for a collective operation
764759
static void *cxl_collective_get_buffer(int rank) {
765760
if (!g_cxl.header || rank < 0 || rank >= CXL_MAX_RANKS) return NULL;
766-
// Flush to ensure we see updates from other nodes
767-
cxl_clflush(&g_cxl.header->coll_data_rptr[rank]);
768-
__asm__ volatile("lfence" ::: "memory");
769-
__atomic_thread_fence(__ATOMIC_ACQUIRE);
761+
__atomic_thread_fence(__ATOMIC_SEQ_CST);
770762
cxl_rptr_t rptr = g_cxl.header->coll_data_rptr[rank];
771763
void *ptr = rptr_to_ptr(rptr);
772764
LOG_TRACE("Get buffer: rank=%d, rptr=0x%lx, ptr=%p\n", rank, (unsigned long)rptr, ptr);
@@ -1857,13 +1849,8 @@ int MPI_Barrier(MPI_Comm comm) {
18571849

18581850
LOAD_ORIGINAL(MPI_Barrier);
18591851

1860-
// For CXL-enabled runs with COMM_WORLD, use optimized barrier
1861-
if (g_cxl.cxl_comm_enabled && g_cxl.header && comm == MPI_COMM_WORLD &&
1862-
g_cxl.world_size <= 64) {
1863-
cxl_collective_barrier(g_cxl.world_size);
1864-
LOG_DEBUG("MPI_Barrier[%d]: CXL optimized\n", call_num);
1865-
return MPI_SUCCESS;
1866-
}
1852+
// Note: CXL barrier disabled due to cache coherency issues across nodes
1853+
// Always use original MPI_Barrier for reliable synchronization
18671854

18681855
return orig_MPI_Barrier(comm);
18691856
}
@@ -1884,20 +1871,23 @@ int MPI_Bcast(void *buffer, int count, MPI_Datatype datatype, int root, MPI_Comm
18841871
if (g_cxl.cxl_comm_enabled && comm == MPI_COMM_WORLD && total_size <= 4096 &&
18851872
g_cxl.world_size <= 64) {
18861873

1874+
LOAD_ORIGINAL(MPI_Barrier);
1875+
18871876
// Root allocates buffer and writes data
18881877
void *root_buf = NULL;
18891878
if (g_cxl.my_rank == root) {
18901879
root_buf = allocate_cxl_memory(total_size);
18911880
if (root_buf) {
1892-
memcpy(root_buf, buffer, total_size);
1881+
cxl_safe_memcpy(root_buf, buffer, total_size);
1882+
cxl_flush_range(root_buf, total_size);
18931883
}
18941884
}
18951885

18961886
// Root registers its buffer (or NULL for non-root)
18971887
cxl_collective_register_buffer(g_cxl.my_rank, root_buf);
18981888

1899-
// Barrier to ensure root has registered
1900-
cxl_collective_barrier(g_cxl.world_size);
1889+
// Use MPI_Barrier for reliable synchronization across nodes
1890+
orig_MPI_Barrier(comm);
19011891

19021892
// All non-root ranks get root's buffer address and read
19031893
if (g_cxl.my_rank != root) {
@@ -1906,11 +1896,11 @@ int MPI_Bcast(void *buffer, int count, MPI_Datatype datatype, int root, MPI_Comm
19061896
LOG_WARN("MPI_Bcast[%d]: Root buffer unavailable, fallback\n", call_num);
19071897
goto bcast_fallback;
19081898
}
1909-
memcpy(buffer, bcast_buf, total_size);
1899+
cxl_safe_memcpy(buffer, bcast_buf, total_size);
19101900
}
19111901

19121902
// Final barrier before returning
1913-
cxl_collective_barrier(g_cxl.world_size);
1903+
orig_MPI_Barrier(comm);
19141904

19151905
LOG_DEBUG("MPI_Bcast[%d]: CXL optimized (%zu bytes from root %d)\n",
19161906
call_num, total_size, root);
@@ -1951,17 +1941,20 @@ int MPI_Allreduce(const void *sendbuf, void *recvbuf, int count, MPI_Datatype da
19511941
(datatype == MPI_DOUBLE || datatype == MPI_FLOAT || datatype == MPI_INT ||
19521942
datatype == MPI_LONG || datatype == MPI_LONG_LONG)) {
19531943

1944+
LOAD_ORIGINAL(MPI_Barrier);
1945+
19541946
// Allocate per-rank buffer in shared memory
19551947
void *my_buf = allocate_cxl_memory(total_size);
19561948
if (my_buf) {
1957-
// Copy my data to shared buffer
1958-
memcpy(my_buf, sendbuf, total_size);
1949+
// Copy my data to shared buffer (use safe copy for CXL)
1950+
cxl_safe_memcpy(my_buf, sendbuf, total_size);
1951+
cxl_flush_range(my_buf, total_size);
19591952

19601953
// Register my buffer location so other ranks can find it
19611954
cxl_collective_register_buffer(g_cxl.my_rank, my_buf);
19621955

1963-
// Barrier to ensure all ranks have registered and data is visible
1964-
cxl_collective_barrier(g_cxl.world_size);
1956+
// Use MPI_Barrier for reliable synchronization across nodes
1957+
orig_MPI_Barrier(comm);
19651958

19661959
// Initialize result with my own data
19671960
memcpy(recvbuf, sendbuf, total_size);
@@ -2013,7 +2006,7 @@ int MPI_Allreduce(const void *sendbuf, void *recvbuf, int count, MPI_Datatype da
20132006
}
20142007

20152008
// Final barrier before returning (ensure all ranks finished reading)
2016-
cxl_collective_barrier(g_cxl.world_size);
2009+
orig_MPI_Barrier(comm);
20172010

20182011
LOG_DEBUG("MPI_Allreduce[%d]: CXL optimized SUM (%zu bytes)\n", call_num, total_size);
20192012
return MPI_SUCCESS;
@@ -2041,17 +2034,20 @@ int MPI_Allgather(const void *sendbuf, int sendcount, MPI_Datatype sendtype,
20412034
if (g_cxl.cxl_comm_enabled && comm == MPI_COMM_WORLD && send_bytes <= 4096 &&
20422035
g_cxl.world_size <= 64) {
20432036

2037+
LOAD_ORIGINAL(MPI_Barrier);
2038+
20442039
// Each rank allocates its own contribution buffer
20452040
void *my_buf = allocate_cxl_memory(send_bytes);
20462041
if (my_buf) {
2047-
// Copy my contribution to shared memory
2048-
memcpy(my_buf, sendbuf, send_bytes);
2042+
// Copy my contribution to shared memory (use safe copy for CXL)
2043+
cxl_safe_memcpy(my_buf, sendbuf, send_bytes);
2044+
cxl_flush_range(my_buf, send_bytes);
20492045

20502046
// Register my buffer location
20512047
cxl_collective_register_buffer(g_cxl.my_rank, my_buf);
20522048

2053-
// Barrier to ensure all ranks have registered and data is visible
2054-
cxl_collective_barrier(g_cxl.world_size);
2049+
// Use MPI_Barrier for reliable synchronization across nodes
2050+
orig_MPI_Barrier(comm);
20552051

20562052
// Each rank reads all contributions in order
20572053
for (int r = 0; r < g_cxl.world_size; r++) {
@@ -2062,11 +2058,11 @@ int MPI_Allgather(const void *sendbuf, int sendcount, MPI_Datatype sendtype,
20622058
goto allgather_fallback;
20632059
}
20642060
void *dst = (char *)recvbuf + r * send_bytes;
2065-
memcpy(dst, their_buf, send_bytes);
2061+
cxl_safe_memcpy(dst, their_buf, send_bytes);
20662062
}
20672063

20682064
// Final barrier before returning
2069-
cxl_collective_barrier(g_cxl.world_size);
2065+
orig_MPI_Barrier(comm);
20702066

20712067
LOG_DEBUG("MPI_Allgather[%d]: CXL optimized (%zu bytes each)\n", call_num, send_bytes);
20722068
return MPI_SUCCESS;
@@ -2099,18 +2095,21 @@ int MPI_Alltoall(const void *sendbuf, int sendcount, MPI_Datatype sendtype,
20992095
int n = g_cxl.world_size;
21002096
size_t row_size = send_bytes * n; // Total data this rank sends
21012097

2098+
LOAD_ORIGINAL(MPI_Barrier);
2099+
21022100
// Each rank allocates buffer for its outgoing row
21032101
void *my_row = allocate_cxl_memory(row_size);
21042102

21052103
if (my_row) {
2106-
// Copy my send data to shared memory
2107-
memcpy(my_row, sendbuf, row_size);
2104+
// Copy my send data to shared memory (use safe copy for CXL)
2105+
cxl_safe_memcpy(my_row, sendbuf, row_size);
2106+
cxl_flush_range(my_row, row_size);
21082107

21092108
// Register my row buffer
21102109
cxl_collective_register_buffer(g_cxl.my_rank, my_row);
21112110

2112-
// Barrier to ensure all ranks have registered and data is visible
2113-
cxl_collective_barrier(n);
2111+
// Use MPI_Barrier for reliable synchronization across nodes
2112+
orig_MPI_Barrier(comm);
21142113

21152114
// Read my column - element [r][my_rank] from each rank r
21162115
for (int r = 0; r < n; r++) {
@@ -2123,11 +2122,11 @@ int MPI_Alltoall(const void *sendbuf, int sendcount, MPI_Datatype sendtype,
21232122
// Element [r][my_rank] = data rank r sends to me
21242123
void *src = (char *)their_row + g_cxl.my_rank * send_bytes;
21252124
void *dst = (char *)recvbuf + r * recv_bytes;
2126-
memcpy(dst, src, recv_bytes);
2125+
cxl_safe_memcpy(dst, src, recv_bytes);
21272126
}
21282127

21292128
// Final barrier before returning
2130-
cxl_collective_barrier(n);
2129+
orig_MPI_Barrier(comm);
21312130

21322131
int cxl_num = atomic_fetch_add(&cxl_alltoall_count, 1);
21332132
LOG_DEBUG("MPI_Alltoall[%d]: CXL direct #%d (%zu bytes per rank)\n",
@@ -2157,18 +2156,21 @@ int MPI_Gather(const void *sendbuf, int sendcount, MPI_Datatype sendtype,
21572156
size_t send_bytes = (size_t)sendcount * send_size;
21582157

21592158
if (send_bytes <= 4096) {
2159+
LOAD_ORIGINAL(MPI_Barrier);
2160+
21602161
// Each rank allocates its own contribution buffer
21612162
void *my_buf = allocate_cxl_memory(send_bytes);
21622163

21632164
if (my_buf) {
2164-
// Copy my contribution to shared memory
2165-
memcpy(my_buf, sendbuf, send_bytes);
2165+
// Copy my contribution to shared memory (use safe copy for CXL)
2166+
cxl_safe_memcpy(my_buf, sendbuf, send_bytes);
2167+
cxl_flush_range(my_buf, send_bytes);
21662168

21672169
// Register my buffer location
21682170
cxl_collective_register_buffer(g_cxl.my_rank, my_buf);
21692171

2170-
// Barrier to ensure all ranks have registered
2171-
cxl_collective_barrier(g_cxl.world_size);
2172+
// Use MPI_Barrier for reliable synchronization across nodes
2173+
orig_MPI_Barrier(comm);
21722174

21732175
// Root reads all contributions
21742176
if (g_cxl.my_rank == root) {
@@ -2180,12 +2182,12 @@ int MPI_Gather(const void *sendbuf, int sendcount, MPI_Datatype sendtype,
21802182
goto gather_fallback;
21812183
}
21822184
void *dst = (char *)recvbuf + r * send_bytes;
2183-
memcpy(dst, their_buf, send_bytes);
2185+
cxl_safe_memcpy(dst, their_buf, send_bytes);
21842186
}
21852187
}
21862188

21872189
// Final barrier before returning
2188-
cxl_collective_barrier(g_cxl.world_size);
2190+
orig_MPI_Barrier(comm);
21892191

21902192
LOG_DEBUG("MPI_Gather[%d]: CXL optimized\n", call_num);
21912193
return MPI_SUCCESS;
@@ -2216,20 +2218,23 @@ int MPI_Scatter(const void *sendbuf, int sendcount, MPI_Datatype sendtype,
22162218
if (send_bytes <= 4096) {
22172219
size_t total_size = send_bytes * g_cxl.world_size;
22182220

2221+
LOAD_ORIGINAL(MPI_Barrier);
2222+
22192223
// Root allocates buffer and writes all data
22202224
void *root_buf = NULL;
22212225
if (g_cxl.my_rank == root) {
22222226
root_buf = allocate_cxl_memory(total_size);
22232227
if (root_buf) {
2224-
memcpy(root_buf, sendbuf, total_size);
2228+
cxl_safe_memcpy(root_buf, sendbuf, total_size);
2229+
cxl_flush_range(root_buf, total_size);
22252230
}
22262231
}
22272232

22282233
// Root registers its buffer (or NULL for non-root)
22292234
cxl_collective_register_buffer(g_cxl.my_rank, root_buf);
22302235

2231-
// Barrier to ensure root has registered
2232-
cxl_collective_barrier(g_cxl.world_size);
2236+
// Use MPI_Barrier for reliable synchronization across nodes
2237+
orig_MPI_Barrier(comm);
22332238

22342239
// All ranks get root's buffer address and read their portion
22352240
void *scatter_buf = cxl_collective_get_buffer(root);
@@ -2239,10 +2244,10 @@ int MPI_Scatter(const void *sendbuf, int sendcount, MPI_Datatype sendtype,
22392244
}
22402245

22412246
void *my_slot = (char *)scatter_buf + g_cxl.my_rank * send_bytes;
2242-
memcpy(recvbuf, my_slot, send_bytes);
2247+
cxl_safe_memcpy(recvbuf, my_slot, send_bytes);
22432248

22442249
// Final barrier before returning
2245-
cxl_collective_barrier(g_cxl.world_size);
2250+
orig_MPI_Barrier(comm);
22462251

22472252
LOG_DEBUG("MPI_Scatter[%d]: CXL optimized\n", call_num);
22482253
return MPI_SUCCESS;

0 commit comments

Comments
 (0)