-
Notifications
You must be signed in to change notification settings - Fork 22
Description
Addresses pyro-ppl/pyro#2929
See design doc
This issue tracks changes needed to efficiently perform variable elimination in Gaussian graphical models with plates. While funsor.sum_product.sum_product() is a partial solution, we'd like to generalize to a complete solution.
Tasks
-
Introduce a new Funsor
ConditionalGaussian(info_vec, precision, conditional, inputs)representing the batched conditional distribution of the rightmost real input variable, conditioned on other real input variables. This could be (i) a new Funsor in addition toGaussian, (ii) a replacement or generalization ofGaussian, or (iii) a special case ofGaussianwhere the inputinfo_vecandprecisionare structured (requires Refactor Gaussian info_vec,precision from backend arrays to Funsors #556). This may allow cheaper linear algebra.Alternatively Switch to sqrt(prescision) representation in Gaussian? #567
Temporary Workaround: naively scatter the three parameters(info_vec, precision, conditional)into a denseGaussian. This can be much more computationally expensive. -
Handle collider variables where a latent variable outside a plate depends on an upstream latent variable inside a plate, thereby coupling the upstream variables via moralization. Currently such problems cannot even be specified in the plated-einsum DSL.
Temporary workaround: Globally break all plates out of which any arrow leads; equivalent to.to_event(). -
Handle complete bipartite graphs resulting from the RBM motif (
x_i --> y_ij <-- z_j). Currentlysum_product()and the TVE algorithm give up in this case with "intractable!".
Temporary workaround: no known workaround