Skip to content

Commit 3069ca9

Browse files
committed
Merge branch 'master' into release_3.4.2
2 parents 3451a4a + d7374f5 commit 3069ca9

File tree

4 files changed

+35
-25
lines changed

4 files changed

+35
-25
lines changed

pymc3/distributions/discrete.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import warnings
66

77
from pymc3.util import get_variable_name
8-
from .dist_math import bound, factln, binomln, betaln, logpow
8+
from .dist_math import bound, factln, binomln, betaln, logpow, random_choice
99
from .distribution import Discrete, draw_values, generate_samples
1010
from pymc3.math import tround, sigmoid, logaddexp, logit, log1pexp
1111

@@ -710,19 +710,9 @@ def __init__(self, p, *args, **kwargs):
710710
self.p = (p.T / tt.sum(p, -1)).T
711711
self.mode = tt.argmax(p)
712712

713-
def _random(self, k, p, size=None):
714-
if len(p.shape) > 1:
715-
return np.asarray(
716-
[np.random.choice(k, p=pp, size=size)
717-
for pp in p]
718-
)
719-
else:
720-
return np.asarray(np.random.choice(k, p=p, size=size))
721-
722713
def random(self, point=None, size=None):
723714
p, k = draw_values([self.p, self.k], point=point, size=size)
724-
return generate_samples(self._random,
725-
k=k,
715+
return generate_samples(random_choice,
726716
p=p,
727717
broadcast_shape=p.shape[:-1] or (1,),
728718
dist_shape=self.shape,

pymc3/distributions/dist_math.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,3 +280,30 @@ def impl(self, x):
280280

281281

282282
i0e = I0e(upgrade_to_float, name='i0e')
283+
284+
285+
def random_choice(*args, **kwargs):
286+
"""Return draws from a categorial probability functions
287+
288+
Args:
289+
p: array
290+
Probability of each class
291+
size: int
292+
Number of draws to return
293+
k: int
294+
Number of bins
295+
296+
Returns:
297+
random sample: array
298+
299+
"""
300+
p = kwargs.pop('p')
301+
size = kwargs.pop('size')
302+
k = p.shape[-1]
303+
304+
if p.ndim > 1:
305+
samples = np.row_stack([np.random.choice(k, p=p_) for p_ in p])
306+
else:
307+
samples = np.random.choice(k, p=p, size=size)
308+
return samples
309+

pymc3/distributions/distribution.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,6 @@ def random(self, *args, **kwargs):
214214
"Define a custom random method and pass it as kwarg random")
215215

216216

217-
218217
def draw_values(params, point=None, size=None):
219218
"""
220219
Draw (fix) parameter values. Handles a number of cases:

pymc3/distributions/mixture.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from pymc3.util import get_variable_name
55
from ..math import logsumexp
6-
from .dist_math import bound
6+
from .dist_math import bound, random_choice
77
from .distribution import Discrete, Distribution, draw_values, generate_samples
88
from .continuous import get_tau_sd, Normal
99

@@ -147,24 +147,18 @@ def logp(self, value):
147147
broadcast_conditions=False)
148148

149149
def random(self, point=None, size=None):
150-
def random_choice(*args, **kwargs):
151-
w = kwargs.pop('w')
152-
w /= w.sum(axis=-1, keepdims=True)
153-
k = w.shape[-1]
154-
155-
if w.ndim > 1:
156-
return np.row_stack([np.random.choice(k, p=w_) for w_ in w])
157-
else:
158-
return np.random.choice(k, p=w, *args, **kwargs)
159-
160150
w = draw_values([self.w], point=point)[0]
161151
comp_tmp = self._comp_samples(point=point, size=None)
162152
if np.asarray(self.shape).size == 0:
163153
distshape = np.asarray(np.broadcast(w, comp_tmp).shape)[..., :-1]
164154
else:
165155
distshape = np.asarray(self.shape)
156+
157+
# Normalize inputs
158+
w /= w.sum(axis=-1, keepdims=True)
159+
166160
w_samples = generate_samples(random_choice,
167-
w=w,
161+
p=w,
168162
broadcast_shape=w.shape[:-1] or (1,),
169163
dist_shape=distshape,
170164
size=size).squeeze()

0 commit comments

Comments
 (0)