diff --git a/docs/_static/pallas/distributed/collective_all_gather_operands.svg b/docs/_static/pallas/distributed/collective_all_gather_operands.svg
new file mode 100644
index 000000000000..a47610d754f3
--- /dev/null
+++ b/docs/_static/pallas/distributed/collective_all_gather_operands.svg
@@ -0,0 +1,25 @@
+
diff --git a/docs/pallas/gpu/collective_matmul.md b/docs/pallas/gpu/collective_matmul.md
new file mode 100644
index 000000000000..bea6a3754a23
--- /dev/null
+++ b/docs/pallas/gpu/collective_matmul.md
@@ -0,0 +1,305 @@
+# Collective matrix multiplication
+
+Tensor parallelism (TP) and data parallelism (DP) are the most frequently used
+parallelism techniques that make it possible to fit the ever larger models onto
+a number of accelerators. However, their joint use means that in our programs,
+we sometimes end up with data sharded in ways that don't make it directly
+possible to execute an operation without additional communication. One such
+problem frequently happens at the beginning of the MLP block of a Transformer.
+There, the input activations might be sharded on the batch axis (DP), while the
+weights might be partitioned on the output feature dimension (TP).
+
+
+
+The contraction dimension is not sharded, so it might seem that we can just
+multiply the inputs, but there is a problem: the output can't be sharded along
+the same device axis on both of its dimensions!
+
+There's a simple way to solve this problem: we can all-gather activations or
+weights (here we focus on the activation side), and then perform a local matrix
+multiplication with the other operand sharded. This simple strategy works, but
+it has a downside: we can't begin computing the matrix multiplication while the
+all-gather is running! That means we're underutilizing our hardware!
+
+To achieve better utilization, we'll show how simple it is to implement a
+Pallas:MGPU kernel that overlaps the cross-device communication with the
+matrix-multiplication, achieving almost optimal utilization on large enough
+problem shapes. Our implementation makes heavy use of the NVLINK interconnect,
+which allows us to perform high-bandwidth inter-GPU communication without
+involving the host.
+
+This approach already yields considerable performance improvements! If we
+consider a f16 matmul with M=1024, K=4096 and N=4096 and normally distributed
+data, our benchmarks indicate that it should take about 43us on a single H100.
+In the table below, we scale up the M dimension so that the per-shard shape is
+M=1024. We can compute an expected lower bound for the execution of our
+distributed kernel by multiplying that local runtime estimate by the number of
+devices and by adding about 6us for each round of communication (the memory
+fences associated with the synchronization are expensive). Benchmarking our
+kernel yields the following results:
+
+| Device count | Kernel time | TC utilization | Lower bound | TC utilization | Reference time | TC utilization |
+|--------------|-------------|----------------|-------------|----------------|----------------|----------------|
+| 2 | 102us | 68% | 92us | 75% | 147us | 47% |
+| 4 | 212us | 66% | 190us | 73% | 290us | 48% |
+| 8 | 436us | 64% | 386us | 72% | 565us | 49% |
+
+As you can see there are still some opportunities for optimization here, but at
+least we're getting much better utilization compared to the baseline
+implementation of a NCCL all gather and cuBLAS matmul.
+
+## Algorithm overview: Ring All-Gather
+
+To compute `AllGather(A) @ B`, we form a ring on the participating `D` devices.
+At each step, the device takes the last received shard (starting from its local
+shard), and passes it to the next device in the ring. While the send is
+happening, we compute the matrix multiplication between the last received `A` shard
+and the local `B` shard.
+
+
+
+More formally, the algorithm proceeds in `D` steps. In step `i` (`0 <= i < D`),
+device `d` receives shard `A_{(d + i) % D}` (we don't actually receive in the
+first step) from device `(d + 1) % D`, computes `A_{(d + i) % D} @ B_d`, and
+writes the result to a slice of the output buffer. Concurrently with the
+compute, the device `d` sends shard `A_{(i + d) % D}` to device `(i - 1) % D`
+for its use in step `i + 1` (we don't send in the last step). After `D` steps,
+device `d` will have seen every shard of `A` and computed the full output.
+
+## Pallas primitives for inter-device communication
+
+We use three Pallas functions for inter-device communication:
+
+* **`plgpu.remote_ref(ref, device_id)`**: This function takes a reference to a
+ buffer in global memory (GMEM) and returns a reference to the same buffer on a
+ *different* device, specified by `device_id`. When communicating over NVLINK,
+ this reference can be read or written to directly, even though its data is located
+ in remote memory.
+* **`pl.semaphore_signal(sem, device_id=...)`**: Increments a semaphore on a
+ target device. This is usually used to indicate completion of some process,
+ such as when we notify the remote device that the data it's waiting for has
+ been sent.
+* **`pl.semaphore_wait(sem, value=..., decrement=...)`**: Blocks until a local
+ semaphore reaches a certain value. If decrement is `True` (default), the
+ value of the semaphore is decreased by the awaited amount. If it is `False`,
+ the operation is more efficient, but it does not modify the value of the
+ semaphore after the wait completes. This is frequently used to await signals
+ from a remote device.
+
+## Implementation with Pallas
+
+```{note}
+Here, we only present a simplified version of the kernel, which allows us to
+focus on the most interesting details. You can find [the full implementation in
+our examples directory](https://github.com/jax-ml/jax/blob/main/jax/experimental/pallas/ops/gpu/collective_matmul_mgpu.py).
+```
+
+First, we focus on the set-up of our kernel. For the compute part, we will reuse
+our optimized matmul kernel implementation from `hopper_matmul_mgpu`. Since the
+compute kernel will utilize warp-specialization, we use 3 Pallas threads. It
+is also persistent, which means that we launch a grid as large as the number of
+SMs (queried from `.core_count` on the JAX device). The compute kernel uses
+`pl.run_scoped` for SMEM allocations, so we don't use `scratch_shapes`.
+
+```python
+def all_gather_lhs_matmul(
+ lhs: jax.Array,
+ rhs: jax.Array,
+ axis_name,
+ *,
+ config: hopper_matmul_mgpu.TuningConfig,
+ dtype: jnp.dtype = jnp.bfloat16,
+) -> jax.Array:
+ if (num_devices := jax.device_count()) != jax.process_count():
+ raise ValueError("The kernel only supports one device per process")
+ if (axis_size := lax.axis_size(axis_name)) != num_devices:
+ raise ValueError("The kernel can only work over all devices in a Mesh.")
+ ...
+
+ m_shard, k = lhs.shape
+ _, n_shard = rhs.shape
+ tile_m, tile_n, tile_k = config.tile_m, config.tile_n, config.tile_k
+ cta_tile_m = tile_m * (1 + (config.wg_dimension == MatmulDimension.M))
+ num_sms = jax.extend.backend.get_default_device().core_count
+
+ def kernel_body(lhs_local_ref, rhs_ref, out_ref, scratch_ref):
+ ...
+
+ result, _ = plgpu.kernel(
+ kernel_body,
+ out_shape=[
+ # The output (with M gathered)
+ jax.ShapeDtypeStruct((axis_size * m_shard, n_shard), dtype),
+ # A scratch buffer for LHS all-gather
+ jax.ShapeDtypeStruct((axis_size - 1, m_shard, k), dtype),
+ ],
+ grid=(num_sms,),
+ num_threads=3, # The matmul kernel uses 3 threads: 2 compute and 1 memory
+ thread_name="wg",
+ )(lhs, rhs)
+ return result
+```
+
+The kernel above has two outputs. First one is the actual result of our
+primitive, while the second one is used as a scratch space to receive the left
+operands. Note that we could shrink the leading axis to be smaller than
+`axis_size - 1`, but at that point we would need to introduce backpressure to
+the sending devices, which requires additional expensive communication.
+
+```{note}
+You can see how to deal with this backpressure in the [TPU distributed communication guide](../tpu/distributed.md#run-ahead-and-race-conditions).
+```
+
+Let us now look at the outline of the kernel body:
+
+```python
+def all_gather_lhs_matmul(...):
+ def kernel_body(lhs_local_ref, rhs_ref, out_ref, scratch_ref, out_smem, received_sem):
+ wg_idx = lax.axis_index("wg")
+ dev_id = lax.axis_index(axis_name)
+ # This device sends to dev_id - 1, forming a ring.
+ send_dev_id = lax.rem(dev_id + axis_size - 1, axis_size)
+ send_scratch_ref = plgpu.remote_ref(
+ scratch_ref, send_dev_id, device_id_type=pl.DeviceIdType.LOGICAL
+ )
+
+ def device_step(lhs_source_ref, device_offset):
+ # Invariant: lhs_source_ref contains A_{(dev_id + device_offset) % D}
+ # and is ready to be used for computation.
+
+ ...
+
+ # We peel the first step to read data directly from lhs_local_ref.
+ device_step(lhs_local_ref, 0)
+ @pl.loop(1, num_devices)
+ def _device_loop(device_offset):
+ device_step(scratch_ref.at[device_offset - 1], device_offset)
+```
+
+We locate our position in the ring by querying `lax.axis_index(axis_name)` and
+compute the index of the next device, to which we will be sending the data
+(`send_dev_id`). Then, we loop over invocations of the `device_body` as many
+times as there are devices. We peel the first step of the loop, because we use
+the local reference as the source for the send in that step only (after that the
+sends originate from the data previously received in the scratch buffer).
+
+We are ready to investigate the main loop now:
+
+```python
+def all_gather_lhs_matmul(...):
+ ...
+
+ def kernel_body(lhs_local_ref, rhs_ref, out_ref, scratch_ref, out_smem, received_sem):
+ ...
+
+ def device_step(lhs_source_ref, device_offset):
+ # We are computing block (dev_id + device_offset) % D of the output.
+ out_device_idx = lax.rem(device_offset + dev_id, axis_size)
+ out_device_m_slice = pl.ds(out_device_idx * m_shard, m_shard)
+
+ # In step `device_offset`, we send A_{(dev_id + device_offset) % D} to
+ # the next device in the ring, into scratch slot `device_offset`.
+ # We also don't send on the last step since that would return the data
+ # back to its original source.
+ next_scratch_slot = device_offset
+ is_send_wg = wg_idx == 0 # Only one warpgroup per CTA sends
+ has_send_space = next_scratch_slot < axis_size - 1
+ should_send = is_send_wg & has_send_space
+
+ # This function will be called by hopper_matmul_mgpu.kernel in the body
+ # of its pipeline. We use it to take the tile of LHS loaded into SMEM and
+ # issue a TMA send to the next device in the ring.
+ def send_lhs(m_idx, n_idx, k_idx, a_smem, b_smem, send_ref, should_send):
+ del b_smem # Unused.
+ # We only send when n_idx == 0 to avoid sending the same data
+ # multiple times when revisiting the left operand.
+ @pl.when(should_send & jnp.bool(n_idx == 0))
+ def _():
+ k_slice = pl.ds(k_idx * tile_k, tile_k)
+ m_slice = pl.ds(m_idx * cta_tile_m, cta_tile_m)
+ plgpu.copy_smem_to_gmem(a_smem, send_ref.at[m_slice, k_slice])
+ # Wait for previous copies to complete. We pass in delay_release=1
+ # to the pipeline in the matmul kernel to ensure that it doesn't
+ # overwrite the input until at least the next step completes, but it
+ # will not wait any longer.
+ plgpu.wait_smem_to_gmem(1, wait_read_only=True)
+
+ hopper_matmul_mgpu.kernel(
+ lhs_source_ref, # LHS shard for this step
+ rhs_ref, # RHS shard is always the same
+ out_ref.at[out_device_m_slice], # Slice of output to update
+ out_smem,
+ config=config,
+ pipeline_callback=functools.partial(
+ send_lhs,
+ send_ref=send_scratch_ref.at[next_scratch_slot],
+ should_send=should_send,
+ ),
+ delay_release=1,
+ )
+
+ # Wait for the next scratch to arrive for the next step's computation.
+ # Each device signals its neighbor when it has finished sending.
+ @pl.when(should_send)
+ def _signal():
+ # Make sure our remote copy is done, then signal.
+ plgpu.wait_smem_to_gmem(0, wait_read_only=False)
+ pl.semaphore_signal(received_sem, device_id=send_dev_id)
+ @pl.when(has_send_space)
+ def _wait():
+ # Here, we wait for the data to arrive from the previous device in the
+ # ring. At each step, will expect to receive a signal from each SM.
+ # We use decrement=False to make this operation slightly faster, but
+ # this also means that we need to scale the expected number of signals
+ # by the number of steps taken so far (as the value only increases).
+ pl.semaphore_wait(received_sem, value=(device_offset + 1) * num_sms, decrement=False)
+
+ ...
+```
+
+A few things happen here in a sequence:
+1. We begin by computing the slice of the
+ output that we will compute at this step of the loop.
+2. Then, we call into the optimized matmul kernel, but injecting it with a
+ `pipeline_callback`. We use it to take advantage of the fact that the compute
+ kernel has to fetch the left operand into SMEM, and we instruct the TMA engine
+ to asynchronously stream the local data to the next device. The traffic is
+ transparently routed through NVLINK by the hardware. It is worth noting that we
+ only issue sends from one of the compute threads and only when we visit the left
+ operand for the first time (it might be reloaded many times to compute many
+ output tiles).
+3. Finally, the sending thread makes sure that the sends have completed and
+ signals the `received_sem` on the receiving device to indicate that. After
+ that, all threads wait until they are sure that all the data for the next
+ step of the loop has been received (the wait is skipped on the last step).
+
+## Integrating the kernel with JAX
+
+To invoke the kernel, you need to wrap it into `jax.shard_map`:
+```python
+m_shard, n_shard, k = 1024, 1024, 1024
+dtype = jnp.float16
+mesh = jax.make_mesh((jax.device_count(),), ("x",),
+ axis_types=(jax.sharding.AxisType.Explicit,))
+with jax.set_mesh(mesh):
+ a = jax.random.normal(jax.random.key(1), (m_shard * jax.device_count(), k), dtype)
+ b = jax.random.normal(jax.random.key(2), (k, n_shard * jax.device_count()), dtype)
+ a = jax.sharding.reshard(a, P("x", None))
+ b = jax.sharding.reshard(b, P(None, "x"))
+
+ # Example config for 8xH100. You might need to retune to your shape.
+ config = hopper_matmul_mgpu.TuningConfig(
+ tile_m=128, tile_n=128, tile_k=64, max_concurrent_steps=4,
+ grid_minor_dim=MatmulDimension.N, grid_tile_width=8,
+ wg_dimension=MatmulDimension.N,
+ )
+
+ kernel = jax.jit(
+ jax.shard_map(
+ functools.partial(all_gather_lhs_matmul, axis_name="x", config=config),
+ out_specs=P(None, "x"),
+ check_vma=False,
+ )
+ )
+ c = kernel(a, b)
+```
\ No newline at end of file
diff --git a/docs/pallas/gpu/index.rst b/docs/pallas/gpu/index.rst
index 840c8428dffe..1dc3b2e3373e 100644
--- a/docs/pallas/gpu/index.rst
+++ b/docs/pallas/gpu/index.rst
@@ -9,6 +9,7 @@ Backend specific documentation for the Mosaic GPU backend.
reference
pipelining
blackwell_matmul
+ collective_matmul
.. toctree::
:caption: Guides