|
4 | 4 | """
|
5 | 5 |
|
6 | 6 | import numpy as np
|
7 |
| -from numba import jit, guvectorize, generated_jit, types |
| 7 | +from numba import guvectorize, generated_jit, types |
8 | 8 |
|
9 | 9 | from ..util import check_random_state, searchsorted
|
10 | 10 |
|
@@ -98,7 +98,6 @@ def _probvec(r, out):
|
98 | 98 | )(_probvec)
|
99 | 99 |
|
100 | 100 |
|
101 |
| -@jit |
102 | 101 | def sample_without_replacement(n, k, num_trials=None, random_state=None):
|
103 | 102 | """
|
104 | 103 | Randomly choose k integers without replacement from 0, ..., n-1.
|
@@ -144,26 +143,30 @@ def sample_without_replacement(n, k, num_trials=None, random_state=None):
|
144 | 143 | if k > n:
|
145 | 144 | raise ValueError('k must be smaller than or equal to n')
|
146 | 145 |
|
147 |
| - m = 1 if num_trials is None else num_trials |
| 146 | + size = k if num_trials is None else (num_trials, k) |
148 | 147 |
|
149 | 148 | random_state = check_random_state(random_state)
|
150 |
| - r = random_state.random_sample(size=(m, k)) |
| 149 | + r = random_state.random_sample(size=size) |
| 150 | + result = _sample_without_replacement(n, r) |
| 151 | + |
| 152 | + return result |
| 153 | + |
| 154 | + |
| 155 | +@guvectorize(['(i8, f8[:], i8[:])'], '(),(k)->(k)', nopython=True, cache=True) |
| 156 | +def _sample_without_replacement(n, r, out): |
| 157 | + """ |
| 158 | + Main body of `sample_without_replacement`. To be complied as a ufunc |
| 159 | + by guvectorize of Numba. |
| 160 | +
|
| 161 | + """ |
| 162 | + k = r.shape[0] |
151 | 163 |
|
152 | 164 | # Logic taken from random.sample in the standard library
|
153 |
| - result = np.empty((m, k), dtype=int) |
154 |
| - pool = np.empty(n, dtype=int) |
155 |
| - for i in range(m): |
156 |
| - for j in range(n): |
157 |
| - pool[j] = j |
158 |
| - for j in range(k): |
159 |
| - idx = int(np.floor(r[i, j] * (n-j))) # np.floor returns a float |
160 |
| - result[i, j] = pool[idx] |
161 |
| - pool[idx] = pool[n-j-1] |
162 |
| - |
163 |
| - if num_trials is None: |
164 |
| - return result[0] |
165 |
| - else: |
166 |
| - return result |
| 165 | + pool = np.arange(n) |
| 166 | + for j in range(k): |
| 167 | + idx = int(np.floor(r[j] * (n-j))) # np.floor returns a float |
| 168 | + out[j] = pool[idx] |
| 169 | + pool[idx] = pool[n-j-1] |
167 | 170 |
|
168 | 171 |
|
169 | 172 | @generated_jit(nopython=True)
|
|
0 commit comments