Skip to content

Commit c0814e4

Browse files
committed
update
1 parent 2dda850 commit c0814e4

File tree

2 files changed

+161
-112
lines changed

2 files changed

+161
-112
lines changed

scripts/start_cxlmemsim_with_qemu.sh

Lines changed: 0 additions & 39 deletions
This file was deleted.

workloads/gromacs/mpi_cxl_shim.c

Lines changed: 161 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,17 @@
1515
#include <unistd.h>
1616
#include <execinfo.h>
1717
#include <signal.h>
18+
#include <stdarg.h>
1819

1920
#define CACHELINE_SIZE 64
2021
#define DEFAULT_CXL_SIZE (4UL * 1024 * 1024 * 1024) // 4GB default
2122
#define CXL_ALIGNMENT 4096
2223
#define SHIM_VERSION "2.0"
2324

25+
#ifndef MIN
26+
#define MIN(a, b) ((a) < (b) ? (a) : (b))
27+
#endif
28+
2429
// Add color output for better visibility
2530
#define RED "\x1b[31m"
2631
#define GREEN "\x1b[32m"
@@ -109,90 +114,145 @@ static void signal_handler(int sig) {
109114
exit(1);
110115
}
111116

117+
// Get DAX device size from sysfs
118+
static size_t get_dax_size(const char *dax_path) {
119+
char sysfs_path[512];
120+
const char *dev_name = strrchr(dax_path, '/');
121+
if (!dev_name) dev_name = dax_path;
122+
else dev_name++;
123+
124+
snprintf(sysfs_path, sizeof(sysfs_path), "/sys/bus/dax/devices/%s/size", dev_name);
125+
126+
FILE *f = fopen(sysfs_path, "r");
127+
if (!f) {
128+
LOG_WARN("Cannot read DAX size from %s, using stat\n", sysfs_path);
129+
return 0;
130+
}
131+
132+
unsigned long long size = 0;
133+
if (fscanf(f, "%llu", &size) != 1) {
134+
fclose(f);
135+
return 0;
136+
}
137+
fclose(f);
138+
139+
LOG_DEBUG("DAX device %s size from sysfs: %llu bytes\n", dev_name, size);
140+
return (size_t)size;
141+
}
142+
112143
// Initialize CXL memory
113144
static void init_cxl_memory(void) {
114145
if (g_cxl.initialized) return;
115-
146+
116147
pthread_mutex_lock(&g_cxl.lock);
117148
if (g_cxl.initialized) {
118149
pthread_mutex_unlock(&g_cxl.lock);
119150
return;
120151
}
121-
152+
122153
const char *dax_path = getenv("CXL_DAX_PATH");
123154
const char *cxl_size_str = getenv("CXL_MEM_SIZE");
124155
size_t cxl_size = cxl_size_str ? strtoull(cxl_size_str, NULL, 0) : DEFAULT_CXL_SIZE;
125-
126-
if (dax_path && access(dax_path, R_OK | W_OK) == 0) {
127-
// Use DAX device
156+
157+
if (dax_path && strlen(dax_path) > 0) {
158+
// Use DAX device - open with O_RDWR for shared access
128159
g_cxl.fd = open(dax_path, O_RDWR);
129160
if (g_cxl.fd < 0) {
130161
LOG_ERROR("Failed to open DAX device %s: %s\n", dax_path, strerror(errno));
131162
goto use_shm;
132163
}
133-
134-
// Get actual DAX size
135-
struct stat st;
136-
if (fstat(g_cxl.fd, &st) == 0) {
137-
cxl_size = st.st_size;
164+
165+
// Try to get DAX size from sysfs first
166+
cxl_size = get_dax_size(dax_path);
167+
if (cxl_size == 0) {
168+
// Fallback to stat
169+
struct stat st;
170+
if (fstat(g_cxl.fd, &st) == 0) {
171+
cxl_size = st.st_size;
172+
}
173+
if (cxl_size == 0) {
174+
// Use default if still 0
175+
cxl_size = DEFAULT_CXL_SIZE;
176+
}
138177
}
139-
178+
140179
g_cxl.base = mmap(NULL, cxl_size, PROT_READ | PROT_WRITE, MAP_SHARED, g_cxl.fd, 0);
141180
if (g_cxl.base == MAP_FAILED) {
142-
LOG_ERROR("Failed to map DAX device: %s\n", strerror(errno));
181+
LOG_ERROR("Failed to map DAX device %s: %s\n", dax_path, strerror(errno));
143182
close(g_cxl.fd);
144183
goto use_shm;
145184
}
146-
185+
147186
g_cxl.type = "dax";
148-
LOG_INFO("Mapped DAX device %s: %zu bytes at %p\n", dax_path, cxl_size, g_cxl.base);
187+
LOG_INFO("Mapped DAX device %s: %zu bytes (%zu MB) at %p\n",
188+
dax_path, cxl_size, cxl_size / (1024*1024), g_cxl.base);
189+
190+
// For DAX, we need to coordinate allocation between processes
191+
// Use first cacheline as allocation counter
192+
if (getenv("CXL_DAX_RESET")) {
193+
// Only reset if explicitly requested
194+
memset(g_cxl.base, 0, CACHELINE_SIZE);
195+
LOG_INFO("Reset DAX allocation counter\n");
196+
}
197+
198+
// DAX allocation starts after first cacheline
199+
g_cxl.used = CACHELINE_SIZE;
200+
149201
} else {
150202
use_shm:
151-
// Use shared memory fallback
152-
char shm_name[256];
153-
snprintf(shm_name, sizeof(shm_name), "/cxlmemsim_mpi_%d", getuid());
154-
155-
// Try to unlink first in case it exists
156-
shm_unlink(shm_name);
157-
158-
g_cxl.fd = shm_open(shm_name, O_CREAT | O_RDWR | O_EXCL, 0600);
203+
// Use shared memory fallback - create a single shared segment
204+
const char *shm_name = "/cxlmemsim_mpi_shared";
205+
206+
// Try to open existing first
207+
g_cxl.fd = shm_open(shm_name, O_RDWR, 0600);
208+
159209
if (g_cxl.fd < 0) {
160-
LOG_ERROR("Failed to create shared memory %s: %s\n", shm_name, strerror(errno));
161-
pthread_mutex_unlock(&g_cxl.lock);
162-
return;
163-
}
164-
165-
if (ftruncate(g_cxl.fd, cxl_size) != 0) {
166-
LOG_ERROR("Failed to resize shared memory: %s\n", strerror(errno));
167-
close(g_cxl.fd);
168-
shm_unlink(shm_name);
169-
pthread_mutex_unlock(&g_cxl.lock);
170-
return;
210+
// Create new
211+
g_cxl.fd = shm_open(shm_name, O_CREAT | O_RDWR, 0666);
212+
if (g_cxl.fd < 0) {
213+
LOG_ERROR("Failed to create/open shared memory %s: %s\n", shm_name, strerror(errno));
214+
pthread_mutex_unlock(&g_cxl.lock);
215+
return;
216+
}
217+
218+
if (ftruncate(g_cxl.fd, cxl_size) != 0) {
219+
LOG_ERROR("Failed to resize shared memory: %s\n", strerror(errno));
220+
close(g_cxl.fd);
221+
pthread_mutex_unlock(&g_cxl.lock);
222+
return;
223+
}
224+
LOG_INFO("Created new shared memory segment %s\n", shm_name);
225+
} else {
226+
// Get existing size
227+
struct stat st;
228+
if (fstat(g_cxl.fd, &st) == 0) {
229+
cxl_size = st.st_size;
230+
}
231+
LOG_INFO("Opened existing shared memory segment %s\n", shm_name);
171232
}
172-
233+
173234
g_cxl.base = mmap(NULL, cxl_size, PROT_READ | PROT_WRITE, MAP_SHARED, g_cxl.fd, 0);
174235
if (g_cxl.base == MAP_FAILED) {
175236
LOG_ERROR("Failed to map shared memory: %s\n", strerror(errno));
176237
close(g_cxl.fd);
177-
shm_unlink(shm_name);
178238
pthread_mutex_unlock(&g_cxl.lock);
179239
return;
180240
}
181-
241+
182242
g_cxl.type = "shm";
183-
LOG_INFO("Created shared memory %s: %zu bytes at %p\n", shm_name, cxl_size, g_cxl.base);
243+
LOG_INFO("Mapped shared memory %s: %zu bytes (%zu MB) at %p\n",
244+
shm_name, cxl_size, cxl_size / (1024*1024), g_cxl.base);
245+
246+
// For shared memory, also use first cacheline for coordination
247+
g_cxl.used = CACHELINE_SIZE;
184248
}
185-
249+
186250
g_cxl.size = cxl_size;
187-
g_cxl.used = 0;
188251
g_cxl.initialized = true;
189-
190-
// Clear the memory
191-
memset(g_cxl.base, 0, MIN(4096, cxl_size));
192-
193-
LOG_INFO("CXL memory initialized: type=%s, size=%zu MB, base=%p\n",
252+
253+
LOG_INFO("CXL memory initialized: type=%s, size=%zu MB, base=%p\n",
194254
g_cxl.type, cxl_size / (1024*1024), g_cxl.base);
195-
255+
196256
pthread_mutex_unlock(&g_cxl.lock);
197257
}
198258

@@ -202,28 +262,54 @@ static void *allocate_cxl_memory(size_t size) {
202262
init_cxl_memory();
203263
if (!g_cxl.initialized) return NULL;
204264
}
205-
265+
206266
// Align size
207267
size = (size + CXL_ALIGNMENT - 1) & ~(CXL_ALIGNMENT - 1);
208-
209-
pthread_mutex_lock(&g_cxl.lock);
210-
211-
if (g_cxl.used + size > g_cxl.size) {
212-
LOG_WARN("Out of CXL memory: requested=%zu, available=%zu\n",
213-
size, g_cxl.size - g_cxl.used);
268+
269+
// For DAX/shared memory, use atomic operations on the allocation counter
270+
if (strcmp(g_cxl.type, "dax") == 0 || strcmp(g_cxl.type, "shm") == 0) {
271+
// Use first 8 bytes of the shared region as atomic allocation counter
272+
_Atomic size_t *alloc_counter = (_Atomic size_t *)g_cxl.base;
273+
274+
size_t old_used = atomic_fetch_add(alloc_counter, size);
275+
size_t new_used = old_used + size;
276+
277+
// Check if we have space (accounting for the counter itself)
278+
if (new_used > g_cxl.size) {
279+
// Roll back
280+
atomic_fetch_sub(alloc_counter, size);
281+
LOG_WARN("Out of CXL memory: requested=%zu, available=%zu\n",
282+
size, g_cxl.size - old_used);
283+
return NULL;
284+
}
285+
286+
void *ptr = (char *)g_cxl.base + old_used;
287+
288+
LOG_TRACE("Allocated %zu bytes at offset %zu (total used: %zu/%zu) [atomic]\n",
289+
size, old_used, new_used, g_cxl.size);
290+
291+
return ptr;
292+
} else {
293+
// Local allocation (shouldn't happen but keep for safety)
294+
pthread_mutex_lock(&g_cxl.lock);
295+
296+
if (g_cxl.used + size > g_cxl.size) {
297+
LOG_WARN("Out of CXL memory: requested=%zu, available=%zu\n",
298+
size, g_cxl.size - g_cxl.used);
299+
pthread_mutex_unlock(&g_cxl.lock);
300+
return NULL;
301+
}
302+
303+
void *ptr = (char *)g_cxl.base + g_cxl.used;
304+
g_cxl.used += size;
305+
306+
LOG_TRACE("Allocated %zu bytes at offset %zu (total used: %zu/%zu)\n",
307+
size, (size_t)((char *)ptr - (char *)g_cxl.base), g_cxl.used, g_cxl.size);
308+
214309
pthread_mutex_unlock(&g_cxl.lock);
215-
return NULL;
310+
311+
return ptr;
216312
}
217-
218-
void *ptr = (char *)g_cxl.base + g_cxl.used;
219-
g_cxl.used += size;
220-
221-
LOG_TRACE("Allocated %zu bytes at offset %zu (total used: %zu/%zu)\n",
222-
size, (size_t)((char *)ptr - (char *)g_cxl.base), g_cxl.used, g_cxl.size);
223-
224-
pthread_mutex_unlock(&g_cxl.lock);
225-
226-
return ptr;
227313
}
228314

229315
// Mapping management
@@ -304,28 +390,30 @@ int MPI_Init(int *argc, char ***argv) {
304390

305391
int MPI_Finalize(void) {
306392
LOG_INFO("=== MPI_Finalize HOOK CALLED ===\n");
307-
393+
308394
LOAD_ORIGINAL(MPI_Finalize);
309-
395+
310396
int ret = orig_MPI_Finalize();
311-
397+
312398
// Cleanup CXL memory
313399
if (g_cxl.initialized) {
314400
LOG_INFO("Cleaning up CXL memory (used %zu/%zu bytes)\n", g_cxl.used, g_cxl.size);
315401
munmap(g_cxl.base, g_cxl.size);
316402
close(g_cxl.fd);
317-
318-
if (strcmp(g_cxl.type, "shm") == 0) {
319-
char shm_name[256];
320-
snprintf(shm_name, sizeof(shm_name), "/cxlmemsim_mpi_%d", getuid());
403+
404+
// Don't unlink shared memory as other processes may still be using it
405+
// Only unlink if explicitly requested
406+
if (strcmp(g_cxl.type, "shm") == 0 && getenv("CXL_SHM_UNLINK")) {
407+
const char *shm_name = "/cxlmemsim_mpi_shared";
321408
shm_unlink(shm_name);
409+
LOG_INFO("Unlinked shared memory %s\n", shm_name);
322410
}
323-
411+
324412
g_cxl.initialized = false;
325413
}
326-
414+
327415
LOG_INFO("MPI_Finalize completed (total hooks: %d)\n", g_hook_count);
328-
416+
329417
return ret;
330418
}
331419

0 commit comments

Comments
 (0)