Skip to content

Commit 3722967

Browse files
committed
Exponential scale default to 1.0
1 parent 3f3aeb9 commit 3722967

File tree

2 files changed

+17
-15
lines changed

2 files changed

+17
-15
lines changed

pymc/distributions/continuous.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1373,13 +1373,12 @@ class Exponential(PositiveContinuous):
13731373
rv_op = exponential
13741374

13751375
@classmethod
1376-
def dist(cls, lam=None, scale=None, *args, **kwargs):
1377-
if lam is not None and scale is not None:
1376+
def dist(cls, lam=None, *, scale=None, **kwargs):
1377+
if lam is None and scale is None:
1378+
scale = 1.0
1379+
elif lam is not None and scale is not None:
13781380
raise ValueError("Incompatible parametrization. Can't specify both lam and scale.")
1379-
elif lam is None and scale is None:
1380-
raise ValueError("Incompatible parametrization. Must specify either lam or scale.")
1381-
1382-
if scale is None:
1381+
elif lam is not None:
13831382
scale = pt.reciprocal(lam)
13841383

13851384
scale = pt.as_tensor_variable(scale)

tests/distributions/test_continuous.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -461,15 +461,6 @@ def test_exponential(self):
461461
lambda q, lam: st.expon.ppf(q, loc=0, scale=1 / lam),
462462
)
463463

464-
def test_exponential_wrong_arguments(self):
465-
msg = "Incompatible parametrization. Can't specify both lam and scale"
466-
with pytest.raises(ValueError, match=msg):
467-
pm.Exponential.dist(lam=0.5, scale=5)
468-
469-
msg = "Incompatible parametrization. Must specify either lam or scale"
470-
with pytest.raises(ValueError, match=msg):
471-
pm.Exponential.dist()
472-
473464
def test_laplace(self):
474465
check_logp(
475466
pm.Laplace,
@@ -2274,8 +2265,20 @@ class TestExponential(BaseTestDistributionRandom):
22742265
checks_to_run = [
22752266
"check_pymc_params_match_rv_op",
22762267
"check_pymc_draws_match_reference",
2268+
"check_both_lam_scale_raises",
2269+
"check_default_scale",
22772270
]
22782271

2272+
def check_both_lam_scale_raises(self):
2273+
msg = "Incompatible parametrization. Can't specify both lam and scale"
2274+
with pytest.raises(ValueError, match=msg):
2275+
pm.Exponential.dist(lam=0.5, scale=5)
2276+
2277+
def check_default_scale(self):
2278+
rv = self.pymc_dist.dist()
2279+
[scale] = rv.owner.op.dist_params(rv.owner)
2280+
assert scale.data == 1.0
2281+
22792282

22802283
class TestExponentialScale(BaseTestDistributionRandom):
22812284
pymc_dist = pm.Exponential

0 commit comments

Comments
 (0)