Non-hashable static arguments not allowed for pmap but allowed for vmap? #14159
Unanswered
somearthling
asked this question in
Q&A
Replies: 1 comment 2 replies
-
This sounds like a bug: a traced value should never be valid as a static argument, so you should get an error regardless of whether you use import os
os.environ['XLA_FLAGS'] = " --xla_force_host_platform_device_count=8"
import jax
import jax.numpy as jnp
def f(x, y):
return x + y
@jax.jit
def g_vmap(x, y):
jax.vmap(jax.jit(f, static_argnums=1))(x, y)
@jax.jit
def g_pmap(x, y):
jax.pmap(jax.jit(f, static_argnums=1))(x, y)
x = jnp.ones((8, 2))
y = jnp.ones((8, 2))
g_vmap(x, y)
# ValueError: Non-hashable static arguments are not supported, as this can lead to unexpected cache-misses.
g_pmap(x, y)
# ValueError: Non-hashable static arguments are not supported, as this can lead to unexpected cache-misses. |
Beta Was this translation helpful? Give feedback.
2 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
I have this piece of code
qnode = jit(qml.QNode(circuit, dev, interface='jax'), static_argnums=2)
probs = pmap(qnode, in_axes=(0, None, None))(x, p, filt)
inside a function that is already jitted, and so x and p are jax tracers at this point. This runs completely fine, but when I replace the vmap with a pmap, I get the following error:
ValueError: Non-hashable static arguments are not supported, as this can lead to unexpected cache-misses. Static argument (index 2) of type <class 'jax.interpreters.partial_eval.DynamicJaxprTracer'> for function circuit is non-hashable.
Is this intended, or is there something I'm doing wrong that I can fix?
Beta Was this translation helpful? Give feedback.
All reactions