Replies: 1 comment
-
Thanks for the question! The difference here is that if you use |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Dear developers,
I have a function as follows
where
t1.shape=(8,38)
andt2.shape=(8,8,38,38)
.The JIT compiled function is quite simple and uses a buffer with the size of ~1MB, which is as expected.
However, if I replace
numpy.tril_indices
withjax.numpy.tril_indices
, the complied function becomes very complicated and the buffer size is doubled (see module_0309.jit_amplitudes_to_vector.cpu_after_optimizations-buffer-assignment.txt). What is the cause of this, and should I just use numpy for indexing purposes instead of jax.numpy?Thank you in advance.
Beta Was this translation helpful? Give feedback.
All reactions