You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm attempting to implement computations which utilize a while_loop to express tail recursive computations -- but I'm also trying to understand if it is possible to include a re-allocation strategy if the while_loop runs longer than the zero-dimension of a pre-allocated array inside of code that I want to JIT.
Specifically, I'm considering body_fun instances where a passed in state x.at[idx].set(y) is accessed linearly (only once) before between returned out. From the documentation, it appears that this form of linear access is optimized away to an in-place update.
That's great -- but to satisfy the XLA While restriction, it appears that I can't self adjust the first dimension of the array if, say, idx goes out of bounds -- the computation proceeds longer than I expected with the pre-allocated x.
Assuming this is not possible now (seems likely that this is not possible at present) -- would this sort of computation be possible with dynamic shapes?
I am also quite curious if this is possible now within code that I wish to JIT, even if there's sub-optimal performance properties.
Update: I’ve been reading through the MLIR discourse to determine the roadmap / blockers for fixed rank dynamic dimension support (which is ultimately what I would need, for this functionality to be supported for operations like While) — it seems like both IREE and a speculative DISC compiler support some semblance of this code path.
I noticed that JAX exposes MHLO, which seems to expose dynamic shape support — does MHLO already support what I’m seeking, and the main technical obstacle is adapting JAX’s tracer to support dynamic shapes?
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
I'm attempting to implement computations which utilize a
while_loop
to express tail recursive computations -- but I'm also trying to understand if it is possible to include a re-allocation strategy if thewhile_loop
runs longer than the zero-dimension of a pre-allocated array inside of code that I want to JIT.Specifically, I'm considering
body_fun
instances where a passed in statex.at[idx].set(y)
is accessed linearly (only once) before between returned out. From the documentation, it appears that this form of linear access is optimized away to an in-place update.That's great -- but to satisfy the XLA
While
restriction, it appears that I can't self adjust the first dimension of the array if, say,idx
goes out of bounds -- the computation proceeds longer than I expected with the pre-allocatedx
.Assuming this is not possible now (seems likely that this is not possible at present) -- would this sort of computation be possible with dynamic shapes?
I am also quite curious if this is possible now within code that I wish to JIT, even if there's sub-optimal performance properties.
Update: I’ve been reading through the MLIR discourse to determine the roadmap / blockers for fixed rank dynamic dimension support (which is ultimately what I would need, for this functionality to be supported for operations like While) — it seems like both IREE and a speculative DISC compiler support some semblance of this code path.
I noticed that JAX exposes MHLO, which seems to expose dynamic shape support — does MHLO already support what I’m seeking, and the main technical obstacle is adapting JAX’s tracer to support dynamic shapes?
Beta Was this translation helpful? Give feedback.
All reactions