Skip to content

Commit 74dccd8

Browse files
ricardoV94twiecki
authored andcommitted
Skip GenExtreme logcdf test on float32 and Windows
* Also refactor lambda for legibility
1 parent 53e153e commit 74dccd8

File tree

1 file changed

+19
-6
lines changed

1 file changed

+19
-6
lines changed

pymc_experimental/tests/distributions/test_continuous.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import platform
1415

1516
import numpy as np
1617
import pymc as pm
@@ -50,6 +51,12 @@ class TestGenExtremeClass:
5051
reason="PyMC underflows earlier than scipy on float32",
5152
)
5253
def test_logp(self):
54+
def ref_logp(value, mu, sigma, xi):
55+
if 1 + xi * (value - mu) / sigma > 0:
56+
return sp.genextreme.logpdf(value, c=-xi, loc=mu, scale=sigma)
57+
else:
58+
return -np.inf
59+
5360
check_logp(
5461
GenExtreme,
5562
R,
@@ -58,15 +65,23 @@ def test_logp(self):
5865
"sigma": Rplusbig,
5966
"xi": Domain([-1, -0.99, -0.5, 0, 0.5, 0.99, 1]),
6067
},
61-
lambda value, mu, sigma, xi: sp.genextreme.logpdf(value, c=-xi, loc=mu, scale=sigma)
62-
if 1 + xi * (value - mu) / sigma > 0
63-
else -np.inf,
68+
ref_logp,
6469
)
6570

6671
if pytensor.config.floatX == "float32":
6772
raise Exception("Flaky test: It passed this time, but XPASS is not allowed.")
6873

74+
@pytest.mark.skipif(
75+
(pytensor.config.floatX == "float32" and platform.system() == "Windows"),
76+
reason="Scipy gives different results on Windows and does not match with desired accuracy",
77+
)
6978
def test_logcdf(self):
79+
def ref_logcdf(value, mu, sigma, xi):
80+
if 1 + xi * (value - mu) / sigma > 0:
81+
return sp.genextreme.logcdf(value, c=-xi, loc=mu, scale=sigma)
82+
else:
83+
return -np.inf
84+
7085
check_logcdf(
7186
GenExtreme,
7287
R,
@@ -75,9 +90,7 @@ def test_logcdf(self):
7590
"sigma": Rplusbig,
7691
"xi": Domain([-1, -0.99, -0.5, 0, 0.5, 0.99, 1]),
7792
},
78-
lambda value, mu, sigma, xi: sp.genextreme.logcdf(value, c=-xi, loc=mu, scale=sigma)
79-
if 1 + xi * (value - mu) / sigma > 0
80-
else -np.inf,
93+
ref_logcdf,
8194
decimal=select_by_precision(float64=6, float32=2),
8295
)
8396

0 commit comments

Comments
 (0)