|
41 | 41 | Nat, |
42 | 42 | NatSmall, |
43 | 43 | R, |
44 | | - Rdunif, |
45 | 44 | Rplus, |
46 | | - Rplusdunif, |
47 | 45 | Runif, |
48 | 46 | Simplex, |
49 | 47 | Unit, |
@@ -95,32 +93,36 @@ def orderedprobit_logpdf(value, eta, cutpoints): |
95 | 93 |
|
96 | 94 |
|
97 | 95 | class TestMatchesScipy: |
98 | | - def test_discrete_unif(self): |
| 96 | + def test_discrete_uniform(self): |
| 97 | + # Choose domain/paramdomain so we test edge cases as well |
| 98 | + test_domain = Domain([-np.inf, -10, -1, 0, 1, 10, np.inf], dtype="int64") |
| 99 | + test_paramdomain = Domain([-np.inf, 0, 10, np.inf], dtype="int64") |
99 | 100 | check_logp( |
100 | 101 | pm.DiscreteUniform, |
101 | | - Rdunif, |
102 | | - {"lower": -Rplusdunif, "upper": Rplusdunif}, |
| 102 | + test_domain, |
| 103 | + {"lower": -test_paramdomain, "upper": test_paramdomain}, |
103 | 104 | lambda value, lower, upper: st.randint.logpmf(value, lower, upper + 1), |
104 | 105 | skip_paramdomain_outside_edge_test=True, |
105 | 106 | ) |
106 | 107 | check_logcdf( |
107 | 108 | pm.DiscreteUniform, |
108 | | - Rdunif, |
109 | | - {"lower": -Rplusdunif, "upper": Rplusdunif}, |
| 109 | + test_domain, |
| 110 | + {"lower": -test_paramdomain, "upper": test_paramdomain}, |
110 | 111 | lambda value, lower, upper: st.randint.logcdf(value, lower, upper + 1), |
111 | 112 | skip_paramdomain_outside_edge_test=True, |
112 | 113 | ) |
113 | 114 | check_selfconsistency_discrete_logcdf( |
114 | 115 | pm.DiscreteUniform, |
115 | 116 | Domain([-10, 0, 10], "int64"), |
116 | | - {"lower": -Rplusdunif, "upper": Rplusdunif}, |
| 117 | + {"lower": -test_paramdomain, "upper": test_paramdomain}, |
117 | 118 | ) |
118 | 119 | check_icdf( |
119 | 120 | pm.DiscreteUniform, |
120 | | - {"lower": -Rplusdunif, "upper": Rplusdunif}, |
| 121 | + {"lower": -test_paramdomain, "upper": test_paramdomain}, |
121 | 122 | lambda q, lower, upper: st.randint.ppf(q=q, low=lower, high=upper + 1), |
122 | 123 | skip_paramdomain_outside_edge_test=True, |
123 | 124 | ) |
| 125 | + |
124 | 126 | # Custom logp / logcdf check for invalid parameters |
125 | 127 | invalid_dist = pm.DiscreteUniform.dist(lower=1, upper=0) |
126 | 128 | with pytensor.config.change_flags(mode=Mode("py")): |
|
0 commit comments