Skip to content

Commit 34dfdfa

Browse files
set d automatically
1 parent dd54a37 commit 34dfdfa

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

pymc_extras/model/marginal/distributions.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -438,9 +438,11 @@ def laplace_marginal_rv_logp(op: MarginalLaplaceRV, values, *inputs, **kwargs):
438438
)
439439

440440
# Set minimizer initialisation to be random
441-
d = 3 # 10000 # TODO pull this from x.shape (or similar) somehow
441+
# TODO Assumes that the observed variable y is the first/only element of values, and that d is shape[-1]
442+
d = values[0].data.shape[-1]
442443
rng = np.random.default_rng(12345)
443-
x0 = pytensor.graph.replace.graph_replace(x0, {marginalized_vv: rng.random(d)})
444+
x0_init = rng.random(d)
445+
x0 = pytensor.graph.replace.graph_replace(x0, {marginalized_vv: x0_init})
444446

445447
# TODO USE CLOSED FORM SOLUTION FOR NOW
446448
n, y_obs = op.temp_kwargs

0 commit comments

Comments
 (0)