@@ -156,6 +156,9 @@ def kde2plot_op(ax, x, y, grid=200, **kwargs):
156
156
xmax = x .max ()
157
157
ymin = y .min ()
158
158
ymax = y .max ()
159
+ extent = kwargs .pop ('extent' , [])
160
+ if len (extent ) != 4 :
161
+ extent = [xmin , xmax , ymin , ymax ]
159
162
160
163
grid = grid * 1j
161
164
X , Y = np .mgrid [xmin :xmax :grid , ymin :ymax :grid ]
@@ -164,7 +167,7 @@ def kde2plot_op(ax, x, y, grid=200, **kwargs):
164
167
kernel = kde .gaussian_kde (values )
165
168
Z = np .reshape (kernel (positions ).T , X .shape )
166
169
167
- ax .imshow (np .rot90 (Z ), extent = [ xmin , xmax , ymin , ymax ] , ** kwargs )
170
+ ax .imshow (np .rot90 (Z ), extent = extent , ** kwargs )
168
171
169
172
170
173
def kdeplot (data , ax = None ):
@@ -363,9 +366,6 @@ def forestplot(trace_obj, varnames=None, transform=lambda x: x, alpha=0.05, quar
363
366
# Range for x-axis
364
367
plotrange = None
365
368
366
- # Number of chains
367
- chains = None
368
-
369
369
# Subplots
370
370
interval_plot = None
371
371
rhat_plot = None
@@ -648,26 +648,22 @@ def plot_posterior(trace, varnames=None, transform=lambda x: x, figsize=None,
648
648
"""
649
649
650
650
def plot_posterior_op (trace_values , ax ):
651
-
652
651
def format_as_percent (x , round_to = 0 ):
653
- value = np .round (100 * x , round_to )
654
- if round_to == 0 :
655
- value = int (value )
656
- return '{}%' .format (value )
652
+ return '{0:.{1:d}f}%' .format (100 * x , round_to )
657
653
658
654
def display_ref_val (ref_val ):
659
655
less_than_ref_probability = (trace_values < ref_val ).mean ()
660
656
greater_than_ref_probability = (trace_values >= ref_val ).mean ()
661
- ref_in_posterior = format_as_percent (less_than_ref_probability , 1 ) + ' <{:g}< ' .format (ref_val ) + format_as_percent (
662
- greater_than_ref_probability , 1 )
657
+ ref_in_posterior = "{} <{:g}< {}" .format (
658
+ format_as_percent (less_than_ref_probability , 1 ),
659
+ ref_val ,
660
+ format_as_percent (greater_than_ref_probability , 1 ))
663
661
ax .axvline (ref_val , ymin = 0.02 , ymax = .75 , color = 'g' ,
664
662
linewidth = 4 , alpha = 0.65 )
665
663
ax .text (trace_values .mean (), plot_height * 0.6 , ref_in_posterior ,
666
664
size = 14 , horizontalalignment = 'center' )
667
665
668
666
def display_rope (rope ):
669
- pc_in_rope = format_as_percent (np .sum ((trace_values > rope [0 ]) &
670
- (trace_values < rope [1 ])) / len (trace_values ), round_to )
671
667
ax .plot (rope , (plot_height * 0.02 , plot_height * 0.02 ),
672
668
linewidth = 20 , color = 'r' , alpha = 0.75 )
673
669
text_props = dict (size = 16 , horizontalalignment = 'center' , color = 'r' )
@@ -756,15 +752,15 @@ def create_axes_grid(figsize, traces):
756
752
ax [- 1 ].set_axis_off ()
757
753
ax = ax [:- 1 ]
758
754
return ax , fig
759
-
755
+
760
756
def get_trace_dict (tr , varnames ):
761
757
traces = {}
762
758
for v in varnames :
763
759
vals = tr .get_values (v , combine = True , squeeze = True )
764
- if vals .ndim > 1 :
760
+ if vals .ndim > 1 :
765
761
vals_flat = vals .reshape (vals .shape [0 ], - 1 ).T
766
- for i ,vi in enumerate (vals_flat ):
767
- traces ['_' .join ([v ,str (i )])] = vi
762
+ for i , vi in enumerate (vals_flat ):
763
+ traces ['_' .join ([v , str (i )])] = vi
768
764
else :
769
765
traces [v ] = vals
770
766
return traces
0 commit comments