Skip to content

Commit 04db7c3

Browse files
refactor: add warning to d calculation
1 parent fda71d6 commit 04db7c3

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

pymc_extras/model/marginal/distributions.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -504,7 +504,12 @@ def laplace_marginal_rv_logp(op: MarginalLaplaceRV, values, *inputs, **kwargs):
504504
logp = pt.sum([pt.sum(logps_dict[k]) for k in logps_dict])
505505

506506
# Set minimizer initialisation to be random
507-
# Assumes that the observed variable y is the first/only element of values, and that d is shape[-1]
507+
# Assumes that the observed variable y is the only element in values, and that d is shape[-1] - if this is invalid it will simply crash rather than producing an invalid result.
508+
# A more robust method of obtaining d would be ideal.
509+
if len(values) > 1:
510+
warnings.warn(
511+
f"INLA assumes that the latent field {marginalized_vv.name} is of the same shape as the observables, however more than one input value to the logp was provided."
512+
)
508513
d = values[0].data.shape[-1]
509514
rng = np.random.default_rng(op.minimizer_seed)
510515
x0_init = rng.random(d)

0 commit comments

Comments
 (0)