Skip to content

Commit 8279b0b

Browse files
Fix uniform distribution log_prob support (#2095)
* initial fix * initial test * add more tests * add validation by default * Update numpyro/distributions/distribution.py Co-authored-by: Meesum Qazalbash <meesumqazalbash@gmail.com> * improve some validation * remove bad noqa generating warnings * validate to false * fix support * rm dupplicated code * feedback 1/n * undo uniform changes and admend test * try to fix validation tests * fix 1/n XD --------- Co-authored-by: Meesum Qazalbash <meesumqazalbash@gmail.com>
1 parent 212bccd commit 8279b0b

File tree

7 files changed

+24
-11
lines changed

7 files changed

+24
-11
lines changed

numpyro/distributions/censored.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def __init__(
8181
base_dist: DistributionT,
8282
censored: ArrayLike = False,
8383
*,
84-
validate_args: Optional[bool] = None,
84+
validate_args: bool = False,
8585
):
8686
# test if base_dist has an implemented cdf method
8787
if not hasattr(base_dist, "cdf"):
@@ -197,7 +197,7 @@ def __init__(
197197
base_dist: DistributionT,
198198
censored: ArrayLike = False,
199199
*,
200-
validate_args: Optional[bool] = None,
200+
validate_args: bool = False,
201201
):
202202
# test if base_dist has an implemented cdf method
203203
if not hasattr(base_dist, "cdf"):
@@ -335,7 +335,7 @@ def __init__(
335335
left_censored: ArrayLike,
336336
right_censored: ArrayLike,
337337
*,
338-
validate_args: Optional[bool] = None,
338+
validate_args: bool = False,
339339
):
340340
# Optionally test that cdf actually works (in validate_args mode)
341341
if validate_args:

numpyro/distributions/continuous.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3577,7 +3577,7 @@ class Levy(Distribution):
35773577
"""
35783578

35793579
arg_constraints = {
3580-
"loc": constraints.positive,
3580+
"loc": constraints.real,
35813581
"scale": constraints.positive,
35823582
}
35833583

numpyro/distributions/distribution.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353

5454
from . import constraints
5555

56-
_VALIDATION_ENABLED = False
56+
_VALIDATION_ENABLED = True
5757

5858

5959
def enable_validation(is_validate: bool = True) -> None:
@@ -1320,7 +1320,7 @@ class Unit(Distribution):
13201320
arg_constraints = {"log_factor": constraints.real}
13211321
support = constraints.real
13221322

1323-
def __init__(self, log_factor: ArrayLike, *, validate_args: Optional[bool] = None):
1323+
def __init__(self, log_factor: ArrayLike, *, validate_args: bool = False):
13241324
batch_shape = jnp.shape(log_factor)
13251325
event_shape = (0,) # This satisfies .size == 0.
13261326
self.log_factor = log_factor

test/contrib/test_module.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -620,7 +620,7 @@ def __call__(self, x, state):
620620
return x, state
621621

622622
# Eager initialization of the Net module outside the model
623-
net_module, eager_state = eqx.nn.make_with_state(Net)(key=random.PRNGKey(0)) # noqa: E1111
623+
net_module, eager_state = eqx.nn.make_with_state(Net)(key=random.PRNGKey(0))
624624
x = dist.Normal(0, 1).expand([4, 3]).to_event(2).sample(random.PRNGKey(0))
625625

626626
def model():

test/test_distributions.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3693,11 +3693,11 @@ def test_vmap_validate_args():
36933693

36943694
def test_explicit_validate_args():
36953695
# Check validation passes for valid parameters.
3696-
d = dist.Normal(0, 1)
3696+
d = dist.Normal(0, 1, validate_args=False)
36973697
d.validate_args()
36983698

36993699
# Check validation fails for invalid parameters.
3700-
d = dist.Normal(0, -1)
3700+
d = dist.Normal(0, -1, validate_args=False)
37013701
with pytest.raises(ValueError, match="got invalid scale parameter"):
37023702
d.validate_args()
37033703

@@ -4766,3 +4766,16 @@ def log_prob_fn(params):
47664766
f"All gradients for Beta({concentration1},{concentration0}) at x={value} "
47674767
f"should be finite"
47684768
)
4769+
4770+
4771+
def test_uniform_log_prob_outside_support():
4772+
from numpyro.distributions.distribution import enable_validation
4773+
4774+
enable_validation()
4775+
4776+
d = dist.Uniform(0, 1)
4777+
with pytest.warns(
4778+
UserWarning,
4779+
match="Out-of-support values provided to log prob method. The value argument should be within the support.",
4780+
):
4781+
d.log_prob(-0.5)

test/test_pickle.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def bernoulli_model():
6565

6666

6767
def logistic_regression():
68-
data = jnp.arange(10)
68+
data = random.choice(random.PRNGKey(0), jnp.array([0, 1]), (10,))
6969
x = numpyro.sample("x", dist.Normal(0, 1))
7070
with numpyro.plate("N", 10, subsample_size=2):
7171
batch = numpyro.subsample(data, 0)

test/test_util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def test_format_shapes():
112112

113113
def model_test():
114114
mean = numpyro.param("mean", jnp.zeros(len(data)))
115-
scale = numpyro.sample("scale", dist.Normal(0, 1).expand([3]).to_event(1))
115+
scale = numpyro.sample("scale", dist.LogNormal(0, 1).expand([3]).to_event(1))
116116
scale = scale.sum()
117117
with numpyro.plate("data", len(data), subsample_size=10) as ind:
118118
batch = data[ind]

0 commit comments

Comments
 (0)