diff --git a/funsor/jax/ops.py b/funsor/jax/ops.py index 0b173959..d8b82e93 100644 --- a/funsor/jax/ops.py +++ b/funsor/jax/ops.py @@ -286,6 +286,9 @@ def _safesub(x, y): @ops.scatter.register(array, tuple, array) def _scatter(dest, indices, src): + missing = len(indices) - len(dest.shape) + if missing > 0: + dest = dest[(None,) * missing] return index_update(dest, indices, src) diff --git a/funsor/joint.py b/funsor/joint.py index 3f7a78b4..9244d90f 100644 --- a/funsor/joint.py +++ b/funsor/joint.py @@ -16,7 +16,7 @@ from funsor.interpretations import eager, moment_matching, normalize from funsor.ops import AssociativeOp from funsor.tensor import Tensor, align_tensor -from funsor.terms import Funsor, Independent, Number, Reduce, Unary +from funsor.terms import Funsor, Independent, Number, Reduce, Scatter, Unary from funsor.typing import Variadic @@ -74,6 +74,26 @@ def eager_cat_homogeneous(name, part_name, *parts): return result +# FIXME this is too aggressive, but does fix some numpyro tests in +# motivated by https://github.com/pyro-ppl/numpyro/pull/991 +@eager.register( + Contraction, + AssociativeOp, + AssociativeOp, + frozenset, + Delta[Tuple[Tuple[str, Tuple[Tensor, Number]], ...]], + Tensor, +) +def eager_delta_to_scatter(red_op, bin_op, reduced_vars, delta, tensor): + source = tensor + subs = {} + for name, (point, log_density) in delta.terms: + subs[name] = point + source = bin_op(source, log_density) + subs = tuple(subs.items()) + return Scatter(red_op, subs, source, reduced_vars) + + ################################# # patterns for moment-matching ################################# diff --git a/funsor/tensor.py b/funsor/tensor.py index 1c5853d3..878722f7 100644 --- a/funsor/tensor.py +++ b/funsor/tensor.py @@ -3,6 +3,7 @@ import functools import itertools +import math import typing import warnings from collections import Counter, OrderedDict @@ -636,6 +637,7 @@ def eager_scatter_number(op, subs, source, reduced_vars): return eager_scatter_tensor(op, subs, source, reduced_vars) +# FIXME Does this blow out one-hot tensors, using unnecessarily much memory? @eager.register(Scatter, Op, tuple, Tensor, frozenset) def eager_scatter_tensor(op, subs, source, reduced_vars): if not all(isinstance(v, (Variable, Number, Slice, Tensor)) for k, v in subs): @@ -676,7 +678,7 @@ def eager_scatter_tensor(op, subs, source, reduced_vars): # Construct a destination backend tensor. output = source.output shape = tuple(d.size for d in destin_inputs.values()) + output.shape - destin = ops.new_full(source.data, shape, ops.UNITS[op]) + destin = ops.new_full(source.data, shape, ops.UNITS.get(op, math.nan)) # TODO Add a check for injectivity and dispatch to scatter_add etc. data = ops.scatter(destin, indices, source_data) diff --git a/test/test_tensor.py b/test/test_tensor.py index a2a733f6..bae36b36 100644 --- a/test/test_tensor.py +++ b/test/test_tensor.py @@ -1475,12 +1475,7 @@ def raw_reduction(x, dim=None, keepdims=False, batch_ndims=len(batch_shape)): def test_scatter_substitute(): expr = Scatter( ops.logaddexp, - ( - ( - "_time_states_38", - Number(0, 1), - ), - ), + (("_time_states_38", Number(0, 1))), Contraction( ops.null, ops.add, @@ -1511,3 +1506,288 @@ def test_scatter_substitute(): ) expr(_time_states_38="_time_states") + + +# motivated by https://github.com/pyro-ppl/numpyro/pull/991 +def test_scatter_dims_error(): + op = ops.null + subs = (('_drop_0', + Tensor( + np.array([[[0, 0], [0, 0]]], dtype=np.int32), + (('_time_states', Bint[1]), + ('_PREV_states', Bint[2]), + ('states', Bint[2])), + 2),),) + # This source is invalid because the input _drop_0 should have been + # substituted away by the above subs. + source = Tensor( + np.array([[[-1.7461166381835938, -3.480717658996582], + [-5.133678436279297, -3.3990774154663086]]], dtype=np.float32), + (('_time_states', Bint[1]), + ('states', Bint[2]), + ('_drop_0', Bint[2])), + 'real') + reduced_vars = frozenset() + + with pytest.raises(Exception): + Scatter(op, subs, source, reduced_vars) + + +# motivated by https://github.com/pyro-ppl/numpyro/pull/991 +def test_infer_discrete_hmm_scan_1(): + from math import inf + actual = Scatter( + ops.max, + (('_time_states', Slice('_time_states__BOUND_21', 1, 2, 2, 2)), + ('_PREV_states', Variable('_drop_0__BOUND_20', Bint[2]))), + Contraction(ops.max, ops.add, + frozenset({Variable('_PREV_states__BOUND_14', Bint[2])}), + (Delta( + (('_drop_0__BOUND_20', + (Tensor( + np.array([[[0, 0], [0, 0]]], dtype=np.int32), + (('_time_states__BOUND_21', Bint[1]), + ('_PREV_states__BOUND_14', Bint[2]), + ('states', Bint[2])), + 2), + Number(0.0),),),)), + Tensor( + np.array([[[-inf, -inf], [-inf, -inf]]], dtype=np.float32), + (('_time_states__BOUND_21', Bint[1]), + ('_PREV_states__BOUND_14', Bint[2]), + ('states', Bint[2])), + 'real'),)), + frozenset({Variable('_drop_0__BOUND_20', Bint[2]), Variable('_time_states__BOUND_21', Bint[1])})) + assert isinstance(actual, Tensor), actual.pretty() + + +def test_infer_discrete_hmm_scan_2(): + from math import inf, nan + actual = Contraction(ops.max, ops.add, + frozenset({Variable('states__BOUND_27', Bint[2])}), + (Delta( + (('_drop_0__BOUND_46', + (Tensor( + np.array([[[0, 1], [1, 1]]], dtype=np.int32), + (('_time_states__BOUND_45', + Bint[1],), + ('_PREV_states', + Bint[2],), + ('states__BOUND_27', + Bint[2],),), + 2), + Number(0.0),),),)), + Tensor( + np.array([[[-2.827766537666321], [nan]], [[nan], [nan]]], dtype=np.float64), # noqa + (('_PREV_states', + Bint[2],), + ('states__BOUND_27', + Bint[2],), + ('_time_states__BOUND_45', + Bint[1],),), + 'real'),)) + assert isinstance(actual, Tensor), actual.pretty() + + +def test_infer_discrete_hmm_scan_3(): + from math import inf + actual = Contraction(ops.null, ops.max, + frozenset(), + (Contraction(ops.null, ops.add, + frozenset(), + (Tensor( + np.array([[-2.110203742980957, -1.0666589736938477], [-1.5835977792739868, -3.236558437347412]], dtype=np.float32), # noqa + (('_time_states', + Bint[2],), + ('states', + Bint[2],),), + 'real'), + Scatter( + ops.max, + (('_time_states', + Slice('_time_states__BOUND_21', 1, 2, 2, 2),), + ('_PREV_states', + Variable('_drop_0__BOUND_20', Bint[2]),),), + Contraction(ops.max, ops.add, + frozenset({Variable('_PREV_states__BOUND_14', Bint[2])}), + (Delta( + (('_drop_0__BOUND_20', + (Tensor( + np.array([[[0, 0], [0, 0]]], dtype=np.int32), + (('_time_states__BOUND_21', + Bint[1],), + ('_PREV_states__BOUND_14', + Bint[2],), + ('states', + Bint[2],),), + 2), + Number(0.0),),),)), + Tensor( + np.array([[[-inf, -inf], [-inf, -inf]]], dtype=np.float32), + (('_time_states__BOUND_21', + Bint[1],), + ('_PREV_states__BOUND_14', + Bint[2],), + ('states', + Bint[2],),), + 'real'),)), + frozenset({Variable('_drop_0__BOUND_20', Bint[2]), Variable('_time_states__BOUND_21', Bint[1])})),)), # noqa + Contraction(ops.null, ops.add, + frozenset(), + (Tensor( + np.array([[-2.110203742980957, -1.0666589736938477], [-1.5835977792739868, -3.236558437347412]], dtype=np.float32), # noqa + (('_time_states', + Bint[2],), + ('states', + Bint[2],),), + 'real'), + Scatter( + ops.max, + (('_time_states', + Slice('_time_states__BOUND_25', 0, 2, 2, 2),), + ('states', + Variable('_drop_0__BOUND_24', Bint[2]),),), + Contraction(ops.max, ops.add, + frozenset({Variable('states__BOUND_13', Bint[2])}), + (Delta( + (('_drop_0__BOUND_24', + (Tensor( + np.array([[[0, 0], [0, 0]]], dtype=np.int32), + (('_time_states__BOUND_25', + Bint[1],), + ('_PREV_states', + Bint[2],), + ('states__BOUND_13', + Bint[2],),), + 2), + Number(0.0),),),)), + Tensor( + np.array([[[-inf, -inf], [-inf, -inf]]], dtype=np.float32), + (('_time_states__BOUND_25', + Bint[1],), + ('states__BOUND_13', + Bint[2],), + ('_PREV_states', + Bint[2],),), + 'real'),)), + frozenset({Variable('_drop_0__BOUND_24', Bint[2]), Variable('_time_states__BOUND_25', Bint[1])})),)), # noqa + Contraction(ops.null, ops.add, + frozenset(), + (Tensor( + np.array([[-2.110203742980957, -1.0666589736938477], [-1.5835977792739868, -3.236558437347412]], dtype=np.float32), # noqa + (('_time_states', + Bint[2],), + ('states', + Bint[2],),), + 'real'), + Contraction(ops.null, ops.max, + frozenset(), + (Scatter( + ops.max, + (('_time_states', + Slice('_time_states__BOUND_37', 1, 2, 2, 2),), + ('_PREV_states', + Variable('_drop_0__BOUND_36', Bint[2]),),), + Contraction(ops.max, ops.add, + frozenset({Variable('_PREV_states__BOUND_28', Bint[2])}), + (Delta( + (('_drop_0__BOUND_36', + (Tensor( + np.array([[[0, 1], [1, 1]]], dtype=np.int32), + (('_time_states__BOUND_37', + Bint[1],), + ('_PREV_states__BOUND_28', + Bint[2],), + ('states', + Bint[2],),), + 2), + Number(0.0),),),)), + Tensor( + np.array([[[-2.2727227210998535, -2.9637789726257324], [-1.2291778326034546, -1.2291778326034546]]], dtype=np.float32), # noqa + (('_time_states__BOUND_37', + Bint[1],), + ('_PREV_states__BOUND_28', + Bint[2],), + ('states', + Bint[2],),), + 'real'), + Scatter( + ops.max, + (('_time_states__BOUND_37', + Number(0, 1),),), + Contraction(ops.null, ops.add, + frozenset(), + (Delta( + (('states', + (Tensor( + np.array(0, dtype=np.int32), + (), + 2), + Number(0.0),),), + ('_PREV_states__BOUND_28', + (Tensor( + np.array(0, dtype=np.int32), + (), + 2), + Number(0.0),),),)), + Tensor( + np.array(-1.081649899482727, dtype=np.float32), + (), + 'real'),)), + frozenset()),)), + frozenset({Variable('_drop_0__BOUND_36', Bint[2]), Variable('_time_states__BOUND_37', Bint[1])})), # noqa + Scatter( + ops.max, + (('_time_states', + Slice('_time_states__BOUND_45', 0, 2, 2, 2),), + ('states', + Variable('_drop_0__BOUND_46', Bint[2]),),), + Contraction(ops.max, ops.add, + frozenset({Variable('states__BOUND_27', Bint[2])}), + (Scatter( + ops.max, + (('_time_states__BOUND_45', + Number(0, 1),),), + Contraction(ops.null, ops.add, + frozenset(), + (Delta( + (('_PREV_states', + (Tensor( + np.array(0, dtype=np.int32), + (), + 2), + Number(0.0),),), + ('states__BOUND_27', + (Tensor( + np.array(0, dtype=np.int32), + (), + 2), + Number(0.0),),),)), + Tensor( + np.array(-1.081649899482727, dtype=np.float32), + (), + 'real'),)), + frozenset()), + Delta( + (('_drop_0__BOUND_46', + (Tensor( + np.array([[[0, 1], [1, 1]]], dtype=np.int32), + (('_time_states__BOUND_45', + Bint[1],), + ('_PREV_states', + Bint[2],), + ('states__BOUND_27', + Bint[2],),), + 2), + Number(0.0),),),)), + Tensor( + np.array([[[-1.7461166381835938, -3.480717658996582], [-3.3990774154663086, -3.3990774154663086]]], dtype=np.float32), # noqa + (('_time_states__BOUND_45', + Bint[1],), + ('states__BOUND_27', + Bint[2],), + ('_PREV_states', + Bint[2],),), + 'real'),)), + frozenset({Variable('_time_states__BOUND_45', Bint[1]), Variable('_drop_0__BOUND_46', Bint[2])})),)),)),)) # noqa + assert isinstance(actual, Tensor), actual.quote()