Skip to content

Commit c378850

Browse files
committed
dynamically sets batches to min(n_samples, 100); closes #1606
1 parent 7584b3b commit c378850

File tree

1 file changed

+19
-12
lines changed

1 file changed

+19
-12
lines changed

pymc3/stats.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -365,7 +365,7 @@ def mc_error(x, batches=5):
365365
x : Numpy array
366366
An array containing MCMC samples
367367
batches : integer
368-
Number of batchas
368+
Number of batches
369369
"""
370370

371371
if x.ndim > 1:
@@ -428,7 +428,7 @@ def quantiles(x, qlist=(2.5, 25, 50, 75, 97.5), transform=lambda x: x):
428428

429429

430430
def df_summary(trace, varnames=None, stat_funcs=None, extend=False, include_transformed=False,
431-
alpha=0.05, batches=100):
431+
alpha=0.05, batches=None):
432432
R"""Create a data frame with summary statistics.
433433
434434
Parameters
@@ -458,15 +458,15 @@ def df_summary(trace, varnames=None, stat_funcs=None, extend=False, include_tran
458458
addition to, rather than in place of, the default statistics.
459459
This is only meaningful when `stat_funcs` is not None.
460460
include_transformed : bool
461-
Flag for reporting automatically transformed variables in addition to
462-
original variables (defaults to False).
461+
Flag for reporting automatically transformed variables in addition
462+
to original variables (defaults to False).
463463
alpha : float
464464
The alpha level for generating posterior intervals. Defaults
465465
to 0.05. This is only meaningful when `stat_funcs` is None.
466-
batches : int
467-
Batch size for calculating standard deviation for
468-
non-independent samples. Defaults to 100. This is only
469-
meaningful when `stat_funcs` is None.
466+
batches : None or int
467+
Batch size for calculating standard deviation for non-independent
468+
samples. Defaults to the smaller of 100 or the number of samples.
469+
This is only meaningful when `stat_funcs` is None.
470470
471471
472472
See also
@@ -509,6 +509,9 @@ def df_summary(trace, varnames=None, stat_funcs=None, extend=False, include_tran
509509
else:
510510
varnames = [name for name in trace.varnames if not name.endswith('_')]
511511

512+
if batches is None:
513+
batches = min([100, len(trace)])
514+
512515
funcs = [lambda x: pd.Series(np.mean(x, 0), name='mean'),
513516
lambda x: pd.Series(np.std(x, 0), name='sd'),
514517
lambda x: pd.Series(mc_error(x, batches), name='mc_error'),
@@ -535,7 +538,7 @@ def _hpd_df(x, alpha):
535538
return pd.DataFrame(hpd(x, alpha), columns=cnames)
536539

537540

538-
def summary(trace, varnames=None, alpha=0.05, start=0, batches=100, roundto=3,
541+
def summary(trace, varnames=None, alpha=0.05, start=0, batches=None, roundto=3,
539542
include_transformed=False, to_file=None):
540543
R"""
541544
Generate a pretty-printed summary of the node.
@@ -553,9 +556,10 @@ def summary(trace, varnames=None, alpha=0.05, start=0, batches=100, roundto=3,
553556
start : int
554557
The starting index from which to summarize (each) chain. Defaults
555558
to zero.
556-
batches : int
557-
Batch size for calculating standard deviation for non-independent
558-
samples. Defaults to 100.
559+
batches : None or int
560+
Batch size for calculating standard deviation for non-independent
561+
samples. Defaults to the smaller of 100 or the number of samples.
562+
This is only meaningful when `stat_funcs` is None.
559563
roundto : int
560564
The number of digits to round posterior statistics.
561565
include_transformed : bool
@@ -571,6 +575,9 @@ def summary(trace, varnames=None, alpha=0.05, start=0, batches=100, roundto=3,
571575
else:
572576
varnames = [name for name in trace.varnames if not name.endswith('_')]
573577

578+
if batches is None:
579+
batches = min([100, len(trace)])
580+
574581
stat_summ = _StatSummary(roundto, batches, alpha)
575582
pq_summ = _PosteriorQuantileSummary(roundto, alpha)
576583

0 commit comments

Comments
 (0)