@@ -112,21 +112,34 @@ def waic(trace, model=None, n_eff=False):
112
112
return - 2 * lppd + 2 * p_waic , p_waic
113
113
else :
114
114
return - 2 * lppd + 2 * p_waic
115
-
116
- def loo (trace , model = None ):
115
+
116
+ def loo (trace , model = None , n_eff = False ):
117
117
"""
118
118
Calculates leave-one-out (LOO) cross-validation for out of sample predictive
119
119
model fit, following Vehtari et al. (2015). Cross-validation is computed using
120
120
Pareto-smoothed importance sampling (PSIS).
121
121
122
- Returns log pointwise predictive density calculated via approximated LOO cross-validation.
122
+ Parameters
123
+ ----------
124
+ trace : result of MCMC run
125
+ model : PyMC Model
126
+ Optional model. Default None, taken from context.
127
+ n_eff: bool
128
+ if True the effective number parameters will be computed and returned.
129
+ Default False
130
+
131
+ Returns
132
+ -------
133
+ elpd_loo: log pointwise predictive density calculated via approximated LOO cross-validation
134
+ p_loo: effective number parameters, only if n_eff True
123
135
"""
124
136
model = modelcontext (model )
125
137
126
138
log_py = log_post_trace (trace , model )
127
139
128
140
# Importance ratios
129
- r = 1. / np .exp (log_py )
141
+ py = np .exp (log_py )
142
+ r = 1. / py
130
143
r_sorted = np .sort (r , axis = 0 )
131
144
132
145
# Extract largest 20% of importance ratios and fit generalized Pareto to each
@@ -154,11 +167,14 @@ def loo(trace, model=None):
154
167
# Truncate weights to guarantee finite variance
155
168
w = np .minimum (r_new , r_new .mean (axis = 0 ) * S ** 0.75 )
156
169
157
- loo_lppd = np .sum (np .log (np .sum (w * np . exp ( log_py ) , axis = 0 ) / np .sum (w , axis = 0 )))
170
+ loo_lppd = np .sum (np .log (np .sum (w * py , axis = 0 ) / np .sum (w , axis = 0 )))
158
171
159
- return loo_lppd
172
+ if n_eff :
173
+ p_loo = np .sum (np .log (np .mean (py , axis = 0 ))) - loo_lppd
174
+ return - 2 * loo_lppd , p_loo
175
+ else :
176
+ return - 2 * loo_lppd
160
177
161
-
162
178
def bpic (trace , model = None ):
163
179
"""
164
180
Calculates Bayesian predictive information criterion n of the samples in trace from model
0 commit comments