Skip to content

Commit 95accbd

Browse files
authored
Fixed str parametrization for Negative Binomial distribution (#4187)
* Fixed repr parametrization for Negative Binomial distribution * Added tests Wrapped p and n with tt.as_tensor_variable
1 parent 9373d5a commit 95accbd

File tree

2 files changed

+19
-0
lines changed

2 files changed

+19
-0
lines changed

pymc3/distributions/discrete.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -640,8 +640,11 @@ def __init__(self, mu=None, alpha=None, p=None, n=None, *args, **kwargs):
640640
self.mode = intX(tt.floor(mu))
641641

642642
def get_mu_alpha(self, mu=None, alpha=None, p=None, n=None):
643+
self._param_type = ["mu", "alpha"]
643644
if alpha is None:
644645
if n is not None:
646+
self._param_type[1] = "n"
647+
self.n = tt.as_tensor_variable(intX(n))
645648
alpha = n
646649
else:
647650
raise ValueError("Incompatible parametrization. Must specify either alpha or n.")
@@ -650,6 +653,8 @@ def get_mu_alpha(self, mu=None, alpha=None, p=None, n=None):
650653

651654
if mu is None:
652655
if p is not None:
656+
self._param_type[0] = "p"
657+
self.p = tt.as_tensor_variable(floatX(p))
653658
mu = alpha * (1 - p) / p
654659
else:
655660
raise ValueError("Incompatible parametrization. Must specify either mu or p.")
@@ -720,6 +725,9 @@ def logp(self, value):
720725
# Return Poisson when alpha gets very large.
721726
return tt.switch(tt.gt(alpha, 1e10), Poisson.dist(self.mu).logp(value), negbinom)
722727

728+
def _distr_parameters_for_repr(self):
729+
return self._param_type
730+
723731

724732
class Geometric(Discrete):
725733
R"""

pymc3/tests/test_distributions.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1880,6 +1880,17 @@ def test_issue_3051(self, dims, dist_cls, kwargs):
18801880
assert actual_a.shape == (X.shape[0],)
18811881
pass
18821882

1883+
def test_issue_4186(self):
1884+
with pm.Model():
1885+
nb = pm.NegativeBinomial(
1886+
"nb", mu=pm.Normal("mu"), alpha=pm.Gamma("alpha", mu=6, sigma=1)
1887+
)
1888+
assert str(nb) == "nb ~ NegativeBinomial(mu=mu, alpha=alpha)"
1889+
1890+
with pm.Model():
1891+
nb = pm.NegativeBinomial("nb", p=pm.Uniform("p"), n=10)
1892+
assert str(nb) == "nb ~ NegativeBinomial(p=p, n=10)"
1893+
18831894

18841895
def test_serialize_density_dist():
18851896
def func(x):

0 commit comments

Comments
 (0)