Parallel nested graph generation attempt #28341
Unanswered
DiagRisker
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hello,
I'm trying to speed up this simple function (with parallel scalar injection):
this will give a 2nd degree tensor (but what matters is only j in the end, i is repeated)
So I tried 2 versions focusing on j, one with jax.vmap, and another with jax.lax.scan:
In idea the array will ressemble this with jax.vmap
However as for numpy Arrays, jax.Array type does not work with non uniform shapes, so I'm ok with flattening / concatenating this.
But jax.lax.scan doesn't seem to work with this to my surprise..
Anyone in the forum knows how to make it work?
Thanks in advance!
Beta Was this translation helpful? Give feedback.
All reactions