7777"""
7878
7979from collections import OrderedDict
80+ from functools import reduce
8081import warnings
8182
82- from jax import lax , random
83+ from jax import lax , random , tree_map
8384import jax .numpy as jnp
8485
8586import numpyro
86- from numpyro .distributions .distribution import COERCIONS , ExpandedDistribution
87+ from numpyro .distributions .distribution import COERCIONS , ExpandedDistribution , Independent
8788from numpyro .primitives import _PYRO_STACK , Messenger , apply_stack , plate
8889from numpyro .util import not_jax_tracer
8990
@@ -245,6 +246,20 @@ def process_message(self, msg):
245246 msg ['stop' ] = True
246247
247248
249+ def _eager_expand_fn (fn ):
250+ if isinstance (fn , Independent ):
251+ reinterpreted_batch_ndims = fn .reinterpreted_batch_ndims
252+ fn = fn .base_dist
253+ else :
254+ reinterpreted_batch_ndims = 0 # no-op for to_event method
255+ if isinstance (fn , ExpandedDistribution ):
256+ batch_shape = fn .batch_shape
257+ base_batch_shape = fn .base_dist .batch_shape
258+ appended_shape = batch_shape [:len (batch_shape ) - len (base_batch_shape )]
259+ fn = tree_map (lambda x : jnp .broadcast_to (x , appended_shape + jnp .shape (x )), fn .base_dist )
260+ return fn .to_event (reinterpreted_batch_ndims )
261+
262+
248263class collapse (trace ):
249264 """
250265 EXPERIMENTAL Collapses all sites in the context by lazily sampling and
@@ -263,33 +278,48 @@ def __init__(self, *args, **kwargs):
263278 super ().__init__ (* args , ** kwargs )
264279
265280 def process_message (self , msg ):
266- from funsor .terms import Funsor
281+ if msg ["type" ] != "sample" :
282+ return
267283
268- if msg ["type" ] == "sample" :
269- if msg ["value" ] is None :
270- msg ["value" ] = msg ["name" ]
271- if isinstance (msg ["fn" ], ExpandedDistribution ):
272- msg ["fn" ] = msg ["fn" ].base_dist
284+ import funsor
285+
286+ # Eagerly convert fn and value to Funsor.
287+ dim_to_name = {f .dim : f .name for f in msg ["cond_indep_stack" ]}
288+ dim_to_name .update (self .preserved_plates )
289+ if isinstance (msg ["fn" ], (Independent , ExpandedDistribution )):
290+ msg ["fn" ] = _eager_expand_fn (msg ["fn" ])
291+ msg ["fn" ] = funsor .to_funsor (msg ["fn" ], funsor .Real , dim_to_name )
292+ domain = msg ["fn" ].inputs ["value" ]
293+ if msg ["value" ] is None :
294+ msg ["value" ] = funsor .Variable (msg ["name" ], domain )
295+ else :
296+ msg ["value" ] = funsor .to_funsor (msg ["value" ], domain , dim_to_name )
273297
274- if isinstance (msg ["fn" ], Funsor ) or isinstance (msg ["value" ], (str , Funsor )):
275- msg ["stop" ] = True
298+ msg ["stop" ] = True
276299
277300 def __enter__ (self ):
278- self .preserved_plates = frozenset ( h .name for h in _PYRO_STACK
279- if isinstance (h , plate ))
301+ self .preserved_plates = { h . dim : h .name for h in _PYRO_STACK
302+ if isinstance (h , plate )}
280303 COERCIONS .append (self ._coerce )
281304 return super ().__enter__ ()
282305
283306 def __exit__ (self , exc_type , exc_value , traceback ):
284- import funsor
285-
286307 _coerce = COERCIONS .pop ()
287308 assert _coerce is self ._coerce
288309 super ().__exit__ (exc_type , exc_value , traceback )
289310
290311 if exc_type is not None :
312+ self .trace .clear ()
313+ self .preserved_plates .clear ()
291314 return
292315
316+ if any (site ["type" ] == "sample" for site in self .trace .values ()):
317+ name , log_prob , _ , _ = self ._get_log_prob ()
318+ numpyro .factor (name , log_prob .data )
319+
320+ def _get_log_prob (self ):
321+ import funsor
322+
293323 # Convert delayed statements to pyro.factor()
294324 reduced_vars = []
295325 log_prob_terms = []
@@ -299,24 +329,28 @@ def __exit__(self, exc_type, exc_value, traceback):
299329 continue
300330 if not site ["is_observed" ]:
301331 reduced_vars .append (name )
302- dim_to_name = {f .dim : f .name for f in site ["cond_indep_stack" ]}
303- fn = funsor .to_funsor (site ["fn" ], funsor .Real , dim_to_name )
304- value = site ["value" ]
305- if not isinstance (value , str ):
306- value = funsor .to_funsor (site ["value" ], fn .inputs ["value" ], dim_to_name )
307- log_prob_terms .append (fn (value = value ))
332+ log_prob_terms .append (site ["fn" ](value = site ["value" ]))
308333 plates |= frozenset (f .name for f in site ["cond_indep_stack" ])
309- assert log_prob_terms , "nothing to collapse"
310- reduced_plates = plates - self .preserved_plates
311- log_prob = funsor .sum_product .sum_product (
312- funsor .ops .logaddexp ,
313- funsor .ops .add ,
314- log_prob_terms ,
315- eliminate = frozenset (reduced_vars ) | reduced_plates ,
316- plates = plates ,
317- )
318334 name = reduced_vars [0 ]
319- numpyro .factor (name , log_prob .data )
335+ reduced_vars = frozenset (reduced_vars )
336+ assert log_prob_terms , "nothing to collapse"
337+ reduced_plates = plates - frozenset (self .preserved_plates .values ())
338+ self .trace .clear ()
339+ self .preserved_plates .clear ()
340+ if reduced_plates :
341+ log_prob = funsor .sum_product .sum_product (
342+ funsor .ops .logaddexp ,
343+ funsor .ops .add ,
344+ log_prob_terms ,
345+ eliminate = frozenset (reduced_vars ) | reduced_plates ,
346+ plates = plates ,
347+ )
348+ log_joint = NotImplemented
349+ else :
350+ log_joint = reduce (funsor .ops .add , log_prob_terms )
351+ log_prob = log_joint .reduce (funsor .ops .logaddexp , reduced_vars )
352+
353+ return name , log_prob , log_joint , reduced_vars
320354
321355
322356class condition (Messenger ):
0 commit comments