1515#include <unistd.h>
1616#include <execinfo.h>
1717#include <signal.h>
18+ #include <stdarg.h>
1819
1920#define CACHELINE_SIZE 64
2021#define DEFAULT_CXL_SIZE (4UL * 1024 * 1024 * 1024) // 4GB default
2122#define CXL_ALIGNMENT 4096
2223#define SHIM_VERSION "2.0"
2324
25+ #ifndef MIN
26+ #define MIN (a , b ) ((a) < (b) ? (a) : (b))
27+ #endif
28+
2429// Add color output for better visibility
2530#define RED "\x1b[31m"
2631#define GREEN "\x1b[32m"
@@ -109,90 +114,145 @@ static void signal_handler(int sig) {
109114 exit (1 );
110115}
111116
117+ // Get DAX device size from sysfs
118+ static size_t get_dax_size (const char * dax_path ) {
119+ char sysfs_path [512 ];
120+ const char * dev_name = strrchr (dax_path , '/' );
121+ if (!dev_name ) dev_name = dax_path ;
122+ else dev_name ++ ;
123+
124+ snprintf (sysfs_path , sizeof (sysfs_path ), "/sys/bus/dax/devices/%s/size" , dev_name );
125+
126+ FILE * f = fopen (sysfs_path , "r" );
127+ if (!f ) {
128+ LOG_WARN ("Cannot read DAX size from %s, using stat\n" , sysfs_path );
129+ return 0 ;
130+ }
131+
132+ unsigned long long size = 0 ;
133+ if (fscanf (f , "%llu" , & size ) != 1 ) {
134+ fclose (f );
135+ return 0 ;
136+ }
137+ fclose (f );
138+
139+ LOG_DEBUG ("DAX device %s size from sysfs: %llu bytes\n" , dev_name , size );
140+ return (size_t )size ;
141+ }
142+
112143// Initialize CXL memory
113144static void init_cxl_memory (void ) {
114145 if (g_cxl .initialized ) return ;
115-
146+
116147 pthread_mutex_lock (& g_cxl .lock );
117148 if (g_cxl .initialized ) {
118149 pthread_mutex_unlock (& g_cxl .lock );
119150 return ;
120151 }
121-
152+
122153 const char * dax_path = getenv ("CXL_DAX_PATH" );
123154 const char * cxl_size_str = getenv ("CXL_MEM_SIZE" );
124155 size_t cxl_size = cxl_size_str ? strtoull (cxl_size_str , NULL , 0 ) : DEFAULT_CXL_SIZE ;
125-
126- if (dax_path && access (dax_path , R_OK | W_OK ) == 0 ) {
127- // Use DAX device
156+
157+ if (dax_path && strlen (dax_path ) > 0 ) {
158+ // Use DAX device - open with O_RDWR for shared access
128159 g_cxl .fd = open (dax_path , O_RDWR );
129160 if (g_cxl .fd < 0 ) {
130161 LOG_ERROR ("Failed to open DAX device %s: %s\n" , dax_path , strerror (errno ));
131162 goto use_shm ;
132163 }
133-
134- // Get actual DAX size
135- struct stat st ;
136- if (fstat (g_cxl .fd , & st ) == 0 ) {
137- cxl_size = st .st_size ;
164+
165+ // Try to get DAX size from sysfs first
166+ cxl_size = get_dax_size (dax_path );
167+ if (cxl_size == 0 ) {
168+ // Fallback to stat
169+ struct stat st ;
170+ if (fstat (g_cxl .fd , & st ) == 0 ) {
171+ cxl_size = st .st_size ;
172+ }
173+ if (cxl_size == 0 ) {
174+ // Use default if still 0
175+ cxl_size = DEFAULT_CXL_SIZE ;
176+ }
138177 }
139-
178+
140179 g_cxl .base = mmap (NULL , cxl_size , PROT_READ | PROT_WRITE , MAP_SHARED , g_cxl .fd , 0 );
141180 if (g_cxl .base == MAP_FAILED ) {
142- LOG_ERROR ("Failed to map DAX device: %s\n" , strerror (errno ));
181+ LOG_ERROR ("Failed to map DAX device %s : %s\n" , dax_path , strerror (errno ));
143182 close (g_cxl .fd );
144183 goto use_shm ;
145184 }
146-
185+
147186 g_cxl .type = "dax" ;
148- LOG_INFO ("Mapped DAX device %s: %zu bytes at %p\n" , dax_path , cxl_size , g_cxl .base );
187+ LOG_INFO ("Mapped DAX device %s: %zu bytes (%zu MB) at %p\n" ,
188+ dax_path , cxl_size , cxl_size / (1024 * 1024 ), g_cxl .base );
189+
190+ // For DAX, we need to coordinate allocation between processes
191+ // Use first cacheline as allocation counter
192+ if (getenv ("CXL_DAX_RESET" )) {
193+ // Only reset if explicitly requested
194+ memset (g_cxl .base , 0 , CACHELINE_SIZE );
195+ LOG_INFO ("Reset DAX allocation counter\n" );
196+ }
197+
198+ // DAX allocation starts after first cacheline
199+ g_cxl .used = CACHELINE_SIZE ;
200+
149201 } else {
150202use_shm :
151- // Use shared memory fallback
152- char shm_name [256 ];
153- snprintf (shm_name , sizeof (shm_name ), "/cxlmemsim_mpi_%d" , getuid ());
154-
155- // Try to unlink first in case it exists
156- shm_unlink (shm_name );
157-
158- g_cxl .fd = shm_open (shm_name , O_CREAT | O_RDWR | O_EXCL , 0600 );
203+ // Use shared memory fallback - create a single shared segment
204+ const char * shm_name = "/cxlmemsim_mpi_shared" ;
205+
206+ // Try to open existing first
207+ g_cxl .fd = shm_open (shm_name , O_RDWR , 0600 );
208+
159209 if (g_cxl .fd < 0 ) {
160- LOG_ERROR ("Failed to create shared memory %s: %s\n" , shm_name , strerror (errno ));
161- pthread_mutex_unlock (& g_cxl .lock );
162- return ;
163- }
164-
165- if (ftruncate (g_cxl .fd , cxl_size ) != 0 ) {
166- LOG_ERROR ("Failed to resize shared memory: %s\n" , strerror (errno ));
167- close (g_cxl .fd );
168- shm_unlink (shm_name );
169- pthread_mutex_unlock (& g_cxl .lock );
170- return ;
210+ // Create new
211+ g_cxl .fd = shm_open (shm_name , O_CREAT | O_RDWR , 0666 );
212+ if (g_cxl .fd < 0 ) {
213+ LOG_ERROR ("Failed to create/open shared memory %s: %s\n" , shm_name , strerror (errno ));
214+ pthread_mutex_unlock (& g_cxl .lock );
215+ return ;
216+ }
217+
218+ if (ftruncate (g_cxl .fd , cxl_size ) != 0 ) {
219+ LOG_ERROR ("Failed to resize shared memory: %s\n" , strerror (errno ));
220+ close (g_cxl .fd );
221+ pthread_mutex_unlock (& g_cxl .lock );
222+ return ;
223+ }
224+ LOG_INFO ("Created new shared memory segment %s\n" , shm_name );
225+ } else {
226+ // Get existing size
227+ struct stat st ;
228+ if (fstat (g_cxl .fd , & st ) == 0 ) {
229+ cxl_size = st .st_size ;
230+ }
231+ LOG_INFO ("Opened existing shared memory segment %s\n" , shm_name );
171232 }
172-
233+
173234 g_cxl .base = mmap (NULL , cxl_size , PROT_READ | PROT_WRITE , MAP_SHARED , g_cxl .fd , 0 );
174235 if (g_cxl .base == MAP_FAILED ) {
175236 LOG_ERROR ("Failed to map shared memory: %s\n" , strerror (errno ));
176237 close (g_cxl .fd );
177- shm_unlink (shm_name );
178238 pthread_mutex_unlock (& g_cxl .lock );
179239 return ;
180240 }
181-
241+
182242 g_cxl .type = "shm" ;
183- LOG_INFO ("Created shared memory %s: %zu bytes at %p\n" , shm_name , cxl_size , g_cxl .base );
243+ LOG_INFO ("Mapped shared memory %s: %zu bytes (%zu MB) at %p\n" ,
244+ shm_name , cxl_size , cxl_size / (1024 * 1024 ), g_cxl .base );
245+
246+ // For shared memory, also use first cacheline for coordination
247+ g_cxl .used = CACHELINE_SIZE ;
184248 }
185-
249+
186250 g_cxl .size = cxl_size ;
187- g_cxl .used = 0 ;
188251 g_cxl .initialized = true;
189-
190- // Clear the memory
191- memset (g_cxl .base , 0 , MIN (4096 , cxl_size ));
192-
193- LOG_INFO ("CXL memory initialized: type=%s, size=%zu MB, base=%p\n" ,
252+
253+ LOG_INFO ("CXL memory initialized: type=%s, size=%zu MB, base=%p\n" ,
194254 g_cxl .type , cxl_size / (1024 * 1024 ), g_cxl .base );
195-
255+
196256 pthread_mutex_unlock (& g_cxl .lock );
197257}
198258
@@ -202,28 +262,54 @@ static void *allocate_cxl_memory(size_t size) {
202262 init_cxl_memory ();
203263 if (!g_cxl .initialized ) return NULL ;
204264 }
205-
265+
206266 // Align size
207267 size = (size + CXL_ALIGNMENT - 1 ) & ~(CXL_ALIGNMENT - 1 );
208-
209- pthread_mutex_lock (& g_cxl .lock );
210-
211- if (g_cxl .used + size > g_cxl .size ) {
212- LOG_WARN ("Out of CXL memory: requested=%zu, available=%zu\n" ,
213- size , g_cxl .size - g_cxl .used );
268+
269+ // For DAX/shared memory, use atomic operations on the allocation counter
270+ if (strcmp (g_cxl .type , "dax" ) == 0 || strcmp (g_cxl .type , "shm" ) == 0 ) {
271+ // Use first 8 bytes of the shared region as atomic allocation counter
272+ _Atomic size_t * alloc_counter = (_Atomic size_t * )g_cxl .base ;
273+
274+ size_t old_used = atomic_fetch_add (alloc_counter , size );
275+ size_t new_used = old_used + size ;
276+
277+ // Check if we have space (accounting for the counter itself)
278+ if (new_used > g_cxl .size ) {
279+ // Roll back
280+ atomic_fetch_sub (alloc_counter , size );
281+ LOG_WARN ("Out of CXL memory: requested=%zu, available=%zu\n" ,
282+ size , g_cxl .size - old_used );
283+ return NULL ;
284+ }
285+
286+ void * ptr = (char * )g_cxl .base + old_used ;
287+
288+ LOG_TRACE ("Allocated %zu bytes at offset %zu (total used: %zu/%zu) [atomic]\n" ,
289+ size , old_used , new_used , g_cxl .size );
290+
291+ return ptr ;
292+ } else {
293+ // Local allocation (shouldn't happen but keep for safety)
294+ pthread_mutex_lock (& g_cxl .lock );
295+
296+ if (g_cxl .used + size > g_cxl .size ) {
297+ LOG_WARN ("Out of CXL memory: requested=%zu, available=%zu\n" ,
298+ size , g_cxl .size - g_cxl .used );
299+ pthread_mutex_unlock (& g_cxl .lock );
300+ return NULL ;
301+ }
302+
303+ void * ptr = (char * )g_cxl .base + g_cxl .used ;
304+ g_cxl .used += size ;
305+
306+ LOG_TRACE ("Allocated %zu bytes at offset %zu (total used: %zu/%zu)\n" ,
307+ size , (size_t )((char * )ptr - (char * )g_cxl .base ), g_cxl .used , g_cxl .size );
308+
214309 pthread_mutex_unlock (& g_cxl .lock );
215- return NULL ;
310+
311+ return ptr ;
216312 }
217-
218- void * ptr = (char * )g_cxl .base + g_cxl .used ;
219- g_cxl .used += size ;
220-
221- LOG_TRACE ("Allocated %zu bytes at offset %zu (total used: %zu/%zu)\n" ,
222- size , (size_t )((char * )ptr - (char * )g_cxl .base ), g_cxl .used , g_cxl .size );
223-
224- pthread_mutex_unlock (& g_cxl .lock );
225-
226- return ptr ;
227313}
228314
229315// Mapping management
@@ -304,28 +390,30 @@ int MPI_Init(int *argc, char ***argv) {
304390
305391int MPI_Finalize (void ) {
306392 LOG_INFO ("=== MPI_Finalize HOOK CALLED ===\n" );
307-
393+
308394 LOAD_ORIGINAL (MPI_Finalize );
309-
395+
310396 int ret = orig_MPI_Finalize ();
311-
397+
312398 // Cleanup CXL memory
313399 if (g_cxl .initialized ) {
314400 LOG_INFO ("Cleaning up CXL memory (used %zu/%zu bytes)\n" , g_cxl .used , g_cxl .size );
315401 munmap (g_cxl .base , g_cxl .size );
316402 close (g_cxl .fd );
317-
318- if (strcmp (g_cxl .type , "shm" ) == 0 ) {
319- char shm_name [256 ];
320- snprintf (shm_name , sizeof (shm_name ), "/cxlmemsim_mpi_%d" , getuid ());
403+
404+ // Don't unlink shared memory as other processes may still be using it
405+ // Only unlink if explicitly requested
406+ if (strcmp (g_cxl .type , "shm" ) == 0 && getenv ("CXL_SHM_UNLINK" )) {
407+ const char * shm_name = "/cxlmemsim_mpi_shared" ;
321408 shm_unlink (shm_name );
409+ LOG_INFO ("Unlinked shared memory %s\n" , shm_name );
322410 }
323-
411+
324412 g_cxl .initialized = false;
325413 }
326-
414+
327415 LOG_INFO ("MPI_Finalize completed (total hooks: %d)\n" , g_hook_count );
328-
416+
329417 return ret ;
330418}
331419
0 commit comments