BindReturn type hint for make_funsor#518
BindReturn type hint for make_funsor#518ordabayevy wants to merge 12 commits intopyro-ppl:masterfrom
Conversation
|
Thanks for adding this! WDYT about making |
Do you mean that in the example below @make_funsor
def Unroll(
x: Has[{"ax"}], # noqa: F821
ax: Fresh[lambda ax, k: Bint[ax.size - k + 1]],
k: Value[int],
kernel: Fresh[lambda k: Bint[k]],
) -> Fresh[lambda x: x]:
return x(**{ax.name: ax + kernel}) |
Yes, exactly. On a related note, we should also start using |
I like the idea. I'll make the changes then. |
|
Sorry for taking so long to review this (especially since I suggested you try it in the first place). I am still not sure how to go about fixing alpha-conversion as a whole in a way that remains compatible with cons-hashing, so I had put off thinking about details. I think for the behavior implemented in this PR to be safe and correct by construction in general, we would need to eagerly alpha-mangle the arguments to a We could write a simple decorator for rewrite rules to perform this extra step: def bind_args(term):
def binding_wrapper(rule):
def wrapped_rule(*args):
mangled_args = reflect.interpret(term, *args)._ast_values
return rule(*mangled_args)
return functools.wraps(rule)(wrapped_rule)
return binding_wrapperTo illustrate the use of @make_funsor
def Softmax(
x: Has[{"ax"}], # noqa: F821
ax: Fresh[lambda ax: ax],
) -> Fresh[lambda x: x]:
return None
@eager.register(Softmax, Tensor, Variable)
@bind_args(Softmax)
def _eager_softmax(x, ax):
y = x - x.reduce(ops.logaddexp, ax)
return y.exp()Of course, this is less ergonomic than the original syntax, so I could imagine folding |
Addresses #481.
BindReturntype hint is used for binding and returning a variable. For example:or