-
I recently went through the section about Runtime value debugging in JAX in the documentation. However, I didn't get the expected output for the very first example. Instead the two output lines are reordered. I understand that distinct calls to The expected ordering can be restored by either using the Am I misunderstanding something about the printing dynamics here? Example.
Expected output.
Actual output.
|
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 2 replies
-
This is unusual...what platform are you running on? |
Beta Was this translation helpful? Give feedback.
-
Unfortunately I wasn't able to reproduce your behavior, but seems plausible to me. We assumed data-dependence was sufficient to guarantee ordering of prints in an XLA program, but that assumption is probably wrong! XLA is a functional compiler and is free to do its own optimizations, such as rematerializing intermediate values instead of storing them. In your program, therefore, XLA could choose to first execute These compiler optimizations should not delete the call to We should update our docs to reflect this potential behavior. |
Beta Was this translation helpful? Give feedback.
Unfortunately I wasn't able to reproduce your behavior, but seems plausible to me. We assumed data-dependence was sufficient to guarantee ordering of prints in an XLA program, but that assumption is probably wrong!
XLA is a functional compiler and is free to do its own optimizations, such as rematerializing intermediate values instead of storing them. In your program, therefore, XLA could choose to first execute
print(sin(x))
and then executeprint(x)
instead of first storingy = sin(x)
.These compiler optimizations should not delete the call to
print
but more aggressive reordering is possible due to rematerialization.We should update our docs to reflect this potential behavior.