@@ -4640,109 +4640,6 @@ def test_interval_censored_validate_sample(
46404640 censored_dist .log_prob (value ) # Should not raise
46414641
46424642
4643- def test_uniform_log_prob_outside_support ():
4644- d = dist .Uniform (0 , 1 )
4645- assert_allclose (d .log_prob (- 0.5 ), - jnp .inf )
4646- assert_allclose (d .log_prob (1.5 ), - jnp .inf )
4647-
4648-
4649- @pytest .mark .parametrize (
4650- "low, high" , [(0.0 , 1.0 ), (- 2.0 , 3.0 ), (1.0 , 5.0 ), (- 5.0 , - 1.0 )]
4651- )
4652- def test_uniform_log_prob_boundaries (low , high ):
4653- """Test that boundary values are handled correctly."""
4654- d = dist .Uniform (low , high )
4655- expected_log_prob = - jnp .log (high - low )
4656-
4657- # Value at lower bound (included): should have finite log prob
4658- assert_allclose (d .log_prob (low ), expected_log_prob )
4659-
4660- # Value just above lower bound: should have finite log prob
4661- assert_allclose (d .log_prob (low + 1e-10 ), expected_log_prob )
4662-
4663- # Value at upper bound (excluded): should be -inf
4664- assert_allclose (d .log_prob (high ), - jnp .inf )
4665-
4666- # Value just below upper bound: should have finite log prob
4667- assert_allclose (d .log_prob (high - 1e-10 ), expected_log_prob )
4668-
4669- # Value inside support: should have finite log prob
4670- mid = (low + high ) / 2.0
4671- assert_allclose (d .log_prob (mid ), expected_log_prob )
4672-
4673- # Value below lower bound: should be -inf
4674- assert_allclose (d .log_prob (low - 1.0 ), - jnp .inf )
4675-
4676- # Value above upper bound: should be -inf
4677- assert_allclose (d .log_prob (high + 1.0 ), - jnp .inf )
4678-
4679-
4680- @pytest .mark .parametrize ("batch_shape" , [(), (3 ,), (2 , 3 ), (4 , 2 , 3 )])
4681- def test_uniform_log_prob_broadcasting (batch_shape ):
4682- """Test broadcasting with different batch shapes."""
4683- if batch_shape == ():
4684- low = 0.0
4685- high = 1.0
4686- else :
4687- low = jnp .linspace (0.0 , 1.0 , np .prod (batch_shape )).reshape (batch_shape )
4688- high = jnp .linspace (1.0 , 2.0 , np .prod (batch_shape )).reshape (batch_shape )
4689-
4690- d = dist .Uniform (low , high )
4691-
4692- # Test with scalar value
4693- value = 0.5
4694- log_probs = d .log_prob (value )
4695- assert log_probs .shape == batch_shape
4696-
4697- # Test with batched value
4698- if batch_shape :
4699- value_batched = jnp .linspace (- 0.5 , 1.5 , np .prod (batch_shape )).reshape (
4700- batch_shape
4701- )
4702- log_probs_batched = d .log_prob (value_batched )
4703- assert log_probs_batched .shape == batch_shape
4704-
4705- # Check that values outside support return -inf
4706- # Values < low should be -inf
4707- below_low = low - 0.1
4708- assert_allclose (d .log_prob (below_low ), - jnp .inf )
4709-
4710- # Values >= high should be -inf
4711- at_high = high
4712- assert_allclose (d .log_prob (at_high ), - jnp .inf )
4713-
4714-
4715- @pytest .mark .parametrize ("value_shape" , [(), (5 ,), (3 , 4 ), (2 , 3 , 4 )])
4716- def test_uniform_log_prob_value_broadcasting (value_shape ):
4717- """Test broadcasting when value has different shapes."""
4718- d = dist .Uniform (0.0 , 1.0 )
4719-
4720- if value_shape == ():
4721- values = 0.5
4722- else :
4723- values = jnp .linspace (- 0.5 , 1.5 , np .prod (value_shape )).reshape (value_shape )
4724-
4725- log_probs = d .log_prob (values )
4726- assert log_probs .shape == value_shape
4727-
4728- # Check that values inside support have finite log prob
4729- inside_values = jnp .linspace (0.1 , 0.9 , np .prod (value_shape ) if value_shape else 1 )
4730- if value_shape :
4731- inside_values = inside_values .reshape (value_shape )
4732- log_probs_inside = d .log_prob (inside_values )
4733- assert jnp .all (jnp .isfinite (log_probs_inside ))
4734-
4735- # Check that values outside support have -inf
4736- outside_values = jnp .linspace (- 1.0 , 2.0 , np .prod (value_shape ) if value_shape else 1 )
4737- if value_shape :
4738- outside_values = outside_values .reshape (value_shape )
4739- log_probs_outside = d .log_prob (outside_values )
4740- # Values in [0, 1) should be finite, others should be -inf
4741- mask_inside = (outside_values >= 0.0 ) & (outside_values < 1.0 )
4742- assert jnp .all (jnp .where (mask_inside , jnp .isfinite (log_probs_outside ), True ))
4743- assert jnp .all (jnp .where (~ mask_inside , log_probs_outside == - jnp .inf , True ))
4744-
4745-
47464643@pytest .mark .parametrize (
47474644 argnames = "concentration1,concentration0,value" ,
47484645 argvalues = [
0 commit comments