Is "a = a + b" or "a += b" applied in-place inside jit() ? #19260
Replies: 1 comment
-
This question is hard to answer, because it is based on the assumption that a sub-computation like You can get some visibility into the choices the compiler is making for a particular program using ahead-of-time lowering and compilation APIs. For example, here's how the compiler handles a simple version of import jax
def f(x):
x -= 1
return x.sum()
x = jnp.arange(10.0)
print(jax.jit(f).lower(x).compile().as_text())
Notice that in the compiled output, the entire operation So the answer to your question is no, this operation is not applied in-place (as in, the input buffer is not modified) but neither are the values copied to a different buffer. Does that help answer your question? |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
According to https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html#jax.numpy.ndarray.at
I wanna know that is "a = a + b" or "a += b" applied in-place inside jit() ?
Beta Was this translation helpful? Give feedback.
All reactions