-
Notifications
You must be signed in to change notification settings - Fork 22
Open
Labels
Description
The if statement in eager_reduce_exp evaluates to False under memoize and the function returns None. Without memoize it returns log_result.exp() as expected.
Lines 157 to 165 in ca1557b
| @eager.register(Reduce, ops.AddOp, Unary[ops.ExpOp, Funsor], frozenset) | |
| def eager_reduce_exp(op, arg, reduced_vars): | |
| # x.exp().reduce(ops.add) == x.reduce(ops.logaddexp).exp() | |
| log_result = arg.arg.reduce(ops.logaddexp, reduced_vars) | |
| if log_result is not normalize.interpret( | |
| Reduce, ops.logaddexp, arg.arg, reduced_vars | |
| ): | |
| return log_result.exp() | |
| return None |
Example code:
from funsor.cnf import Contraction
from funsor.tensor import Tensor
import torch
import funsor.ops as ops
from funsor import Bint, Real
from funsor.terms import Unary, Binary, Variable, Number, eager, lazy, to_data, Reduce
from funsor.constant import Constant
from funsor.delta import Delta
from funsor.integrate import Integrate
import funsor
funsor.set_backend("torch")
cls = Reduce
args = (ops.add,
Unary(ops.exp,
Contraction(ops.null, ops.add,
frozenset(),
(Delta(
(('x__BOUND_16',
(Tensor(
torch.tensor([1, 0, 1, 0, 0, 0, 1, 0, 1, 1], dtype=torch.int64),
(('plate__BOUND_17',
Bint[10],),),
3),
Number(0.0),),),)),
Tensor(
torch.tensor([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], dtype=torch.float64), # noqa
(('plate__BOUND_17',
Bint[10],),),
'real'),))),
frozenset({Variable('x__BOUND_16', Bint[3])})
)
# evaluates to a Tensor
result = eager.interpret(cls, *args)
with funsor.interpretations.memoize():
# evaluates to a lazy Contraction term
result2 = eager.interpret(cls, *args)Reactions are currently unavailable