|
12 | 12 | | Multi Token Attention | `liger_kernel.transformers.LigerMultiTokenAttention` | |
13 | 13 | | Softmax | `liger_kernel.transformers.LigerSoftmax` | |
14 | 14 | | Sparsemax | `liger_kernel.transformers.LigerSparsemax` | |
| 15 | +| mHC (Hyper-Connections) | `liger_kernel.transformers.LigerMHC` | |
15 | 16 |
|
16 | 17 |
|
17 | 18 | ### RMS Norm |
@@ -72,6 +73,41 @@ Sparsemax is a sparse alternative to softmax that produces sparse probability di |
72 | 73 |
|
73 | 74 | The implementation achieves significant speed improvements and memory savings compared to standard PyTorch implementations, particularly for large input tensors. |
74 | 75 |
|
| 76 | +### mHC (Manifold-Constrained Hyper-Connections) |
| 77 | + |
| 78 | +mHC implements fused Triton kernels for Manifold-Constrained Hyper-Connections ([arXiv:2512.24880](https://arxiv.org/abs/2512.24880)). It wraps an arbitrary layer `F: [..., C] -> [..., C]` with multiple residual streams, constraining the residual routing matrix `H_res` onto the Birkhoff polytope (doubly-stochastic matrices) via Sinkhorn-Knopp iterations to stabilize training. |
| 79 | + |
| 80 | +The `LigerMHC` module takes input of shape `[..., HC, C]` where `HC` is the number of residual streams, and performs: |
| 81 | + |
| 82 | +1. **Coefficients** -- Compute data-dependent routing coefficients (`h_pre`, `h_post`, `h_res`) via fused matmul + RMS normalization + Sinkhorn-Knopp iterations. |
| 83 | +2. **Pre-aggregate** -- `x_in = sum_i h_pre[i] * x[i]` (shape: `[..., C]`) |
| 84 | +3. **Layer** -- `f_out = layer(x_in)` (shape: `[..., C]`) |
| 85 | +4. **Post + residual** -- `x_out[o] = sum_i h_res[o,i] * x[i] + h_post[o] * f_out` (shape: `[..., HC, C]`) |
| 86 | + |
| 87 | +Usage: |
| 88 | + |
| 89 | +```python |
| 90 | +import torch |
| 91 | +import torch.nn as nn |
| 92 | +from liger_kernel.transformers import LigerMHC |
| 93 | + |
| 94 | +# Wrap a linear layer with 4 residual streams of dimension 256 |
| 95 | +layer = nn.Linear(256, 256, bias=False, device="cuda", dtype=torch.bfloat16) |
| 96 | +mhc = LigerMHC(layer, hc=4, c=256, phi_dtype=torch.bfloat16).cuda() |
| 97 | + |
| 98 | +# Input: [batch, seq_len, num_streams, channels] in BF16/FP16 |
| 99 | +x = torch.randn(2, 128, 4, 256, device="cuda", dtype=torch.bfloat16) |
| 100 | +out = mhc(x) # shape: [2, 128, 4, 256] |
| 101 | +``` |
| 102 | + |
| 103 | +Functional APIs are also available: |
| 104 | + |
| 105 | +- `liger_kernel.transformers.functional.liger_mhc_coeffs` -- Compute routing coefficients |
| 106 | +- `liger_kernel.transformers.functional.liger_mhc_pre` -- Pre-aggregation |
| 107 | +- `liger_kernel.transformers.functional.liger_mhc_post_res` -- Post-aggregation + residual |
| 108 | +- `liger_kernel.transformers.functional.liger_mhc_apply` -- Combined pre + post_res |
| 109 | +- `liger_kernel.transformers.functional.liger_mhc_forward` -- Full forward pass (coeffs + pre + layer + post_res) |
| 110 | + |
75 | 111 | ## Alignment Kernels |
76 | 112 |
|
77 | 113 | | **Kernel** | **API** | |
|
0 commit comments