Skip to content

Commit cb070aa

Browse files
author
Martin Ingram
committed
Replace with pt.split
1 parent 7cd407e commit cb070aa

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

pymc_extras/inference/dadvi/dadvi.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,8 @@ def create_dadvi_graph(
197197
)
198198

199199
var_params = pt.vector(name="eta", shape=(2 * n_params,))
200-
means, log_sds = var_params[:n_params], var_params[n_params:]
200+
201+
means, log_sds = pt.split(var_params, axis=0, splits_size=[n_params, n_params], n_splits=2)
201202

202203
draw_matrix = pt.constant(draws)
203204
samples = means + pt.exp(log_sds) * draw_matrix

0 commit comments

Comments
 (0)