How to keep hashable arguments static through a scan #16667
Replies: 3 comments 5 replies
-
Thanks for the question! Any of the arguments that pass through the def loop_function(itermax):
iter = 0
nonstatic = 1
static = 2
def scanner_closure(a, x):
a, x = scanner((a[0], a[1], static), x)
return (a[0], a[1]), x
a, x = lax.scan(scanner_closure, (iter, nonstatic), xs=None, length=itermax) Now |
Beta Was this translation helpful? Give feedback.
-
This is a very common issue. As a follow-up, you may like |
Beta Was this translation helpful? Give feedback.
-
For a native JAX solution I think handling the carry of a scanned function as import jax
from functools import partial
from dataclasses import dataclass, field
@jax.tree_util.register_dataclass
@dataclass
class ScanCarry:
"""Named scan args"""
iter: jax.Array
nonstatic: jax.Array
static: int = field(metadata={"static": True})
def loop_function(itermax):
carry = ScanCarry(
iter=0,
nonstatic=1,
static=2,
)
return jax.lax.scan(body, carry, xs=None, length=itermax)
def body(a, x):
nonstatic = a.nonstatic * a.static
iter = a.iter + 1 + nonstatic % 100
return ScanCarry(iter, nonstatic, a.static), x
if __name__=='__main__':
print(loop_function(1000)) I think it improves readability and gets rid of to many |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
-
Hi, disclaimer: I am new to JAX but it has been awesome so far! :D
I have a function ('jitfunc') that is jit-inherited and works properly. This function is being called in a while loop. The while loop itself is very slow whereas the function is fast. I would like to convert my while loop to either lax.while_loop, or lax.fori_loop but most likely lax.scan as it is differentiable and I the next step for this project is AI stuff.
So I have a maximal number of iterations itermax for my loop but the reason I use a while loop instead of a for loop is because sometimes, the jit-function can speed up the calculations and increase the iterand by more than 1.
In that scenario, I was planning to use scan, and have the function do nothing when it has reached itermax naturally (I think this is doable but I have no idea if it is optimal)
However, my issue comes from the fact that when I try running lax.scan, I get the error:
ValueError: Non-hashable static arguments are not supported, as this can lead to unexpected cache-misses. Static argument (index 6) of type for function in_loop is non-hashable.
From looking around online, it looks like using lax.scan makes my hashable static variables non-hashable? (Not sure about this one)
Is there a fix for what I am trying to do?
My code is not super readable so here is a minimal example that replicates my error:
Thanks! :)
Beta Was this translation helpful? Give feedback.
All reactions