-
I've got the following code pattern: class MyProblem():
@partial(jax.jit, static_argnums=0)
def f(self, x, y):
print("compiling")
return x + y
def solve(self, xs, y):
def calls_f(x):
return self.f(x, y)
return jax.pmap(calls_f)(xs)
prob = MyProblem()
prob.solve(jnp.array(range(4)), 0) This seems to trigger compilation every time I run the last line. If I replace with I'm wondering how I can rewrite this to avoid the recompilation? Is |
Beta Was this translation helpful? Give feedback.
Replies: 3 comments 14 replies
-
Hi. - In fact class MyProblem():
# jit here is unnecessary
def f(self, x, y):
print("compiling")
return x + y
def solve(self, xs, y):
return jax.pmap(self.f, in_axes=(0, None))(xs, y)
prob = MyProblem()
prob.solve(jnp.array(range(4)), 0) |
Beta Was this translation helpful? Give feedback.
-
@marius311 I think from jax.tree_util import Partial # note: not functools.partial
def call_fun(f, *args, **kwargs):
return f(*args, **kwargs)
class MyProblem():
@partial(jax.jit, static_argnums=0)
def f(self, x, y):
print("compiling")
return x + y
def calls_f_open(self, x, y):
return self.f(x, y)
def solve(self, xs, y):
calls_f = Partial(self.calls_f_open, y=y)
# now the "closure" is a Pytree, with self.calls_f_open as aux_data which has same hash for same self
return jax.pmap(call_fun, in_axes=(None, 0))(calls_f, xs)
prob = MyProblem()
prob.solve(jnp.array(range(4)), 0) |
Beta Was this translation helpful? Give feedback.
-
To summarize the answer from the very helpful discussion:
And this comment also explains it. Workarounds here include dont use a closure, or JIT the entire |
Beta Was this translation helpful? Give feedback.
To summarize the answer from the very helpful discussion:
And this comment also explains it. Workarounds here include dont use a closure, or JIT the entire
solve
function if you can.