Skip to content

Commit 973f25d

Browse files
canyon289junpenglao
authored andcommitted
[WIP] Fix categorical random shape (#3060)
Fix categorical random shape
1 parent 8e0d7ac commit 973f25d

File tree

3 files changed

+9
-3
lines changed

3 files changed

+9
-3
lines changed

pymc3/distributions/discrete.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -712,6 +712,7 @@ def __init__(self, p, *args, **kwargs):
712712

713713
def random(self, point=None, size=None):
714714
p, k = draw_values([self.p, self.k], point=point, size=size)
715+
715716
return generate_samples(random_choice,
716717
p=p,
717718
broadcast_shape=p.shape[:-1] or (1,),

pymc3/distributions/dist_math.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,8 @@ def random_choice(*args, **kwargs):
302302
k = p.shape[-1]
303303

304304
if p.ndim > 1:
305-
samples = np.row_stack([np.random.choice(k, p=p_) for p_ in p])
305+
# If a 2d vector of probabilities is passed return a sample for each row of categorical probability
306+
samples = np.array([np.random.choice(k, p=p_) for p_ in p])
306307
else:
307308
samples = np.random.choice(k, p=p, size=size)
308309
return samples

pymc3/tests/test_distributions_random.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -194,8 +194,7 @@ def test_broadcast_shape(self, size):
194194
@pytest.mark.parametrize('shape', [(), (1,), (1, 1), (1, 2), (10, 10, 1), (10, 10, 2)], ids=str)
195195
def test_different_shapes_and_sample_sizes(self, shape):
196196
prefix = self.distribution.__name__
197-
expected = []
198-
actual = []
197+
199198
rv = self.get_random_variable(shape, name='%s_%s' % (prefix, shape))
200199
for size in (None, 1, 5, (4, 5)):
201200
if size is None:
@@ -402,6 +401,11 @@ class TestCategorical(BaseTestCases.BaseTestCase):
402401
def get_random_variable(self, shape, with_vector_params=False, **kwargs): # don't transform categories
403402
return super(TestCategorical, self).get_random_variable(shape, with_vector_params=False, **kwargs)
404403

404+
def test_probability_vector_shape(self):
405+
"""Check that if a 2d array of probabilities are passed to categorical correct shape is returned"""
406+
p = np.ones((10, 5))
407+
assert pm.Categorical.dist(p=p).random().shape == (10,)
408+
405409

406410
class TestScalarParameterSamples(SeededTest):
407411
def test_bounded(self):

0 commit comments

Comments
 (0)