diff --git a/funsor/factory.py b/funsor/factory.py index 127f55b0..cf2f51a3 100644 --- a/funsor/factory.py +++ b/funsor/factory.py @@ -10,7 +10,15 @@ import makefun from funsor.instrument import debug_logged -from funsor.terms import Funsor, FunsorMeta, Variable, eager, to_funsor +from funsor.terms import ( + Funsor, + FunsorMeta, + Subs, + Variable, + eager, + substitute, + to_funsor, +) from funsor.util import as_callable @@ -137,7 +145,9 @@ def _get_dependent_args(fields, hints, args): return { name: arg if isinstance(hint, Value) else arg.output for name, arg, hint in zip(fields, args, hints) - if hint in (Funsor, Bound) or isinstance(hint, (Has, Value)) + if hint in (Funsor, Bound) + or isinstance(hint, (Has, Value)) + or (isinstance(hint, Fresh) and name in hint.args) } @@ -179,19 +189,41 @@ def Unflatten( for name, hint in input_types.items(): if not (hint in (Funsor, Bound) or isinstance(hint, (Fresh, Value, Has))): raise TypeError(f"Invalid type hint {name}: {hint}") + if any( + isinstance(hint, Fresh) and arg in hint.args + for arg, hint in input_types.items() + ): + input_types["bind_return"] = Value[frozenset] + + def new_fn(*args): + args, bind_return = args[:-1], args[-1] + result = fn(*args) + return Subs(result, bind_return) + + else: + new_fn = fn + output_type = input_types.pop("return") hints = tuple(input_types.values()) class ResultMeta(FunsorMeta): - def __call__(cls, *args): + def __call__(cls, *args, bind_return=None): args = list(args) + # Bind-and-return variables + if bind_return is None: + bind_return = frozenset( + (arg, arg) + for hint, arg, arg_name in zip(hints, args, cls._ast_fields) + if isinstance(hint, Fresh) and arg_name in hint.args + ) + # Compute domains of bound variables. for i, (name, arg) in enumerate(zip(cls._ast_fields, args)): hint = input_types[name] if hint is Funsor or isinstance(hint, Has): # TODO support domains args[i] = to_funsor(arg) - elif hint is Bound: + elif hint is Bound or (isinstance(hint, Fresh) and name in hint.args): for other in args: if isinstance(other, Funsor): domain = other.inputs.get(arg, None) @@ -209,10 +241,19 @@ def __call__(cls, *args): # Compute domains of fresh variables. dependent_args = _get_dependent_args(cls._ast_fields, hints, args) - for i, (hint, arg) in enumerate(zip(hints, args)): - if isinstance(hint, Fresh): + for i, (hint, arg, arg_name) in enumerate( + zip(hints, args, cls._ast_fields) + ): + if isinstance(hint, Fresh) and arg_name in hint.args: + domain = hint(**dependent_args) + args[i] = to_funsor(arg.name, domain) + elif isinstance(hint, Fresh): domain = hint(**dependent_args) args[i] = to_funsor(arg, domain) + + # Append bind_return to args + if bind_return: + args.append(bind_return) return super().__call__(*args) @makefun.with_signature( @@ -220,10 +261,12 @@ def __call__(cls, *args): ) def __init__(self, **kwargs): args = tuple(kwargs[k] for k in self._ast_fields) + bind_return = dict(kwargs.get("bind_return", dict())) dependent_args = _get_dependent_args(self._ast_fields, hints, args) output = output_type(**dependent_args) inputs = OrderedDict() bound = {} + fresh = frozenset() for hint, arg, arg_name in zip(hints, args, self._ast_fields): if hint is Funsor: assert isinstance(arg, Funsor) @@ -232,28 +275,45 @@ def __init__(self, **kwargs): assert isinstance(arg, Funsor) inputs.update(arg.inputs) for name in hint.bound: - if kwargs[name] not in arg.input_vars: + if kwargs[name].name not in arg.inputs: warnings.warn( f"Argument {arg_name} is missing bound variable {kwargs[name]} from argument {name}." f"Are you sure {name} will always appear in {arg_name}?", SyntaxWarning, ) - for hint, arg in zip(hints, args): + for hint, arg, arg_name in zip(hints, args, self._ast_fields): if hint is Bound: bound[arg.name] = inputs.pop(arg.name) + elif isinstance(hint, Fresh) and arg_name in hint.args: + bound[arg.name] = inputs.pop(arg.name) + inputs[bind_return[arg.name]] = arg.output + fresh |= frozenset({bind_return[arg.name]}) for hint, arg in zip(hints, args): if isinstance(hint, Fresh): - for k, d in arg.inputs.items(): - if k not in bound: - inputs[k] = d - fresh = frozenset() + if arg.name not in bound: + inputs[arg.name] = arg.output + fresh |= frozenset({arg.name}) Funsor.__init__(self, inputs, output, fresh, bound) for name, arg in zip(self._ast_fields, args): + if name == "bind_return": + arg = dict(arg) setattr(self, name, arg) def _alpha_convert(self, alpha_subs): - alpha_subs = {k: to_funsor(v, self.bound[k]) for k, v in alpha_subs.items()} - return Funsor._alpha_convert(self, alpha_subs) + result = [] + new_alpha_subs = {k: to_funsor(v, self.bound[k]) for k, v in alpha_subs.items()} + for hint, value, arg_name in zip(hints, self._ast_values, self._ast_fields): + if isinstance(hint, Fresh) and arg_name in hint.args: + result.append(to_funsor(alpha_subs[value.name], value.output)) + elif arg_name == "bind_return": + result.append( + frozenset( + (alpha_subs.get(k, k), v) for k, v in self.bind_return.items() + ) + ) + else: + result.append(substitute(value, new_alpha_subs)) + return tuple(result) name = _get_name(fn) ResultMeta.__name__ = f"{name}Meta" @@ -263,7 +323,7 @@ def _alpha_convert(self, alpha_subs): pattern = (Result,) + tuple( _hint_to_pattern(input_types[k]) for k in Result._ast_fields ) - eager.register(*pattern)(_erase_types(fn)) + eager.register(*pattern)(_erase_types(new_fn)) return Result diff --git a/test/test_factory.py b/test/test_factory.py index aff82a9a..887a0eff 100644 --- a/test/test_factory.py +++ b/test/test_factory.py @@ -297,3 +297,44 @@ def MatMul( # To preserve extensionality, should only error on reflect xy = MatMul(x, y, "b") check_funsor(xy, {"a": Bint[3], "c": Bint[4], "d": Bint[3]}, Real) + + +def test_unroll(): + @make_funsor + def Unroll( + x: Has[{"ax"}], # noqa: F821 + ax: Fresh[lambda ax, k: Bint[ax.size - k + 1]], + k: Value[int], + kernel: Fresh[lambda k: Bint[k]], + ) -> Fresh[lambda x: x]: + return x(**{ax.name: ax + kernel}) + + x = random_tensor(OrderedDict(a=Bint[5])) + with reflect: + y = Unroll(x, "a", 2, "kernel") + assert y.fresh == frozenset({"a", "kernel"}) + assert all(bound in y.x.inputs and "__BOUND" in bound for bound in y.bound) + check_funsor(y, {"a": Bint[5 - 2 + 1], "kernel": Bint[2]}, Real) + z = reinterpret(y) + assert isinstance(z, Tensor) + check_funsor(z, {"a": Bint[5 - 2 + 1], "kernel": Bint[2]}, Real) + + +def test_softmax(): + @make_funsor + def Softmax( + x: Has[{"ax"}], # noqa: F821 + ax: Fresh[lambda ax: ax], + ) -> Fresh[lambda x: x]: + y = x - x.reduce(ops.logaddexp, ax) + return y.exp() + + x = random_tensor(OrderedDict(a=Bint[3], b=Bint[4])) + with reflect: + y = Softmax(x, "a") + assert y.fresh == frozenset({"a"}) + assert all(bound in y.x.inputs and "__BOUND" in bound for bound in y.bound) + check_funsor(y, {"a": Bint[3], "b": Bint[4]}, Real) + z = reinterpret(y) + assert isinstance(z, Tensor) + check_funsor(z, {"a": Bint[3], "b": Bint[4]}, Real)