Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,7 @@ loss.backward()
| Multi Token Attention | `liger_kernel.transformers.LigerMultiTokenAttention` |
| Softmax | `liger_kernel.transformers.LigerSoftmax` |
| Sparsemax | `liger_kernel.transformers.LigerSparsemax` |
| mHC (Hyper-Connections) | `liger_kernel.transformers.LigerMHC` |


### Alignment Kernels
Expand Down
36 changes: 36 additions & 0 deletions docs/Low-Level-APIs.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
| Multi Token Attention | `liger_kernel.transformers.LigerMultiTokenAttention` |
| Softmax | `liger_kernel.transformers.LigerSoftmax` |
| Sparsemax | `liger_kernel.transformers.LigerSparsemax` |
| mHC (Hyper-Connections) | `liger_kernel.transformers.LigerMHC` |


### RMS Norm
Expand Down Expand Up @@ -72,6 +73,41 @@ Sparsemax is a sparse alternative to softmax that produces sparse probability di

The implementation achieves significant speed improvements and memory savings compared to standard PyTorch implementations, particularly for large input tensors.

### mHC (Manifold-Constrained Hyper-Connections)

mHC implements fused Triton kernels for Manifold-Constrained Hyper-Connections ([arXiv:2512.24880](https://arxiv.org/abs/2512.24880)). It wraps an arbitrary layer `F: [..., C] -> [..., C]` with multiple residual streams, constraining the residual routing matrix `H_res` onto the Birkhoff polytope (doubly-stochastic matrices) via Sinkhorn-Knopp iterations to stabilize training.

The `LigerMHC` module takes input of shape `[..., HC, C]` where `HC` is the number of residual streams, and performs:

1. **Coefficients** -- Compute data-dependent routing coefficients (`h_pre`, `h_post`, `h_res`) via fused matmul + RMS normalization + Sinkhorn-Knopp iterations.
2. **Pre-aggregate** -- `x_in = sum_i h_pre[i] * x[i]` (shape: `[..., C]`)
3. **Layer** -- `f_out = layer(x_in)` (shape: `[..., C]`)
4. **Post + residual** -- `x_out[o] = sum_i h_res[o,i] * x[i] + h_post[o] * f_out` (shape: `[..., HC, C]`)

Usage:

```python
import torch
import torch.nn as nn
from liger_kernel.transformers import LigerMHC

# Wrap a linear layer with 4 residual streams of dimension 256
layer = nn.Linear(256, 256, bias=False, device="cuda", dtype=torch.bfloat16)
mhc = LigerMHC(layer, hc=4, c=256, phi_dtype=torch.bfloat16).cuda()

# Input: [batch, seq_len, num_streams, channels] in BF16/FP16
x = torch.randn(2, 128, 4, 256, device="cuda", dtype=torch.bfloat16)
out = mhc(x) # shape: [2, 128, 4, 256]
```

Functional APIs are also available:

- `liger_kernel.transformers.functional.liger_mhc_coeffs` -- Compute routing coefficients
- `liger_kernel.transformers.functional.liger_mhc_pre` -- Pre-aggregation
- `liger_kernel.transformers.functional.liger_mhc_post_res` -- Post-aggregation + residual
- `liger_kernel.transformers.functional.liger_mhc_apply` -- Combined pre + post_res
- `liger_kernel.transformers.functional.liger_mhc_forward` -- Full forward pass (coeffs + pre + layer + post_res)

## Alignment Kernels

| **Kernel** | **API** |
Expand Down