Skip to content

Commit 5841280

Browse files
authored
Add mHC kernel documentation to README and API reference (#1132)
## Summary - Add `mHC (Hyper-Connections)` entry to the Model Kernels table in `README.md` and `docs/Low-Level-APIs.md` - Add detailed description section in `docs/Low-Level-APIs.md` with architecture overview, usage example, and functional API reference ## Reference Issue - Follows up on #1065 which added the mHC fused kernels but did not update documentation
1 parent 73cb70b commit 5841280

File tree

2 files changed

+37
-0
lines changed

2 files changed

+37
-0
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,7 @@ loss.backward()
293293
| Multi Token Attention | `liger_kernel.transformers.LigerMultiTokenAttention` |
294294
| Softmax | `liger_kernel.transformers.LigerSoftmax` |
295295
| Sparsemax | `liger_kernel.transformers.LigerSparsemax` |
296+
| mHC (Hyper-Connections) | `liger_kernel.transformers.LigerMHC` |
296297

297298

298299
### Alignment Kernels

docs/Low-Level-APIs.md

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
| Multi Token Attention | `liger_kernel.transformers.LigerMultiTokenAttention` |
1313
| Softmax | `liger_kernel.transformers.LigerSoftmax` |
1414
| Sparsemax | `liger_kernel.transformers.LigerSparsemax` |
15+
| mHC (Hyper-Connections) | `liger_kernel.transformers.LigerMHC` |
1516

1617

1718
### RMS Norm
@@ -72,6 +73,41 @@ Sparsemax is a sparse alternative to softmax that produces sparse probability di
7273

7374
The implementation achieves significant speed improvements and memory savings compared to standard PyTorch implementations, particularly for large input tensors.
7475

76+
### mHC (Manifold-Constrained Hyper-Connections)
77+
78+
mHC implements fused Triton kernels for Manifold-Constrained Hyper-Connections ([arXiv:2512.24880](https://arxiv.org/abs/2512.24880)). It wraps an arbitrary layer `F: [..., C] -> [..., C]` with multiple residual streams, constraining the residual routing matrix `H_res` onto the Birkhoff polytope (doubly-stochastic matrices) via Sinkhorn-Knopp iterations to stabilize training.
79+
80+
The `LigerMHC` module takes input of shape `[..., HC, C]` where `HC` is the number of residual streams, and performs:
81+
82+
1. **Coefficients** -- Compute data-dependent routing coefficients (`h_pre`, `h_post`, `h_res`) via fused matmul + RMS normalization + Sinkhorn-Knopp iterations.
83+
2. **Pre-aggregate** -- `x_in = sum_i h_pre[i] * x[i]` (shape: `[..., C]`)
84+
3. **Layer** -- `f_out = layer(x_in)` (shape: `[..., C]`)
85+
4. **Post + residual** -- `x_out[o] = sum_i h_res[o,i] * x[i] + h_post[o] * f_out` (shape: `[..., HC, C]`)
86+
87+
Usage:
88+
89+
```python
90+
import torch
91+
import torch.nn as nn
92+
from liger_kernel.transformers import LigerMHC
93+
94+
# Wrap a linear layer with 4 residual streams of dimension 256
95+
layer = nn.Linear(256, 256, bias=False, device="cuda", dtype=torch.bfloat16)
96+
mhc = LigerMHC(layer, hc=4, c=256, phi_dtype=torch.bfloat16).cuda()
97+
98+
# Input: [batch, seq_len, num_streams, channels] in BF16/FP16
99+
x = torch.randn(2, 128, 4, 256, device="cuda", dtype=torch.bfloat16)
100+
out = mhc(x) # shape: [2, 128, 4, 256]
101+
```
102+
103+
Functional APIs are also available:
104+
105+
- `liger_kernel.transformers.functional.liger_mhc_coeffs` -- Compute routing coefficients
106+
- `liger_kernel.transformers.functional.liger_mhc_pre` -- Pre-aggregation
107+
- `liger_kernel.transformers.functional.liger_mhc_post_res` -- Post-aggregation + residual
108+
- `liger_kernel.transformers.functional.liger_mhc_apply` -- Combined pre + post_res
109+
- `liger_kernel.transformers.functional.liger_mhc_forward` -- Full forward pass (coeffs + pre + layer + post_res)
110+
75111
## Alignment Kernels
76112

77113
| **Kernel** | **API** |

0 commit comments

Comments
 (0)