@@ -365,7 +365,7 @@ def mc_error(x, batches=5):
365
365
x : Numpy array
366
366
An array containing MCMC samples
367
367
batches : integer
368
- Number of batchas
368
+ Number of batches
369
369
"""
370
370
371
371
if x .ndim > 1 :
@@ -428,7 +428,7 @@ def quantiles(x, qlist=(2.5, 25, 50, 75, 97.5), transform=lambda x: x):
428
428
429
429
430
430
def df_summary (trace , varnames = None , stat_funcs = None , extend = False , include_transformed = False ,
431
- alpha = 0.05 , batches = 100 ):
431
+ alpha = 0.05 , batches = None ):
432
432
R"""Create a data frame with summary statistics.
433
433
434
434
Parameters
@@ -458,15 +458,15 @@ def df_summary(trace, varnames=None, stat_funcs=None, extend=False, include_tran
458
458
addition to, rather than in place of, the default statistics.
459
459
This is only meaningful when `stat_funcs` is not None.
460
460
include_transformed : bool
461
- Flag for reporting automatically transformed variables in addition to
462
- original variables (defaults to False).
461
+ Flag for reporting automatically transformed variables in addition
462
+ to original variables (defaults to False).
463
463
alpha : float
464
464
The alpha level for generating posterior intervals. Defaults
465
465
to 0.05. This is only meaningful when `stat_funcs` is None.
466
- batches : int
467
- Batch size for calculating standard deviation for
468
- non-independent samples. Defaults to 100. This is only
469
- meaningful when `stat_funcs` is None.
466
+ batches : None or int
467
+ Batch size for calculating standard deviation for non-independent
468
+ samples. Defaults to the smaller of 100 or the number of samples.
469
+ This is only meaningful when `stat_funcs` is None.
470
470
471
471
472
472
See also
@@ -509,6 +509,9 @@ def df_summary(trace, varnames=None, stat_funcs=None, extend=False, include_tran
509
509
else :
510
510
varnames = [name for name in trace .varnames if not name .endswith ('_' )]
511
511
512
+ if batches is None :
513
+ batches = min ([100 , len (trace )])
514
+
512
515
funcs = [lambda x : pd .Series (np .mean (x , 0 ), name = 'mean' ),
513
516
lambda x : pd .Series (np .std (x , 0 ), name = 'sd' ),
514
517
lambda x : pd .Series (mc_error (x , batches ), name = 'mc_error' ),
@@ -535,7 +538,7 @@ def _hpd_df(x, alpha):
535
538
return pd .DataFrame (hpd (x , alpha ), columns = cnames )
536
539
537
540
538
- def summary (trace , varnames = None , alpha = 0.05 , start = 0 , batches = 100 , roundto = 3 ,
541
+ def summary (trace , varnames = None , alpha = 0.05 , start = 0 , batches = None , roundto = 3 ,
539
542
include_transformed = False , to_file = None ):
540
543
R"""
541
544
Generate a pretty-printed summary of the node.
@@ -553,9 +556,10 @@ def summary(trace, varnames=None, alpha=0.05, start=0, batches=100, roundto=3,
553
556
start : int
554
557
The starting index from which to summarize (each) chain. Defaults
555
558
to zero.
556
- batches : int
557
- Batch size for calculating standard deviation for non-independent
558
- samples. Defaults to 100.
559
+ batches : None or int
560
+ Batch size for calculating standard deviation for non-independent
561
+ samples. Defaults to the smaller of 100 or the number of samples.
562
+ This is only meaningful when `stat_funcs` is None.
559
563
roundto : int
560
564
The number of digits to round posterior statistics.
561
565
include_transformed : bool
@@ -571,6 +575,9 @@ def summary(trace, varnames=None, alpha=0.05, start=0, batches=100, roundto=3,
571
575
else :
572
576
varnames = [name for name in trace .varnames if not name .endswith ('_' )]
573
577
578
+ if batches is None :
579
+ batches = min ([100 , len (trace )])
580
+
574
581
stat_summ = _StatSummary (roundto , batches , alpha )
575
582
pq_summ = _PosteriorQuantileSummary (roundto , alpha )
576
583
0 commit comments