Update global variable on pmap function #7697
Unanswered
lucasliunju
asked this question in
Q&A
Replies: 1 comment
-
When a function changes a global value, that is known as a side effect, and JAX transforms like The operation you're doing looks more like a from jax import lax
def f(v, x):
v = train(v, x)
return v, g(v)
v = initialize_v()
v, outs = lax.scan(f, v, batches) |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hi,
I would like to ask a question about how to update the global value on pmap function. For example:
I would like to reuse the value of global "v" and update it when I run pmap(). However, I find I havn't update the value of "v" when I run it. Could you give me some advice on how to solve it.
Thnak you very much!
YOng
Beta Was this translation helpful? Give feedback.
All reactions