Skip to content

Commit 9ab2e1e

Browse files
Update pymc_extras/inference/deterministic_advi/dadvi.py
Co-authored-by: Ricardo Vieira <[email protected]>
1 parent f17a090 commit 9ab2e1e

File tree

1 file changed

+3
-8
lines changed
  • pymc_extras/inference/deterministic_advi

1 file changed

+3
-8
lines changed

pymc_extras/inference/deterministic_advi/dadvi.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -160,16 +160,11 @@ def create_dadvi_graph(
160160
means = var_params[:n_params]
161161
log_sds = var_params[n_params:]
162162

163-
draw = pt.vector(name="draw", shape=(n_params,))
164-
sample = means + pt.exp(log_sds) * draw
165-
166-
# Graph in terms of a single sample
167-
logp_draw = pytensor.clone_replace(logp, replace={flat_input: sample})
168163
draw_matrix = pt.constant(draws)
169-
170-
# Vectorise
164+
samples = means + pt.exp(log_sds) * draw_matrix
165+
171166
logp_vectorized_draws = pytensor.graph.vectorize_graph(
172-
logp_draw, replace={draw: draw_matrix}
167+
logp, replace={flat_input: samples}
173168
)
174169

175170
mean_log_density = pt.mean(logp_vectorized_draws)

0 commit comments

Comments
 (0)