@@ -191,14 +191,11 @@ function kernel_BTRS!(
191191 @inbounds n = count[I1]
192192 @inbounds p = prob[CartesianIndex(I1, I2)]
193193 end
194- # BTRS approximations work well for p <= 0.5
195- # invert p and set `invert` flag
196- (invert = p > 0.5f0 ) && (p = 1 - p)
197194 else
198195 n = 0
199196 p = 0f0
200197 end
201-
198+
202199 # SAMPLER
203200 # edge cases
204201 if p <= 0 || n <= 0
@@ -213,66 +210,120 @@ function kernel_BTRS!(
213210 rand(Float32) < p && (k += 1 )
214211 ctr += 1
215212 end
216- # Use inversion algorithm for n*p < 10
217- elseif n * p < 10f0
218- logp = CUDA. log(1f0 - p)
219- geom_sum = 0f0
220- k = 0
221- while true
222- geom = ceil(CUDA. log(rand(Float32)) / logp)
223- geom_sum += geom
224- geom_sum > n && break
225- k += 1
226- end
227- # BTRS algorithm
228- else
229- r = p/ (1f0 - p)
230- s = p* (1f0 - p)
231-
232- stddev = sqrt(n * s)
233- b = 1.15f0 + 2.53f0 * stddev
234- a = - 0.0873f0 + 0.0248f0 * b + 0.01f0 * p
235- c = n * p + 0.5f0
236- v_r = 0.92f0 - 4.2f0 / b
237-
238- alpha = (2.83f0 + 5.1f0 / b) * stddev;
239- m = floor((n + 1 ) * p)
240-
241- ks = 0f0
242-
243- while true
244- usample = rand(Float32) - 0.5f0
245- vsample = rand(Float32)
246-
247- us = 0.5f0 - abs(usample)
248- ks = floor((2 * a / us + b) * usample + c)
249-
250- if us >= 0.07f0 && vsample <= v_r
251- break
213+ elseif p <= 0.5f0
214+ # Use inversion algorithm for n*p < 10
215+ if n * p < 10f0
216+ logp = CUDA. log(1f0 - p)
217+ geom_sum = 0f0
218+ k = 0
219+ while true
220+ geom = ceil(CUDA. log(rand(Float32)) / logp)
221+ geom_sum += geom
222+ geom_sum > n && break
223+ k += 1
252224 end
225+ # BTRS algorithm
226+ else
227+ r = p/ (1f0 - p)
228+ s = p* (1f0 - p)
229+
230+ stddev = sqrt(n * s)
231+ b = 1.15f0 + 2.53f0 * stddev
232+ a = - 0.0873f0 + 0.0248f0 * b + 0.01f0 * p
233+ c = n * p + 0.5f0
234+ v_r = 0.92f0 - 4.2f0 / b
235+
236+ alpha = (2.83f0 + 5.1f0 / b) * stddev;
237+ m = floor((n + 1 ) * p)
253238
254- if ks < 0 || ks > n
255- continue
239+ ks = 0f0
240+
241+ while true
242+ usample = rand(Float32) - 0.5f0
243+ vsample = rand(Float32)
244+
245+ us = 0.5f0 - abs(usample)
246+ ks = floor((2 * a / us + b) * usample + c)
247+
248+ if us >= 0.07f0 && vsample <= v_r
249+ break
250+ end
251+
252+ if ks < 0 || ks > n
253+ continue
254+ end
255+
256+ v2 = CUDA. log(vsample * alpha / (a / (us * us) + b))
257+ ub = (m + 0.5f0 ) * CUDA. log((m + 1 ) / (r * (n - m + 1 ))) +
258+ (n + 1 ) * CUDA. log((n - m + 1 ) / (n - ks + 1 )) +
259+ (ks + 0.5f0 ) * CUDA. log(r * (n - ks + 1 ) / (ks + 1 )) +
260+ stirling_approx_tail(m) + stirling_approx_tail(n - m) - stirling_approx_tail(ks) - stirling_approx_tail(n - ks)
261+ if v2 <= ub
262+ break
263+ end
264+ end
265+ k = Int(ks)
266+ end
267+ elseif p > 0.5f0
268+ p = 1 - p
269+ # Use inversion algorithm for n*p < 10
270+ if n * p < 10f0
271+ logp = CUDA. log(1f0 - p)
272+ geom_sum = 0f0
273+ k = 0
274+ while true
275+ geom = ceil(CUDA. log(rand(Float32)) / logp)
276+ geom_sum += geom
277+ geom_sum > n && break
278+ k += 1
256279 end
280+ # BTRS algorithm
281+ else
282+ r = p/ (1f0 - p)
283+ s = p* (1f0 - p)
284+
285+ stddev = sqrt(n * s)
286+ b = 1.15f0 + 2.53f0 * stddev
287+ a = - 0.0873f0 + 0.0248f0 * b + 0.01f0 * p
288+ c = n * p + 0.5f0
289+ v_r = 0.92f0 - 4.2f0 / b
290+
291+ alpha = (2.83f0 + 5.1f0 / b) * stddev;
292+ m = floor((n + 1 ) * p)
257293
258- v2 = CUDA. log(vsample * alpha / (a / (us * us) + b))
259- ub = (m + 0.5f0 ) * CUDA. log((m + 1 ) / (r * (n - m + 1 ))) +
260- (n + 1 ) * CUDA. log((n - m + 1 ) / (n - ks + 1 )) +
261- (ks + 0.5f0 ) * CUDA. log(r * (n - ks + 1 ) / (ks + 1 )) +
262- stirling_approx_tail(m) + stirling_approx_tail(n - m) - stirling_approx_tail(ks) - stirling_approx_tail(n - ks)
263- if v2 <= ub
264- break
294+ ks = 0f0
295+
296+ while true
297+ usample = rand(Float32) - 0.5f0
298+ vsample = rand(Float32)
299+
300+ us = 0.5f0 - abs(usample)
301+ ks = floor((2 * a / us + b) * usample + c)
302+
303+ if us >= 0.07f0 && vsample <= v_r
304+ break
305+ end
306+
307+ if ks < 0 || ks > n
308+ continue
309+ end
310+
311+ v2 = CUDA. log(vsample * alpha / (a / (us * us) + b))
312+ ub = (m + 0.5f0 ) * CUDA. log((m + 1 ) / (r * (n - m + 1 ))) +
313+ (n + 1 ) * CUDA. log((n - m + 1 ) / (n - ks + 1 )) +
314+ (ks + 0.5f0 ) * CUDA. log(r * (n - ks + 1 ) / (ks + 1 )) +
315+ stirling_approx_tail(m) + stirling_approx_tail(n - m) - stirling_approx_tail(ks) - stirling_approx_tail(n - ks)
316+ if v2 <= ub
317+ break
318+ end
265319 end
320+ k = Int(ks)
266321 end
267- k = Int(ks)
322+ k = n - k
268323 end
269324
270325 if i <= length(A)
271- if invert
272- @inbounds A[i] = n - k
273- else
274- @inbounds A[i] = k
275- end
326+ @inbounds A[i] = k
276327 end
277328 offset += window
278329 end
0 commit comments