-
Notifications
You must be signed in to change notification settings - Fork 3.4k
Description
Problem description
Suppose we want to multiply a chain of
In particular, we want to perform a parallel associative reduction of non-scalar elements.
The following existing approaches aren't suitable:
-
lax.reduceworks only with scalars. (Note that the identity element in our case would be the$n \times n$ identity matrix, which is not a scalar.) -
lax.scandoes not perform parallel associative reduction. -
lax.associative_scanand taking the last element[-1]is memory inefficient.
The last point is confirmed by inspecting the generated HLO. It performs the full "downsweep" (prefix sum) phase of the scan algorithm, even though the intermediate prefix sums are immediately discarded.
This results in:
-
Computational overhead: We pay for 2 log_2 N steps (upsweep + downsweep) instead of the necessary log_2 N steps (upsweep only).
-
Unnecessary peak memory consumption: The compiler fails to optimize away the intermediate buffers required for the downsweep, preventing the minimal memory footprint of a pure tree reduction.
See the code below for proof.
Proposed solution
Add lax.associative_reduce: A parallel tree-reduction primitive that performs only the upsweep phase.
Use case
This would be helpful for modern ML workloads (e.g., linear RNNs and SSMs like Mamba) where reducing a chain of operators is a bottleneck.
Code
Here's code that demonstrates both the issue and the proposed solution:
import re
import jax
import jax.numpy as jnp
from jax._src.tree_util import Unspecified
def array_take_axis(array: jax.Array, axis: int, index: int | slice | jax.Array):
indices: list[int | slice | jax.Array] = [slice(None)] * array.ndim
indices[axis] = index
return array[tuple(indices)]
def tree_take_axis(tree, axis: int, index: int | slice | jax.Array):
return jax.tree.map(lambda leaf: array_take_axis(leaf, axis, index), tree)
def count_dot_ops(func, example):
lowered = jax.jit(func).lower(example)
compiled = lowered.compile()
text = compiled.as_text()
assert text is not None
matches = re.findall(r"dot_general", text)
return len(matches)
def scan_full(fn, elems, axis: int):
return jax.lax.associative_scan(fn, elems, axis=axis)
def scan_last(fn, elems, axis: int):
full = scan_full(fn, elems, axis=axis)
return tree_take_axis(full, axis, -1)
def tree_dim(tree, axis: int):
leaves = jax.tree.leaves(tree)
dim = jnp.shape(leaves[0])[axis]
if any(jnp.shape(leaf)[axis] != dim for leaf in leaves[1:]):
raise ValueError(f"Tree leaves have unequal dimension along {axis=}.")
return dim
def associative_reduce(fn, elems, axis: int, identity=Unspecified()):
while tree_dim(elems, axis) > 1:
n, parity = divmod(tree_dim(elems, axis), 2)
evens = tree_take_axis(elems, axis, slice(0, n * 2, 2))
odds = tree_take_axis(elems, axis, slice(1, n * 2, 2))
reduced = fn(evens, odds)
if parity == 1:
last = tree_take_axis(elems, axis, slice(-1, None))
elems = jax.tree.map(
lambda reduced, last: jnp.concatenate([reduced, last], axis=axis),
reduced,
last,
)
else:
elems = reduced
if tree_dim(elems, axis) == 1:
return jax.tree.map(lambda leaf: leaf.squeeze(axis=axis), elems)
else:
if isinstance(identity, Unspecified):
raise TypeError("Must specify identity for empty reduction.")
else:
return identity
def main():
SEQ_LEN = 1024
MATRIX_SIZE = 16
def fn(a, b):
return {"key": jnp.matmul(a["key"], b["key"])}
example = {
"key": jax.ShapeDtypeStruct((5, SEQ_LEN, MATRIX_SIZE, MATRIX_SIZE), jnp.float32)
}
axis = 1
print(f"{'method':20} {'dots':>10}")
for label, method in [
("scan_full", scan_full),
("scan_last", scan_last),
("associative_reduce", associative_reduce),
]:
dots = count_dot_ops(lambda elems: method(fn, elems, axis), example)
print(f"{label:20} {dots:>10}")
if __name__ == "__main__":
main()Output:
method dots
scan_full 81
scan_last 81
associative_reduce 37