What is the best way to parallelize complex nested loops #11826
Unanswered
nkupelioglu
asked this question in
Q&A
Replies: 1 comment
-
Usually that kind of access pattern ( My usual approach for those ugly loops is to reshape the data ( You cannot modify a jax array in place but there are two workaround:
|
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hello everyone,
I have got a nested loop that is indexed in a very complex way with slices and skips.
where
matrix
is of rank 5 andx
is of rank 4, which are also fairly large. The reason I am pointing out this is that, I would like to implement them in an inplace way. (i.e. lax.dynamic_slice(?) )Every update on this nested-loop is independent between each iteration of each loop, so technically it should be possible to unroll everything here. I am fairly new to Jax, so I wanted to ask how you would implement this (i.e. whether with nested vmaps, fori_loop), for least memory consumption and best timing.
Beta Was this translation helpful? Give feedback.
All reactions