Skip to content

Commit 1400715

Browse files
committed
fix: Protect local memory stores with shmem_barrier_all()
1 parent b6cfc4e commit 1400715

File tree

3 files changed

+6
-6
lines changed

3 files changed

+6
-6
lines changed

src/libshmem/fallback.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,7 @@ static inline long *_py_shmem_pSync()
311311
_py_shmem_pSync_array = (long *) shmem_malloc(SHMEM_SYNC_SIZE * sizeof(long));
312312
for (int i = 0; i < SHMEM_SYNC_SIZE; i++)
313313
_py_shmem_pSync_array[i] = SHMEM_SYNC_VALUE;
314-
shmem_sync_all();
314+
shmem_barrier_all();
315315
}
316316
return _py_shmem_pSync_array;
317317
}
@@ -328,7 +328,7 @@ static inline void *_py_shmem_pWrk(size_t nreduce, size_t eltsize)
328328
shmem_free(_py_shmem_pWrk_array);
329329
_py_shmem_pWrk_size = wrk_size;
330330
_py_shmem_pWrk_array = shmem_malloc(wrk_size);
331-
shmem_sync_all();
331+
shmem_barrier_all();
332332
}
333333
return _py_shmem_pWrk_array;
334334
}

src/libshmem/memalloc.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ static void *shmem_py_alloc(size_t size, size_t align, long hints, int clear)
1313
}
1414
if (clear) {
1515
memset(ptr, 0, size);
16-
shmem_sync_all();
16+
shmem_barrier_all();
1717
}
1818
return ptr;
1919
}

src/shmem4py/shmem.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -889,7 +889,7 @@ def array(
889889
if tmp.ndim > 1:
890890
a.strides = tmp.strides
891891
np.copyto(a, tmp, casting='no')
892-
lib.shmem_sync_all()
892+
lib.shmem_barrier_all()
893893
return a
894894

895895

@@ -969,7 +969,7 @@ def ones(
969969
"""
970970
a = new_array(shape, dtype, order, align=align, hints=hints, clear=False)
971971
np.copyto(a, 1, casting='unsafe')
972-
lib.shmem_sync_all()
972+
lib.shmem_barrier_all()
973973
return a
974974

975975

@@ -1001,7 +1001,7 @@ def full(
10011001
dtype = np.array(fill_value).dtype
10021002
a = new_array(shape, dtype, order, align=align, hints=hints, clear=False)
10031003
np.copyto(a, fill_value, casting='unsafe')
1004-
lib.shmem_sync_all()
1004+
lib.shmem_barrier_all()
10051005
return a
10061006

10071007

0 commit comments

Comments
 (0)