-
Hi all, I have a follow-up MWE of the stateful computation tutorial where the import jax
import jax.numpy as jnp
class Counter:
"""A simple counter."""
def __init__(self, n=None):
self.n = n if n is not None else 0
def count(self) -> int:
"""Increments the counter and returns the new value."""
self.n += 1
return self.n
def reset(self):
"""Resets the counter to zero."""
self.n = 0
jax.tree_util.register_pytree_node(
Counter,
lambda self: ((self.n,), None),
lambda _, xs: Counter(xs[0]),
)
counter = Counter()
counter.reset()
# LOOP 1
@jax.jit
def loop_fun(counter):
for _ in range(3):
output = counter.count()
jax.debug.print("{i}",i=output) # prints 1, 2, 3
loop_fun(counter)
# or scan equivalent
counter.reset()
# LOOP 2
_, res = jax.lax.scan(lambda c, _: (c, c.count()), init=counter, xs=jnp.arange(3))
print(res) # prints [1,2,3] In loops 1 and 2, the loop iterations are taken into account by the Are we getting the expected result by only chance ? Is returning Thanks for your insights :) |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 4 replies
-
Both cases are fine, becuase the functions you're using ( You don't seem to be depending on the global state of |
Beta Was this translation helpful? Give feedback.
-
Hi, thanks for the question and the reply. I have an extra follow-up question elaborating on @HGangloff example but merely replacing the
I tried to return
For context : I am asking this to debug or rewrite a more elaborated code where a custom class (registered as PyTree) implements a method |
Beta Was this translation helpful? Give feedback.
Both cases are fine, becuase the functions you're using (
loop_fun
in the case ofjit
and thelambda
in the case ofscan
) are non-side-effecting in themselves. The only side-effects that are problematic are those that happen across the boundaries of functions passed to transformations likejit
or to higher-order primitives likescan
.You don't seem to be depending on the global state of
counter
being updated in either case, so your functions are effectively pure. Does that make sense?