You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I have an array [0 0 0 1 2 3] that I need to store in an output of format [0 0 0 0 0 0] in the following way [1 2 3 0 0 0]. So I could just slice and use pl.store to place the [1 2 3] bit at the beginning of the output array but if I use non power of 2 slices, my kernel crashes (which I was told is normal).
My question is, can I use pointers to basically tell the computer, start writing the array a bit before the output placeholder starts that way I would get what I am after. I am very confident that I can do this in Triton because it uses pointers but I was wondering if I can do it in Pallas?
If not then I have to use a jax.numpy.roll method which is very costly.
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
Uh oh!
There was an error while loading. Please reload this page.
-
Hi there,
I have an array
[0 0 0 1 2 3]
that I need to store in an output of format[0 0 0 0 0 0]
in the following way[1 2 3 0 0 0]
. So I could just slice and usepl.store
to place the[1 2 3]
bit at the beginning of the output array but if I use non power of 2 slices, my kernel crashes (which I was told is normal).My question is, can I use pointers to basically tell the computer, start writing the array a bit before the output placeholder starts that way I would get what I am after. I am very confident that I can do this in Triton because it uses pointers but I was wondering if I can do it in Pallas?
If not then I have to use a
jax.numpy.roll
method which is very costly.Many thanks!
Beta Was this translation helpful? Give feedback.
All reactions