-
So, I've got some code and I could really use help deciphering the behavior and how to get it to do what I want. See my code as follows: from typing import Callable, List
import chex
import jax.numpy as jnp
import jax
Weights = List[jnp.ndarray]
@chex.dataclass(frozen=True)
class Model:
mult: Callable[
[jnp.ndarray],
jnp.ndarray
]
jitted_mult: Callable[
[jnp.ndarray],
jnp.ndarray
]
weight_updater: Callable[
[jnp.ndarray], None
]
def create_weight():
return jnp.ones((2, 5))
def wrapper():
weights = create_weight()
def mult(input_var):
return weights.dot(input_var)
@jax.jit
def jitted_mult(new_weights):
return weights.dot(input_var)
def update_locally_created(new_weights):
nonlocal weights
weights = new_weights
return weights
return Model(
mult=mult,
jitted_mult=jitted_mult,
weight_updater=update_locally_created
)
if __name__ == '__main__':
tester = wrapper()
to_mult = jnp.ones((5, 2))
for i in range(5):
print(jnp.sum(tester.mult(to_mult)))
print(jnp.sum(tester.jitted_mult(to_mult)))
if i % 2 == 0:
tester.weight_updater(jnp.zeros((2, 5)))
else:
tester.weight_updater(jnp.ones((2, 5)))
print("*" * 10) TL;DR I'm defining some "weights" within a function closure, and I'm trying to modify the weights via a What can I do to make it recognize the update? I think that I might be able to do what Build your own Haiku does, but that seems like a lot of work for an experiment |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 3 replies
-
maybe a smaller scope for jit? I also come through problems like this sometimes, i think jit the dot operation may work, like def jitted_mult(input_var):
return jax.jit(jnp.dot)(weights, input_var) |
Beta Was this translation helpful? Give feedback.
-
I'm not sure I understand your question. Is it possible that you meant for @jax.jit
def jitted_mult(input_var):
return weights.dot(input_var) If that's the case, then the reason it's not respecting the update is because your function is not pure (see JAX Sharp Bits: Pure Functions). In your case, the function is not pure because the output depends on an input that is not explicitly passed to the function. This violates the assumptions made by To fix it I would make this implicit input explicit, so that your function is pure. It might look something like this: def wrapper():
def mult(input_var, weights):
return weights.dot(input_var)
@jax.jit
def jitted_mult(input_var, weights):
return weights.dot(input_var)
return Model(
mult=mult,
jitted_mult=jitted_mult,
weight_updater=None
)
if __name__ == '__main__':
tester = wrapper()
to_mult = jnp.ones((5, 2))
weights = create_weight()
for i in range(5):
print(jnp.sum(tester.mult(to_mult, weights)))
print(jnp.sum(tester.jitted_mult(to_mult, weights)))
if i % 2 == 0:
weights = jnp.zeros((2, 5))
else:
weights = jnp.ones((2, 5))
print("*" * 10) |
Beta Was this translation helpful? Give feedback.
I'm not sure I understand your question. Is it possible that you meant for
jitted_mult
to be defined like this?If that's the case, then the reason it's not respecting the update is because your function is not pure (see JAX Sharp Bits: Pure Functions). In your case, the function is not pure because the output depends on an input that is not explicitly passed to the function. This violates the assumptions made by
jit
and other JAX transformations, which leads to unexpected behavior.To fix it I would make this implicit input explicit, so that your function is pure. It might look something like this: