Replies: 1 comment 1 reply
-
First off, for running microbenchmarks in JAX, be sure to follow the recommendations at FAQ: Benchmarking JAX Code. I found the following timings for your calls on a Colab CPU: _ = update_priorities2(tree_dmy[0],stack_p(indx)[0],priority_dmy[0]) # compile
%timeit jax.block_until_ready(update_priorities2(tree_dmy[0],stack_p(indx)[0],priority_dmy[0]))
# 1.49 ms ± 244 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
_ = j_update_tree(trees_Ns,stack_p(indx),priority_dmy) # compile
%timeit jax.block_until_ready(j_update_tree(trees_Ns,stack_p(indx),priority_dmy))
# 312 µs ± 12.6 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
_ = v_update_tree(trees_Ns,stack_p(indx),priority_dmy) # compile
%timeit jax.block_until_ready(v_update_tree(trees_Ns,stack_p(indx),priority_dmy))
# 2.96 ms ± 119 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) So it looks like the |
Beta Was this translation helpful? Give feedback.
1 reply
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
I have the following code that operates on a tree-like data structure,
trees_Ns
in my code:My questions are:
update_priorities2
is slower than that using the vmap, i.e.j_update_tree
?update_priorities2
function that contains a for loop? This for loop is needed as the update of the tree's values need to be performed in order following theindices
. Also, this function's perfomance highly scales with the length of bothindices
andpriorities
. Are there any alternatives of constructing the function that is not scaled with the length of the input, like using vmap?Beta Was this translation helpful? Give feedback.
All reactions