Skip to content

Commit cd25fb6

Browse files
committed
use correct wavefrontsize for device-side rng
Using the wrong wavefrontsize here technically could cause non-deterministic results even with a fixed seed. This is because mutiple threads inside the same wavefront are writing different values to the same memory location non-atomically. In CUDA, which value is actually stored is undefined and I am assuming ROCm makes similar assumptions, though I didn't find it documented anywhere. Testing with RDNA3.5 in wavefront64 mode though, I wasn't able to produce non-deterministic results from this, so it might not be an issue in practice. Still, I believe it's good practice to not rely on this being deterministic. One issue that remains is that we are allocating 32 counters and keys, even though 16 would be sufficient for wavefront64 targets. I didn't find a good way to emit these globals in a target-specific way, any suggestions welcome! closes #819
1 parent 4069021 commit cd25fb6

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

src/device/random.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -82,11 +82,10 @@ end
8282
threadId = workitemIdx().x +
8383
(workitemIdx().y - Int32(1)) * workgroupDim().x +
8484
(workitemIdx().z - Int32(1)) * workgroupDim().x * workgroupDim().y
85-
warpId = (threadId - Int32(1)) >> 0x5 + Int32(1) # fld1 by 32
85+
wavefrontsize_log2 = ifelse(wavefrontsize() == UInt32(32), 0x5, 0x6)
86+
warpId = (threadId - Int32(1)) >> wavefrontsize_log2 + Int32(1) # fld1 by 32
8687

87-
if field === :seed
88-
@inbounds global_random_seed()[1]
89-
elseif field === :key
88+
if field === :key
9089
@inbounds global_random_keys()[warpId]
9190
elseif field === :ctr1
9291
@inbounds global_random_counters()[warpId]
@@ -104,7 +103,8 @@ end
104103
threadId = workitemIdx().x +
105104
(workitemIdx().y - Int32(1)) * workgroupDim().x +
106105
(workitemIdx().z - Int32(1)) * workgroupDim().x * workgroupDim().y
107-
warpId = (threadId - Int32(1)) >> 0x5 + Int32(1) # fld1 by 32
106+
wavefrontsize_log2 = ifelse(wavefrontsize() == UInt32(32), 0x5, 0x6)
107+
warpId = (threadId - Int32(1)) >> wavefrontsize_log2 + Int32(1) # fld1 by 32
108108

109109
if field === :key
110110
@inbounds global_random_keys()[warpId] = x

0 commit comments

Comments
 (0)