Skip to content

Commit 48b0678

Browse files
authored
Merge pull request #479 from QuantEcon/fix_sample_without_replacement
Fix `sample_without_replacement` using guvectorize
2 parents 2402749 + d3c8b3c commit 48b0678

File tree

1 file changed

+21
-18
lines changed

1 file changed

+21
-18
lines changed

quantecon/random/utilities.py

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
"""
55

66
import numpy as np
7-
from numba import jit, guvectorize, generated_jit, types
7+
from numba import guvectorize, generated_jit, types
88

99
from ..util import check_random_state, searchsorted
1010

@@ -98,7 +98,6 @@ def _probvec(r, out):
9898
)(_probvec)
9999

100100

101-
@jit
102101
def sample_without_replacement(n, k, num_trials=None, random_state=None):
103102
"""
104103
Randomly choose k integers without replacement from 0, ..., n-1.
@@ -144,26 +143,30 @@ def sample_without_replacement(n, k, num_trials=None, random_state=None):
144143
if k > n:
145144
raise ValueError('k must be smaller than or equal to n')
146145

147-
m = 1 if num_trials is None else num_trials
146+
size = k if num_trials is None else (num_trials, k)
148147

149148
random_state = check_random_state(random_state)
150-
r = random_state.random_sample(size=(m, k))
149+
r = random_state.random_sample(size=size)
150+
result = _sample_without_replacement(n, r)
151+
152+
return result
153+
154+
155+
@guvectorize(['(i8, f8[:], i8[:])'], '(),(k)->(k)', nopython=True, cache=True)
156+
def _sample_without_replacement(n, r, out):
157+
"""
158+
Main body of `sample_without_replacement`. To be complied as a ufunc
159+
by guvectorize of Numba.
160+
161+
"""
162+
k = r.shape[0]
151163

152164
# Logic taken from random.sample in the standard library
153-
result = np.empty((m, k), dtype=int)
154-
pool = np.empty(n, dtype=int)
155-
for i in range(m):
156-
for j in range(n):
157-
pool[j] = j
158-
for j in range(k):
159-
idx = int(np.floor(r[i, j] * (n-j))) # np.floor returns a float
160-
result[i, j] = pool[idx]
161-
pool[idx] = pool[n-j-1]
162-
163-
if num_trials is None:
164-
return result[0]
165-
else:
166-
return result
165+
pool = np.arange(n)
166+
for j in range(k):
167+
idx = int(np.floor(r[j] * (n-j))) # np.floor returns a float
168+
out[j] = pool[idx]
169+
pool[idx] = pool[n-j-1]
167170

168171

169172
@generated_jit(nopython=True)

0 commit comments

Comments
 (0)