Skip to content

Add lax.associative_reduce (parallel tree reduction without downsweep) #35118

@carlosgmartin

Description

@carlosgmartin

Problem description

Suppose we want to multiply a chain of $n \times n$ matrices as efficiently as possible.

In particular, we want to perform a parallel associative reduction of non-scalar elements.

The following existing approaches aren't suitable:

  • lax.reduce works only with scalars. (Note that the identity element in our case would be the $n \times n$ identity matrix, which is not a scalar.)

  • lax.scan does not perform parallel associative reduction.

  • lax.associative_scan and taking the last element [-1] is memory inefficient.

The last point is confirmed by inspecting the generated HLO. It performs the full "downsweep" (prefix sum) phase of the scan algorithm, even though the intermediate prefix sums are immediately discarded.

This results in:

  • Computational overhead: We pay for 2 log_2 N steps (upsweep + downsweep) instead of the necessary log_2 N steps (upsweep only).

  • Unnecessary peak memory consumption: The compiler fails to optimize away the intermediate buffers required for the downsweep, preventing the minimal memory footprint of a pure tree reduction.

See the code below for proof.

Proposed solution

Add lax.associative_reduce: A parallel tree-reduction primitive that performs only the upsweep phase.

Use case

This would be helpful for modern ML workloads (e.g., linear RNNs and SSMs like Mamba) where reducing a chain of operators is a bottleneck.

Code

Here's code that demonstrates both the issue and the proposed solution:

import re

import jax
import jax.numpy as jnp
from jax._src.tree_util import Unspecified


def array_take_axis(array: jax.Array, axis: int, index: int | slice | jax.Array):
  indices: list[int | slice | jax.Array] = [slice(None)] * array.ndim
  indices[axis] = index
  return array[tuple(indices)]


def tree_take_axis(tree, axis: int, index: int | slice | jax.Array):
  return jax.tree.map(lambda leaf: array_take_axis(leaf, axis, index), tree)


def count_dot_ops(func, example):
  lowered = jax.jit(func).lower(example)
  compiled = lowered.compile()
  text = compiled.as_text()
  assert text is not None
  matches = re.findall(r"dot_general", text)
  return len(matches)


def scan_full(fn, elems, axis: int):
  return jax.lax.associative_scan(fn, elems, axis=axis)


def scan_last(fn, elems, axis: int):
  full = scan_full(fn, elems, axis=axis)
  return tree_take_axis(full, axis, -1)


def tree_dim(tree, axis: int):
  leaves = jax.tree.leaves(tree)
  dim = jnp.shape(leaves[0])[axis]
  if any(jnp.shape(leaf)[axis] != dim for leaf in leaves[1:]):
    raise ValueError(f"Tree leaves have unequal dimension along {axis=}.")
  return dim


def associative_reduce(fn, elems, axis: int, identity=Unspecified()):
  while tree_dim(elems, axis) > 1:
    n, parity = divmod(tree_dim(elems, axis), 2)
    evens = tree_take_axis(elems, axis, slice(0, n * 2, 2))
    odds = tree_take_axis(elems, axis, slice(1, n * 2, 2))
    reduced = fn(evens, odds)
    if parity == 1:
      last = tree_take_axis(elems, axis, slice(-1, None))
      elems = jax.tree.map(
        lambda reduced, last: jnp.concatenate([reduced, last], axis=axis),
        reduced,
        last,
      )
    else:
      elems = reduced

  if tree_dim(elems, axis) == 1:
    return jax.tree.map(lambda leaf: leaf.squeeze(axis=axis), elems)
  else:
    if isinstance(identity, Unspecified):
      raise TypeError("Must specify identity for empty reduction.")
    else:
      return identity


def main():
  SEQ_LEN = 1024
  MATRIX_SIZE = 16

  def fn(a, b):
    return {"key": jnp.matmul(a["key"], b["key"])}

  example = {
    "key": jax.ShapeDtypeStruct((5, SEQ_LEN, MATRIX_SIZE, MATRIX_SIZE), jnp.float32)
  }

  axis = 1

  print(f"{'method':20} {'dots':>10}")
  for label, method in [
    ("scan_full", scan_full),
    ("scan_last", scan_last),
    ("associative_reduce", associative_reduce),
  ]:
    dots = count_dot_ops(lambda elems: method(fn, elems, axis), example)
    print(f"{label:20} {dots:>10}")


if __name__ == "__main__":
  main()

Output:

method                dots
scan_full               81
scan_last               81
associative_reduce      37

Metadata

Metadata

Assignees

Labels

enhancementNew feature or request

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions