Skip to content

Commit 3ca5f10

Browse files
committed
Keep old function names too
1 parent b3b45da commit 3ca5f10

File tree

1 file changed

+44
-9
lines changed

1 file changed

+44
-9
lines changed

pymc3/plots/__init__.py

Lines changed: 44 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,16 @@
44
"exploratory analysis of Bayesian models." See https://arviz-devs.github.io/arviz/
55
for details on plots.
66
"""
7+
import functools
8+
import sys
9+
import warnings
710
try:
811
import arviz as az
912
except ImportError: # arviz is optional, throw exception when used
1013

1114
class _ImportWarner:
15+
__all__ = []
16+
1217
def __init__(self, attr):
1318
self.attr = attr
1419

@@ -21,17 +26,47 @@ class _ArviZ:
2126
def __getattr__(self, attr):
2227
return _ImportWarner(attr)
2328

29+
2430
az = _ArviZ()
2531

32+
def map_args(func):
33+
swaps = [
34+
('varnames', 'var_names')
35+
]
36+
@functools.wraps(func)
37+
def wrapped(*args, **kwargs):
38+
for (old, new) in swaps:
39+
if old in kwargs and new not in kwargs:
40+
warnings.warn('Keyword argument `{old}` renamed to `{new}`, and will be removed in pymc3 3.8'.format(old=old, new=new))
41+
kwargs[new] = kwargs.pop(old)
42+
return func(*args, **kwargs)
43+
return wrapped
2644

27-
autocorrplot = az.plot_autocorr
28-
compareplot = az.plot_compare
29-
forestplot = az.plot_forest
30-
kdeplot = az.plot_kde
31-
plot_posterior = az.plot_posterior
32-
traceplot = az.plot_trace
33-
energyplot = az.plot_energy
34-
densityplot = az.plot_density
35-
pairplot = az.plot_pair
45+
autocorrplot = map_args(az.plot_autocorr)
46+
compareplot = map_args(az.plot_compare)
47+
forestplot = map_args(az.plot_forest)
48+
kdeplot = map_args(az.plot_kde)
49+
plot_posterior = map_args(az.plot_posterior)
50+
traceplot = map_args(az.plot_trace)
51+
energyplot = map_args(az.plot_energy)
52+
densityplot = map_args(az.plot_density)
53+
pairplot = map_args(az.plot_pair)
3654

3755
from .posteriorplot import plot_posterior_predictive_glm
56+
57+
58+
for plot in az.plots.__all__:
59+
setattr(sys.modules[__name__], plot, map_args(getattr(az.plots, plot)))
60+
61+
__all__ = tuple(az.plots.__all__) + (
62+
'autocorrplot',
63+
'compareplot',
64+
'forestplot',
65+
'kdeplot',
66+
'plot_posterior',
67+
'traceplot',
68+
'energyplot',
69+
'densityplot',
70+
'pairplot',
71+
'plot_posterior_predictive_glm',
72+
)

0 commit comments

Comments
 (0)