Skip to content
Discussion options

You must be logged in to vote

Different types of operations have different costs. In JAX, computing elementwise operations between values in existing arrays will be relatively efficient (true on CPU, but especially true on GPU and TPU). Extracting elements of arrays into other arrays via indexing will be relatively inefficient (again, true on CPU, but especially true on GPU and TPU).

What you're attempting here is to remove N / 2 duplicative—but very efficient—operations and replace them with N / 2 very inefficient operations. The result is going to be slower execution.

My suggestion would be to use your original function. Yes, it does twice as much work as strictly necessary, but the work it does is much more efficie…

Replies: 1 comment 2 replies

Comment options

You must be logged in to vote
2 replies
@zahari-kassabov-opificio
Comment options

@jakevdp
Comment options

Answer selected by zahari-kassabov-opificio
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants