Skip to content

Commit ac09d78

Browse files
committed
make expand works under funsor
1 parent cfdc30d commit ac09d78

File tree

2 files changed

+71
-38
lines changed

2 files changed

+71
-38
lines changed

numpyro/handlers.py

Lines changed: 64 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -77,13 +77,14 @@
7777
"""
7878

7979
from collections import OrderedDict
80+
from functools import reduce
8081
import warnings
8182

82-
from jax import lax, random
83+
from jax import lax, random, tree_map
8384
import jax.numpy as jnp
8485

8586
import numpyro
86-
from numpyro.distributions.distribution import COERCIONS, ExpandedDistribution
87+
from numpyro.distributions.distribution import COERCIONS, ExpandedDistribution, Independent
8788
from numpyro.primitives import _PYRO_STACK, Messenger, apply_stack, plate
8889
from numpyro.util import not_jax_tracer
8990

@@ -245,6 +246,20 @@ def process_message(self, msg):
245246
msg['stop'] = True
246247

247248

249+
def _eager_expand_fn(fn):
250+
if isinstance(fn, Independent):
251+
reinterpreted_batch_ndims = fn.reinterpreted_batch_ndims
252+
fn = fn.base_dist
253+
else:
254+
reinterpreted_batch_ndims = 0 # no-op for to_event method
255+
if isinstance(fn, ExpandedDistribution):
256+
batch_shape = fn.batch_shape
257+
base_batch_shape = fn.base_dist.batch_shape
258+
appended_shape = batch_shape[:len(batch_shape) - len(base_batch_shape)]
259+
fn = tree_map(lambda x: jnp.broadcast_to(x, appended_shape + jnp.shape(x)), fn.base_dist)
260+
return fn.to_event(reinterpreted_batch_ndims)
261+
262+
248263
class collapse(trace):
249264
"""
250265
EXPERIMENTAL Collapses all sites in the context by lazily sampling and
@@ -263,33 +278,48 @@ def __init__(self, *args, **kwargs):
263278
super().__init__(*args, **kwargs)
264279

265280
def process_message(self, msg):
266-
from funsor.terms import Funsor
281+
if msg["type"] != "sample":
282+
return
267283

268-
if msg["type"] == "sample":
269-
if msg["value"] is None:
270-
msg["value"] = msg["name"]
271-
if isinstance(msg["fn"], ExpandedDistribution):
272-
msg["fn"] = msg["fn"].base_dist
284+
import funsor
285+
286+
# Eagerly convert fn and value to Funsor.
287+
dim_to_name = {f.dim: f.name for f in msg["cond_indep_stack"]}
288+
dim_to_name.update(self.preserved_plates)
289+
if isinstance(msg["fn"], (Independent, ExpandedDistribution)):
290+
msg["fn"] = _eager_expand_fn(msg["fn"])
291+
msg["fn"] = funsor.to_funsor(msg["fn"], funsor.Real, dim_to_name)
292+
domain = msg["fn"].inputs["value"]
293+
if msg["value"] is None:
294+
msg["value"] = funsor.Variable(msg["name"], domain)
295+
else:
296+
msg["value"] = funsor.to_funsor(msg["value"], domain, dim_to_name)
273297

274-
if isinstance(msg["fn"], Funsor) or isinstance(msg["value"], (str, Funsor)):
275-
msg["stop"] = True
298+
msg["stop"] = True
276299

277300
def __enter__(self):
278-
self.preserved_plates = frozenset(h.name for h in _PYRO_STACK
279-
if isinstance(h, plate))
301+
self.preserved_plates = {h.dim: h.name for h in _PYRO_STACK
302+
if isinstance(h, plate)}
280303
COERCIONS.append(self._coerce)
281304
return super().__enter__()
282305

283306
def __exit__(self, exc_type, exc_value, traceback):
284-
import funsor
285-
286307
_coerce = COERCIONS.pop()
287308
assert _coerce is self._coerce
288309
super().__exit__(exc_type, exc_value, traceback)
289310

290311
if exc_type is not None:
312+
self.trace.clear()
313+
self.preserved_plates.clear()
291314
return
292315

316+
if any(site["type"] == "sample" for site in self.trace.values()):
317+
name, log_prob, _, _ = self._get_log_prob()
318+
numpyro.factor(name, log_prob.data)
319+
320+
def _get_log_prob(self):
321+
import funsor
322+
293323
# Convert delayed statements to pyro.factor()
294324
reduced_vars = []
295325
log_prob_terms = []
@@ -299,24 +329,28 @@ def __exit__(self, exc_type, exc_value, traceback):
299329
continue
300330
if not site["is_observed"]:
301331
reduced_vars.append(name)
302-
dim_to_name = {f.dim: f.name for f in site["cond_indep_stack"]}
303-
fn = funsor.to_funsor(site["fn"], funsor.Real, dim_to_name)
304-
value = site["value"]
305-
if not isinstance(value, str):
306-
value = funsor.to_funsor(site["value"], fn.inputs["value"], dim_to_name)
307-
log_prob_terms.append(fn(value=value))
332+
log_prob_terms.append(site["fn"](value=site["value"]))
308333
plates |= frozenset(f.name for f in site["cond_indep_stack"])
309-
assert log_prob_terms, "nothing to collapse"
310-
reduced_plates = plates - self.preserved_plates
311-
log_prob = funsor.sum_product.sum_product(
312-
funsor.ops.logaddexp,
313-
funsor.ops.add,
314-
log_prob_terms,
315-
eliminate=frozenset(reduced_vars) | reduced_plates,
316-
plates=plates,
317-
)
318334
name = reduced_vars[0]
319-
numpyro.factor(name, log_prob.data)
335+
reduced_vars = frozenset(reduced_vars)
336+
assert log_prob_terms, "nothing to collapse"
337+
reduced_plates = plates - frozenset(self.preserved_plates.values())
338+
self.trace.clear()
339+
self.preserved_plates.clear()
340+
if reduced_plates:
341+
log_prob = funsor.sum_product.sum_product(
342+
funsor.ops.logaddexp,
343+
funsor.ops.add,
344+
log_prob_terms,
345+
eliminate=frozenset(reduced_vars) | reduced_plates,
346+
plates=plates,
347+
)
348+
log_joint = NotImplemented
349+
else:
350+
log_joint = reduce(funsor.ops.add, log_prob_terms)
351+
log_prob = log_joint.reduce(funsor.ops.logaddexp, reduced_vars)
352+
353+
return name, log_prob, log_joint, reduced_vars
320354

321355

322356
class condition(Messenger):

test/test_handlers.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -656,15 +656,14 @@ def test_collapse_normal_mvn_mvn():
656656
def model():
657657
x = numpyro.sample("x", dist.Exponential(1))
658658
with handlers.collapse():
659-
with numpyro.plate("d", d):
660-
# TODO: verify that to_event works here
661-
beta0 = numpyro.sample("beta0", dist.Normal(0, 1).expand([S]).to_event(1))
662-
# TODO: address beta0 is a str, which cannot do infer_param_domain
663-
beta = numpyro.sample("beta", dist.MultivariateNormal(beta0, jnp.eye(S)))
664-
# FIXME: beta is a string here, how to apply numeric operators
659+
with numpyro.plate("d", d, dim=-1):
660+
beta0 = numpyro.sample("beta0", dist.Normal(0., 1.).expand([d, S]).to_event(1))
661+
beta = numpyro.sample(
662+
"beta", dist.MultivariateNormal(beta0, scale_tril=jnp.eye(S)))
663+
665664
mean = jnp.ones((T, d)) @ beta
666-
with numpyro.plate("data", T, dim=-2):
667-
numpyro.sample("obs", dist.MultivariateNormal(mean, jnp.eye(S)), obs=data)
665+
with numpyro.plate("data", T, dim=-1):
666+
numpyro.sample("obs", dist.MultivariateNormal(mean, scale_tril=jnp.eye(S)), obs=data)
668667

669668
def guide():
670669
rate = numpyro.param("rate", 1., constraint=constraints.positive)

0 commit comments

Comments
 (0)