Skip to content

Commit 25e9476

Browse files
Deal with deprecation warning (#4547)
* Make backend compatibility check more robust And deal with deprecation warning (closes #4546). * Remove undefined plot_posterior from __all__ because it's already in az.plots.__all__ Pylint complained about it.
1 parent c1efb7a commit 25e9476

File tree

4 files changed

+47
-10
lines changed

4 files changed

+47
-10
lines changed

pymc3/__init__.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,29 @@
3131
_log.addHandler(handler)
3232

3333

34-
if not semver.match(theano.__version__, ">=1.1.2"):
35-
print(
36-
"!" * 60
37-
+ f"\nThe installed Theano(-PyMC) version ({theano.__version__}) does not match the PyMC3 requirements."
38-
+ "\nFor PyMC3 to work, Theano must be uninstalled and replaced with Theano-PyMC."
39-
+ "\nSee https://github.com/pymc-devs/pymc3/wiki for installation instructions.\n"
40-
+ "!" * 60
41-
)
34+
def _check_backend_version():
35+
backend_paths = theano.__spec__.submodule_search_locations
36+
try:
37+
backend_version = theano.__version__
38+
except:
39+
print(
40+
"!" * 60
41+
+ f"\nThe imported Theano(-PyMC) module is broken."
42+
+ f"\nIt was imported from {backend_paths}"
43+
+ "\nTry to uninstall/reinstall it after closing all active sessions/notebooks."
44+
+ "\nAlso see https://github.com/pymc-devs/pymc3/wiki for installation instructions.\n"
45+
+ "!" * 60
46+
)
47+
return
48+
if not semver.VersionInfo.parse(backend_version).match(">=1.1.2"):
49+
print(
50+
"!" * 60
51+
+ f"\nThe installed Theano(-PyMC) version ({theano.__version__}) does not match the PyMC3 requirements."
52+
+ f"\nIt was imported from {backend_paths}"
53+
+ "\nFor PyMC3 to work, a compatible Theano-PyMC backend version must be installed."
54+
+ "\nSee https://github.com/pymc-devs/pymc3/wiki for installation instructions.\n"
55+
+ "!" * 60
56+
)
4257

4358

4459
def __set_compiler_flags():
@@ -47,6 +62,7 @@ def __set_compiler_flags():
4762
theano.config.gcc__cxxflags = f"{current} -Wno-c++11-narrowing"
4863

4964

65+
_check_backend_version()
5066
__set_compiler_flags()
5167

5268
from pymc3 import gp, ode, sampling

pymc3/plots/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,6 @@ def compareplot(*args, **kwargs):
117117
"compareplot",
118118
"forestplot",
119119
"kdeplot",
120-
"plot_posterior",
121120
"traceplot",
122121
"energyplot",
123122
"densityplot",

pymc3/tests/test_util.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import numpy as np
1616
import pytest
17+
import theano
1718

1819
from cachetools import cached
1920
from numpy.testing import assert_almost_equal
@@ -25,6 +26,27 @@
2526
from pymc3.util import hash_key, hashable, locally_cachedmethod
2627

2728

29+
class TestBackendVersionCheck:
30+
def test_warn_on_incompatible_backend(self, capsys):
31+
assert not "!!!!!" in capsys.readouterr().out
32+
pm._check_backend_version()
33+
assert not "!!!!!" in capsys.readouterr().out
34+
35+
# Mock an incorrect backend version
36+
original = theano.__version__
37+
38+
theano.__version__ = "1.1.0"
39+
pm._check_backend_version()
40+
assert "does not match" in capsys.readouterr().out
41+
42+
del theano.__version__
43+
pm._check_backend_version()
44+
assert "is broken" in capsys.readouterr().out
45+
46+
theano.__version__ = original
47+
pass
48+
49+
2850
class TestTransformName:
2951
cases = [("var", "var_test__"), ("var_test_", "var_test__test__")]
3052
transform_name = "test"

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,6 @@ numpy>=1.15.0
66
pandas>=0.24.0
77
patsy>=0.5.1
88
scipy>=1.2.0
9-
semver
9+
semver>=2.13.0
1010
theano-pymc==1.1.2
1111
typing-extensions>=3.7.4

0 commit comments

Comments
 (0)