We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent f17a090 commit 9ab2e1eCopy full SHA for 9ab2e1e
pymc_extras/inference/deterministic_advi/dadvi.py
@@ -160,16 +160,11 @@ def create_dadvi_graph(
160
means = var_params[:n_params]
161
log_sds = var_params[n_params:]
162
163
- draw = pt.vector(name="draw", shape=(n_params,))
164
- sample = means + pt.exp(log_sds) * draw
165
-
166
- # Graph in terms of a single sample
167
- logp_draw = pytensor.clone_replace(logp, replace={flat_input: sample})
168
draw_matrix = pt.constant(draws)
169
170
- # Vectorise
+ samples = means + pt.exp(log_sds) * draw_matrix
+
171
logp_vectorized_draws = pytensor.graph.vectorize_graph(
172
- logp_draw, replace={draw: draw_matrix}
+ logp, replace={flat_input: samples}
173
)
174
175
mean_log_density = pt.mean(logp_vectorized_draws)
0 commit comments