-
Hi all, I'm porting to Jax one of my Pytorch codebase where one big function (which I'll denote F) takes:
and returns some quantity of interest. The inner workings of that function involve a for-loop on the data with some carry-over. Up to now I managed to get everything working with lax.scan and putting Theta in the carry, when Theta is a Pytree with np.array leaves. However one convenient aspect of my Pytorch implementation is that some elements of Theta are actually nn modules (which contain their own parameters but can also be called at some point in F). I can give any module I want (which is practical for me because these modules correspond to arbitrarily complex parameterized functions on which I want to run experiments). I wanted to reproduce this behaviour by passing some tuples of (parameters, function _to_apply) in the carry of lax.scan, or even passing some Haiku model, but apparently this is not allowed (I get error messages saying that "function" is not a valid type for the carry of scan). I can get around this by passing some "string" in the carry with the name of the function to apply and keeping a dict of these functions somewhere and using e.g lax.cond, but that seems a big ugly so I was wondering if there was some other way ? Thanks in advance! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 10 replies
-
I think something like this will work: from jax import jit
from jax.tree_util import Partial
th = (Partial(lambda params, x: params['x'] + x), dict(x=1))
@jit
def app(t, x):
f, params = t
return f(params, x)
app(th, 1) |
Beta Was this translation helpful? Give feedback.
I think something like this will work: