-
Is there a way to get the jax.jit/XLA pipeline to optimize the order of matrix multiplications? I'm currently working on a project that requires lots of chained linear transformations on large coordinate grids, i.e. of the form A_1 @ A_2 @ A_3 @ ... @ A_n @ coords # A_n: [3 x 3], coords: [3 x 100000] However, due to dynamic (user-defined) behaviour at compile-time, the code will actually look more like coords = <initial_coordinates>
coords = A_n @ coords
...
coords = A_2 @ coords
coords = A_1 @ coords Clearly, doing this naively is much less efficient than doing the left-associative I was hoping that just wrapping my code in a However that doesn't appear to be the case. Is this not done for numerical reasons, or will it be supported at some point? Here's a MWE of what I'm trying to do: import jax
import jax.numpy as jnp
def make_fn(shortcut=False):
def fn(theta):
# Create some coordinates and a simple rotation matrix
coords = jnp.mgrid[-3000:3000, -3000:3000].reshape(2, -1) / 100
rot = jnp.array([[jnp.cos(theta), jnp.sin(theta)], [-jnp.sin(theta), jnp.cos(theta)]])
if shortcut:
transform = rot @ -rot @ rot @ -rot @ rot @ -rot @ rot @ -rot
coords = transform @ coords
else:
coords = rot @ coords
coords = -rot @ coords
coords = rot @ coords
coords = -rot @ coords
coords = rot @ coords
coords = -rot @ coords
coords = rot @ coords
coords = -rot @ coords
return coords
return fn
fn_naive = make_fn(shortcut=False)
fn_quick = make_fn(shortcut=True) Timing the functions with and without the "manual" shortcut makes quite the difference:
Looking at the generated HLO via
|
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
I don't think XLA does optimization of matmul orderings, but JAX does expose |
Beta Was this translation helpful? Give feedback.
I don't think XLA does optimization of matmul orderings, but JAX does expose
jnp.multi_dot
for this purpose.