You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm writing because I have some ideas about using donate arguments to speed up computations, but I'm not sure how this works. So, I would appreciate any ideas and clarification.
The following is my understanding.
Let's say I have a Jax array x. If I want to mutate x, the Jax way would be to apply a transformation and then overwrite x with the new values:
This is where it gets tricky. What happens to the memory? Let's say we are on a CPU, so there is no need to consider where the data is. My impression from Jax's documentation is that we create a copy of x, and then replace the pointer of x so it now points to the new memory with the updated values. The async execution hides the latency of creating new memory by continuing with the computations. We can confirm this by printing the id() of x before and after the function is called:
I know the array type holds the pointer to the actual data buffer allocated on the device. One question is whether this buffer is reused. So even if we get a new object, it's just a tiny struct (in case the buffer is reused). Maybe JIT optimizes this?
Now, let's get into the donated arguments. The idea is that if we donate a buffer, we can't use it again. I like to think of this as the hoe rust arguments are consumed when they get inside a function:
UserWarning: Some donated buffers were not usable: float32[10].
See an explanation at https://docs.jax.dev/en/latest/faq.html#buffer-donation. warnings.warn("Some donated buffers were not usable:"
Why does the first example work, but this one does not? Now, if func only takes x and I donate x, it does not show as deleted, but I also don't get the warning. This gives me the idea that every function with this pattern should have the arguments donated: x = func(x). But at this point, I have no idea if the donating argument in this case does anything, and it's weird how some cases work and some don't.
Now, I wanted to see what was going on with the data buffer pointer and found something very weird
@partial(jax.jit, donate_argnames=("x", "y"))
238847680
/home/wind/PythonEnvs/python3/lib/python3.13/site-packages/jax/_src/interpreters/mlir.py:1316: UserWarning: Some donated buffers were not usable: float32[10].
See an explanation at https://docs.jax.dev/en/latest/faq.html#buffer-donation.
warnings.warn("Some donated buffers were not usable:"
238847680
Note that donating x allows the buffer to be reused, even with the warning. However, it's weird that we get a warning when donating both arguments but not x.
I would appreciate it if anyone has insights into how this works and could clarify best practices and how this affects performance.
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
Uh oh!
There was an error while loading. Please reload this page.
-
Hi,
I'm writing because I have some ideas about using donate arguments to speed up computations, but I'm not sure how this works. So, I would appreciate any ideas and clarification.
The following is my understanding.
Let's say I have a Jax array
x
. If I want to mutatex
, the Jax way would be to apply a transformation and then overwrite x with the new values:This is where it gets tricky. What happens to the memory? Let's say we are on a CPU, so there is no need to consider where the data is. My impression from Jax's documentation is that we create a copy of x, and then replace the pointer of x so it now points to the new memory with the updated values. The async execution hides the latency of creating new memory by continuing with the computations. We can confirm this by printing the
id()
of x before and after the function is called:This happened even inside JITted functions:
wich yields
I know the array type holds the pointer to the actual data buffer allocated on the device. One question is whether this buffer is reused. So even if we get a new object, it's just a tiny struct (in case the buffer is reused). Maybe JIT optimizes this?
Now, let's get into the donated arguments. The idea is that if we donate a buffer, we can't use it again. I like to think of this as the hoe rust arguments are consumed when they get inside a function:
However, what if we donate x? Well, that's weird, and this is why I would like some explanations. If I do this:
But if I do this
and I get this warning
Why does the first example work, but this one does not? Now, if func only takes x and I donate x, it does not show as deleted, but I also don't get the warning. This gives me the idea that every function with this pattern should have the arguments donated:
x = func(x)
. But at this point, I have no idea if the donating argument in this case does anything, and it's weird how some cases work and some don't.Now, I wanted to see what was going on with the data buffer pointer and found something very weird
Note that donating x allows the buffer to be reused, even with the warning. However, it's weird that we get a warning when donating both arguments but not x.
I would appreciate it if anyone has insights into how this works and could clarify best practices and how this affects performance.
Beta Was this translation helpful? Give feedback.
All reactions