Skip to content

Commit 762619c

Browse files
irk_geometric_vec now handles p==1.0 without MKL error
Closes issue #15. Note that MKL samples from Geometric(p) are supported on the set of non-negative 32-bit integers, while samples from numpy.random are spported on 64-bit positive integers. ``` In [1]: import mkl_random as rnd In [2]: rnd.geometric(1, 10) Out[2]: array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=int32) In [3]: import numpy as np In [4]: rnd2 = np.random.default_rng() In [5]: rnd2 Out[5]: Generator(PCG64) at 0x2AF349B17A50 In [6]: rnd2.geometric(1, 10) Out[6]: array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1]) In [7]: rnd2.geometric(1, 10)-1 Out[7]: array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) In [8]: np.bincount(rnd2.geometric(0.78, 10**5)-1) Out[8]: array([77870, 17233, 3786, 879, 177, 44, 7, 3, 0, 0, 1]) In [9]: np.bincount(rnd.geometric(0.78, 10**5)) Out[9]: array([78121, 17088, 3743, 821, 172, 42, 10, 2, 1]) In [10]: np.bincount(rnd.geometric(0.78, 10**5)) Out[10]: array([78029, 17225, 3658, 848, 179, 48, 10, 3]) ```
1 parent eb0f7ec commit 762619c

File tree

2 files changed

+21
-12
lines changed

2 files changed

+21
-12
lines changed

mkl_random/mklrand.pyx

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4894,8 +4894,8 @@ cdef class RandomState:
48944894

48954895
fp = PyFloat_AsDouble(p)
48964896
if not PyErr_Occurred():
4897-
if fp < 0.0:
4898-
raise ValueError("p < 0.0")
4897+
if fp <= 0.0:
4898+
raise ValueError("p <= 0.0")
48994899
if fp > 1.0:
49004900
raise ValueError("p > 1.0")
49014901
return vec_discd_array_sc(self.internal_state, irk_geometric_vec, size, fp,
@@ -4905,7 +4905,7 @@ cdef class RandomState:
49054905

49064906

49074907
op = <ndarray>PyArray_FROM_OTF(p, NPY_DOUBLE, NPY_ARRAY_IN_ARRAY)
4908-
if np.any(np.less(op, 0.0)):
4908+
if np.any(np.less_equal(op, 0.0)):
49094909
raise ValueError("p < 0.0")
49104910
if np.any(np.greater(op, 1.0)):
49114911
raise ValueError("p > 1.0")

mkl_random/src/mkl_distributions.cpp

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1169,17 +1169,26 @@ irk_geometric_vec(irk_state *state, npy_intp len, int *res, const double p)
11691169
if(len < 1)
11701170
return;
11711171

1172-
while(len > MKL_INT_MAX) {
1173-
err = viRngGeometric(VSL_RNG_METHOD_GEOMETRIC_ICDF, state->stream, MKL_INT_MAX, res, p);
1174-
assert(err == VSL_STATUS_OK);
1175-
1176-
res += MKL_INT_MAX;
1177-
len -= MKL_INT_MAX;
1178-
}
1172+
if ((0.0 < p) && (p < 1.0)) {
1173+
while(len > MKL_INT_MAX) {
1174+
err = viRngGeometric(VSL_RNG_METHOD_GEOMETRIC_ICDF, state->stream, MKL_INT_MAX, res, p);
1175+
assert(err == VSL_STATUS_OK);
11791176

1180-
err = viRngGeometric(VSL_RNG_METHOD_GEOMETRIC_ICDF, state->stream, len, res, p);
1181-
assert(err == VSL_STATUS_OK);
1177+
res += MKL_INT_MAX;
1178+
len -= MKL_INT_MAX;
1179+
}
11821180

1181+
err = viRngGeometric(VSL_RNG_METHOD_GEOMETRIC_ICDF, state->stream, len, res, p);
1182+
assert(err == VSL_STATUS_OK);
1183+
} else {
1184+
if (p==1.0) {
1185+
npy_intp i;
1186+
for(i=0; i < len; ++i) res[i] = 0;
1187+
} else {
1188+
assert(p >= 0.0);
1189+
assert(p <= 1.0);
1190+
}
1191+
}
11831192
}
11841193

11851194
void

0 commit comments

Comments
 (0)