Skip to content

Commit 2bc25c0

Browse files
committed
rm dupplicated code
1 parent 9ca88f1 commit 2bc25c0

File tree

1 file changed

+0
-103
lines changed

1 file changed

+0
-103
lines changed

test/test_distributions.py

Lines changed: 0 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -4640,109 +4640,6 @@ def test_interval_censored_validate_sample(
46404640
censored_dist.log_prob(value) # Should not raise
46414641

46424642

4643-
def test_uniform_log_prob_outside_support():
4644-
d = dist.Uniform(0, 1)
4645-
assert_allclose(d.log_prob(-0.5), -jnp.inf)
4646-
assert_allclose(d.log_prob(1.5), -jnp.inf)
4647-
4648-
4649-
@pytest.mark.parametrize(
4650-
"low, high", [(0.0, 1.0), (-2.0, 3.0), (1.0, 5.0), (-5.0, -1.0)]
4651-
)
4652-
def test_uniform_log_prob_boundaries(low, high):
4653-
"""Test that boundary values are handled correctly."""
4654-
d = dist.Uniform(low, high)
4655-
expected_log_prob = -jnp.log(high - low)
4656-
4657-
# Value at lower bound (included): should have finite log prob
4658-
assert_allclose(d.log_prob(low), expected_log_prob)
4659-
4660-
# Value just above lower bound: should have finite log prob
4661-
assert_allclose(d.log_prob(low + 1e-10), expected_log_prob)
4662-
4663-
# Value at upper bound (excluded): should be -inf
4664-
assert_allclose(d.log_prob(high), -jnp.inf)
4665-
4666-
# Value just below upper bound: should have finite log prob
4667-
assert_allclose(d.log_prob(high - 1e-10), expected_log_prob)
4668-
4669-
# Value inside support: should have finite log prob
4670-
mid = (low + high) / 2.0
4671-
assert_allclose(d.log_prob(mid), expected_log_prob)
4672-
4673-
# Value below lower bound: should be -inf
4674-
assert_allclose(d.log_prob(low - 1.0), -jnp.inf)
4675-
4676-
# Value above upper bound: should be -inf
4677-
assert_allclose(d.log_prob(high + 1.0), -jnp.inf)
4678-
4679-
4680-
@pytest.mark.parametrize("batch_shape", [(), (3,), (2, 3), (4, 2, 3)])
4681-
def test_uniform_log_prob_broadcasting(batch_shape):
4682-
"""Test broadcasting with different batch shapes."""
4683-
if batch_shape == ():
4684-
low = 0.0
4685-
high = 1.0
4686-
else:
4687-
low = jnp.linspace(0.0, 1.0, np.prod(batch_shape)).reshape(batch_shape)
4688-
high = jnp.linspace(1.0, 2.0, np.prod(batch_shape)).reshape(batch_shape)
4689-
4690-
d = dist.Uniform(low, high)
4691-
4692-
# Test with scalar value
4693-
value = 0.5
4694-
log_probs = d.log_prob(value)
4695-
assert log_probs.shape == batch_shape
4696-
4697-
# Test with batched value
4698-
if batch_shape:
4699-
value_batched = jnp.linspace(-0.5, 1.5, np.prod(batch_shape)).reshape(
4700-
batch_shape
4701-
)
4702-
log_probs_batched = d.log_prob(value_batched)
4703-
assert log_probs_batched.shape == batch_shape
4704-
4705-
# Check that values outside support return -inf
4706-
# Values < low should be -inf
4707-
below_low = low - 0.1
4708-
assert_allclose(d.log_prob(below_low), -jnp.inf)
4709-
4710-
# Values >= high should be -inf
4711-
at_high = high
4712-
assert_allclose(d.log_prob(at_high), -jnp.inf)
4713-
4714-
4715-
@pytest.mark.parametrize("value_shape", [(), (5,), (3, 4), (2, 3, 4)])
4716-
def test_uniform_log_prob_value_broadcasting(value_shape):
4717-
"""Test broadcasting when value has different shapes."""
4718-
d = dist.Uniform(0.0, 1.0)
4719-
4720-
if value_shape == ():
4721-
values = 0.5
4722-
else:
4723-
values = jnp.linspace(-0.5, 1.5, np.prod(value_shape)).reshape(value_shape)
4724-
4725-
log_probs = d.log_prob(values)
4726-
assert log_probs.shape == value_shape
4727-
4728-
# Check that values inside support have finite log prob
4729-
inside_values = jnp.linspace(0.1, 0.9, np.prod(value_shape) if value_shape else 1)
4730-
if value_shape:
4731-
inside_values = inside_values.reshape(value_shape)
4732-
log_probs_inside = d.log_prob(inside_values)
4733-
assert jnp.all(jnp.isfinite(log_probs_inside))
4734-
4735-
# Check that values outside support have -inf
4736-
outside_values = jnp.linspace(-1.0, 2.0, np.prod(value_shape) if value_shape else 1)
4737-
if value_shape:
4738-
outside_values = outside_values.reshape(value_shape)
4739-
log_probs_outside = d.log_prob(outside_values)
4740-
# Values in [0, 1) should be finite, others should be -inf
4741-
mask_inside = (outside_values >= 0.0) & (outside_values < 1.0)
4742-
assert jnp.all(jnp.where(mask_inside, jnp.isfinite(log_probs_outside), True))
4743-
assert jnp.all(jnp.where(~mask_inside, log_probs_outside == -jnp.inf, True))
4744-
4745-
47464643
@pytest.mark.parametrize(
47474644
argnames="concentration1,concentration0,value",
47484645
argvalues=[

0 commit comments

Comments
 (0)