How do I iterate over non repeating combinations efficiently? #18919
-
Given an array of arbitrary values import jax.numpy as jnp
import jax
samples = jnp.linspace(0, 2, 1000) I'd like to compute the expected value of some symmetric function like for example def f(a, b):
return jnp.exp(-(a - b)**2) over all pairs of samples. The simplest way I can think of is: @jax.jit
def fgrid1(samples):
x1, x2 = jnp.meshgrid(samples, samples)
return jax.vmap(f)(x1, x2).mean() and this works as expected (for one dimensional inputs). However I'd like to make this faster by iterating over non repeating pairs of samples only, rather than all of them. To simplify the examples, please ignore the fact that this is subtly incorrect as one would have to weight the case i==j by half. In pure python one would do values = []
for i in range(len(samples)):
for j in range(i, len(samples)):
values.append(f(samples[i], samples[j]))
result = jnp.mean(jnp.array(values)) A way of doing something similar in jax is: @jax.jit
def fgrid2(samples):
x1ind, x2ind = jnp.tril_indices(len(samples))
return jax.vmap(f)(samples[x1ind], samples[x2ind]).mean() However this has rather worse runtime on my M2 Mac using CPU despite doing about half of the work:
I tried writing the loop explicitly and that still didn't beat def pairs(arr):
l = len(arr)*(len(arr) + 1)//2
a = jnp.empty((l, *arr.shape[1:]))
b = jnp.empty((l, *arr.shape[1:]))
def body(i, val):
((ai, bi), (a, b)) = val
a = a.at[i].set(arr[ai])
b = b.at[i].set(arr[bi])
return jax.lax.cond(bi < len(arr)-1,
lambda: ((ai, bi+1), (a, b)),
lambda: ((ai+1, ai+1), (a,b))
)
(_idx, (aret, bret)) = jax.lax.fori_loop(0, l, body, ((0, 0), (a, b)))
return aret, bret
@jax.jit
def fgrid3(arr):
x1, x2 = pairs(arr)
return jax.vmap(f)(x1, x2).mean()
Is there a way to improve on |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 2 replies
-
Different types of operations have different costs. In JAX, computing elementwise operations between values in existing arrays will be relatively efficient (true on CPU, but especially true on GPU and TPU). Extracting elements of arrays into other arrays via indexing will be relatively inefficient (again, true on CPU, but especially true on GPU and TPU). What you're attempting here is to remove My suggestion would be to use your original function. Yes, it does twice as much work as strictly necessary, but the work it does is much more efficient than what would be required to avoid that work. The only time it would make sense to worry about this would be if the duplicative operations are much more complicated, to the point where introducing scatter and gather operations is preferable to doing duplicated work. What do you think? |
Beta Was this translation helpful? Give feedback.
Different types of operations have different costs. In JAX, computing elementwise operations between values in existing arrays will be relatively efficient (true on CPU, but especially true on GPU and TPU). Extracting elements of arrays into other arrays via indexing will be relatively inefficient (again, true on CPU, but especially true on GPU and TPU).
What you're attempting here is to remove
N / 2
duplicative—but very efficient—operations and replace them withN / 2
very inefficient operations. The result is going to be slower execution.My suggestion would be to use your original function. Yes, it does twice as much work as strictly necessary, but the work it does is much more efficie…