Skip to content

Commit 139bc88

Browse files
committed
feedback 1/n
1 parent 2bc25c0 commit 139bc88

File tree

4 files changed

+6
-6
lines changed

4 files changed

+6
-6
lines changed

numpyro/distributions/distribution.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@
5656
_VALIDATION_ENABLED = True
5757

5858

59-
def enable_validation(is_validate: bool = False) -> None:
59+
def enable_validation(is_validate: bool = True) -> None:
6060
"""
6161
Enable or disable validation checks in NumPyro. Validation checks provide useful warnings and
6262
errors, e.g. NaN checks, validating distribution arguments and support values, etc. which is

test/test_distributions.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3693,13 +3693,13 @@ def test_vmap_validate_args():
36933693

36943694
def test_explicit_validate_args():
36953695
# Check validation passes for valid parameters.
3696-
d = dist.Normal(0, 1)
3697-
d.validate_args(False)
3696+
d = dist.Normal(0, 1, validate_args=False)
3697+
d.validate_args()
36983698

36993699
# Check validation fails for invalid parameters.
37003700
d = dist.Normal(0, -1)
37013701
with pytest.raises(ValueError, match="got invalid scale parameter"):
3702-
d.validate_args(False)
3702+
d.validate_args()
37033703

37043704
# Check validation is skipped for strict=False and raises an error for strict=True.
37053705
jitted = jax.jit(

test/test_pickle.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def bernoulli_model():
6565

6666

6767
def logistic_regression():
68-
data = jnp.arange(10)
68+
data = random.choice(random.PRNGKey(0), jnp.array([0, 1]), (10,))
6969
x = numpyro.sample("x", dist.Normal(0, 1))
7070
with numpyro.plate("N", 10, subsample_size=2):
7171
batch = numpyro.subsample(data, 0)

test/test_util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def test_format_shapes():
112112

113113
def model_test():
114114
mean = numpyro.param("mean", jnp.zeros(len(data)))
115-
scale = numpyro.sample("scale", dist.Normal(0, 1).expand([3]).to_event(1))
115+
scale = numpyro.sample("scale", dist.LogNormal(0, 1).expand([3]).to_event(1))
116116
scale = scale.sum()
117117
with numpyro.plate("data", len(data), subsample_size=10) as ind:
118118
batch = data[ind]

0 commit comments

Comments
 (0)