Reduce functionality for vmap #9505
-
I was trying to compute a Gauss-Newton matrix and had some difficulties in aggregating the outer jacobians in an efficient manner, i.e., a sum of outer products generated by large parameter-vectors, over the entire dataset. Essentially what this came down to was: I was looking for some aggregation functionality for the Of course, I know that I can split this computation up into batches and use a loop, which will be equivalent. But I was actually wondering how I could implement this in a more principled map-reduce procedure in Jax. I've tried playing around with Below is some code that summarizes what I would like to do. Unviable solution: import jax
import jax.numpy as jnp
n = 200
d = 1000
a = jnp.arange(n * d).reshape((n, d))
# Works for slightly larger d
jnp.outer(a[0], a[0])
# Fails for slightly larger d
res = jax.vmap(jnp.outer)(a, a)
print(res.sum(axis=0)) Naive Viable: @jax.jit
def outer_sum(arr):
result = 0
for a in arr:
result += jnp.outer(a, a)
return result
incremental_outer_sum = outer_sum(a)
print(incremental_outer_sum) Python Map-Reduce: # Fastest import functools
map_reduce = functools.reduce(jnp.add, (map(jnp.outer, a, a)))
print(map_reduce) So, could a Jax primitive-based implementation pose a speed-up over the Python Map-Reduce functionality? And how would this work? |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 1 reply
-
I just figured out that I can also be using einsum_result = jnp.einsum('abc,abd->bcd', a, a)
print(einsum_result) Still, a reduce functionality would come in handy for other use cases. |
Beta Was this translation helpful? Give feedback.
-
You just need to jit it, then XLA will optimize the computation for you. (However this cannot be observed with n = 200
d = 5000
a = jnp.ones((n, d), jnp.float32)
def f(x):
return jnp.sum(jax.vmap(jnp.outer)(x, x), axis=0)
print(f(a)) # fail with n = 5000, d = 1000 or n = 200, d = 5000 (on 16G V100)
print(jax.jit(f)(a)) # succeed with n = 2000000, d = 1000 or n = 200, d = 60000 (on 16G V100)
comparison with explicit map-reduce semantic n = 200
d = 5000
a = jnp.ones((n, d), jnp.float32)
def f(x):
return jnp.sum(jax.vmap(jnp.outer)(x, x), axis=0)
def g(xs): # explicit map-reduce, success even without jit
def scan_fun(carry, x):
return carry + jnp.outer(x, x), None
out = jax.eval_shape(jnp.outer, xs[0], xs[0])
return jax.lax.scan(scan_fun, jnp.zeros(out.shape, out.dtype), xs)[0]
print(jax.xla_computation(f)(a).as_hlo_text())
print(jax.xla_computation(g)(a).as_hlo_text()) You will see f's
As you can see, there is
|
Beta Was this translation helpful? Give feedback.
You just need to jit it, then XLA will optimize the computation for you. (However this cannot be observed with
jax.xla_computation
)jax.lax.reduce
's semantic is just reduce over operands, the large intermediate array will not be eliminated if without optimization.jax.lax.scan
has map-reduce "semantic", but optimization may break it for higher parallelism when intermediate array should be store(for examp…