2
2
import numpy as np
3
3
4
4
5
- def compareplot (comp_df , ax = None ):
5
+ def compareplot (comp_df , insample_dev = True , se = True , dse = True , ax = None ,
6
+ plot_kwargs = None ):
6
7
"""
7
8
Model comparison summary plot in the style of the one used in the book
8
9
Statistical Rethinking by Richard McElreath.
@@ -11,9 +12,22 @@ def compareplot(comp_df, ax=None):
11
12
----------
12
13
13
14
comp_df: DataFrame
14
- The result of the pm.compare() function
15
+ the result of the `pm.compare()` function
16
+ insample_dev : bool
17
+ plot the in-sample deviance, that is the value of the IC without the
18
+ penalization given by the effective number of parameters (pIC).
19
+ Defaults to True
20
+ se : bool
21
+ plot the standard error of the IC estimate. Defaults to True
22
+ dse : bool
23
+ plot standard error of the difference in IC between each model and the
24
+ top-ranked model. Defaults to True
25
+ plot_kwargs : dict
26
+ Optional arguments for plot elements. Currently accepts 'color_ic',
27
+ 'marker_ic', 'color_insample_dev', 'marker_insample_dev', 'color_dse',
28
+ 'marker_dse', 'ls_min_ic' 'color_ls_min_ic', 'fontsize'
15
29
ax : axes
16
- Matplotlib axes. Defaults to None.
30
+ Matplotlib axes. Defaults to None
17
31
18
32
Returns
19
33
-------
@@ -24,26 +38,59 @@ def compareplot(comp_df, ax=None):
24
38
if ax is None :
25
39
_ , ax = plt .subplots ()
26
40
27
- yticks_pos , step = np .linspace (0 , - 1 , (comp_df .shape [0 ] * 2 ) - 1 , retstep = True )
41
+ if plot_kwargs is None :
42
+ plot_kwargs = {}
43
+
44
+ yticks_pos , step = np .linspace (0 , - 1 , (comp_df .shape [0 ] * 2 ) - 1 ,
45
+ retstep = True )
28
46
yticks_pos [1 ::2 ] = yticks_pos [1 ::2 ] + step / 2
29
47
30
48
yticks_labels = ['' ] * len (yticks_pos )
31
- yticks_labels [0 ] = comp_df .index [0 ]
32
- yticks_labels [1 ::2 ] = comp_df .index [1 :]
33
49
34
- data = comp_df .values
35
- min_ic = data [0 , 0 ]
50
+ if dse :
51
+ yticks_labels [0 ] = comp_df .index [0 ]
52
+ yticks_labels [2 ::2 ] = comp_df .index [1 :]
53
+ ax .set_yticks (yticks_pos )
54
+ ax .errorbar (x = comp_df .WAIC [1 :],
55
+ y = yticks_pos [1 ::2 ],
56
+ xerr = comp_df .dSE [1 :],
57
+ color = plot_kwargs .get ('color_dse' , 'grey' ),
58
+ fmt = plot_kwargs .get ('marker_dse' , '^' ))
59
+
60
+ else :
61
+ yticks_labels = comp_df .index
62
+ ax .set_yticks (yticks_pos [::2 ])
63
+
64
+ if se :
65
+ ax .errorbar (x = comp_df .WAIC ,
66
+ y = yticks_pos [::2 ],
67
+ xerr = comp_df .SE ,
68
+ color = plot_kwargs .get ('color_ic' , 'k' ),
69
+ fmt = plot_kwargs .get ('marker_ic' , 'o' ),
70
+ mfc = 'None' ,
71
+ mew = 1 )
72
+ else :
73
+ ax .plot (comp_df .WAIC ,
74
+ yticks_pos [::2 ],
75
+ color = plot_kwargs .get ('color_ic' , 'k' ),
76
+ marker = plot_kwargs .get ('marker_ic' , 'o' ),
77
+ mfc = 'None' ,
78
+ mew = 1 ,
79
+ lw = 0 )
36
80
37
- ax .errorbar (x = data [:, 0 ], y = yticks_pos [::2 ], xerr = data [:, 4 ],
38
- fmt = 'ko' , mfc = 'None' , mew = 1 )
39
- ax .errorbar (x = data [1 :, 0 ], y = yticks_pos [1 ::2 ],
40
- xerr = data [1 :, 5 ], fmt = '^' , color = 'grey' )
81
+ if insample_dev :
82
+ ax .plot (comp_df .WAIC - (2 * comp_df .pWAIC ),
83
+ yticks_pos [::2 ],
84
+ color = plot_kwargs .get ('color_insample_dev' , 'k' ),
85
+ marker = plot_kwargs .get ('marker_insample_dev' , 'o' ),
86
+ lw = 0 )
41
87
42
- ax .plot (data [:, 0 ] - (2 * data [:, 1 ]), yticks_pos [::2 ], 'ko' )
43
- ax .axvline (min_ic , ls = '--' , color = 'grey' )
88
+ ax .axvline (comp_df .WAIC [0 ],
89
+ ls = plot_kwargs .get ('ls_min_ic' , '--' ),
90
+ color = plot_kwargs .get ('color_ls_min_ic' , 'grey' ))
44
91
45
- ax .set_yticks ( yticks_pos )
92
+ ax .set_xlabel ( 'Deviance' , fontsize = plot_kwargs . get ( 'fontsize' , 14 ) )
46
93
ax .set_yticklabels (yticks_labels )
47
- ax .set_xlabel ( 'Deviance' )
94
+ ax .set_ylim ( - 1 + step , 0 - step )
48
95
49
96
return ax
0 commit comments