@@ -2742,13 +2742,14 @@ def test_generated_sample_distribution(
27422742@pytest .mark .parametrize (
27432743 "jax_dist, params, support" ,
27442744 [
2745- (dist .BernoulliLogits , (5.0 ,), jnp .arange (2 )),
2746- (dist .BernoulliProbs , (0.5 ,), jnp .arange (2 )),
2747- (dist .BinomialLogits , (4.5 , 10 ), jnp .arange (11 )),
2748- (dist .BinomialProbs , (0.5 , 11 ), jnp .arange (12 )),
2749- (dist .BetaBinomial , (2.0 , 0.5 , 12 ), jnp .arange (13 )),
2750- (dist .CategoricalLogits , (np .array ([3.0 , 4.0 , 5.0 ]),), jnp .arange (3 )),
2751- (dist .CategoricalProbs , (np .array ([0.1 , 0.5 , 0.4 ]),), jnp .arange (3 )),
2745+ (dist .BernoulliLogits , (5.0 ,), np .arange (2 )),
2746+ (dist .BernoulliProbs , (0.5 ,), np .arange (2 )),
2747+ (dist .BinomialLogits , (4.5 , 10 ), np .arange (11 )),
2748+ (dist .BinomialProbs , (0.5 , 11 ), np .arange (12 )),
2749+ (dist .BetaBinomial , (2.0 , 0.5 , 12 ), np .arange (13 )),
2750+ (dist .CategoricalLogits , (np .array ([3.0 , 4.0 , 5.0 ]),), np .arange (3 )),
2751+ (dist .CategoricalProbs , (np .array ([0.1 , 0.5 , 0.4 ]),), np .arange (3 )),
2752+ (dist .DiscreteUniform , (2 , 4 ), np .arange (2 , 5 )),
27522753 ],
27532754)
27542755@pytest .mark .parametrize ("batch_shape" , [(5 ,), ()])
@@ -3333,8 +3334,8 @@ def test_normal_log_cdf():
33333334 "value" ,
33343335 [
33353336 - 15.0 ,
3336- jnp .array ([[- 15.0 ], [- 10.0 ], [- 5.0 ]]),
3337- jnp .array ([[[- 15.0 ], [- 10.0 ], [- 5.0 ]], [[- 14.0 ], [- 9.0 ], [- 4.0 ]]]),
3337+ np .array ([[- 15.0 ], [- 10.0 ], [- 5.0 ]]),
3338+ np .array ([[[- 15.0 ], [- 10.0 ], [- 5.0 ]], [[- 14.0 ], [- 9.0 ], [- 4.0 ]]]),
33383339 ],
33393340)
33403341def test_truncated_normal_log_prob_in_tail (value ):
0 commit comments