Skip to content

Commit 68c9617

Browse files
Merge branch 'main' into update-grpo-loss-type
2 parents 840c691 + adb2238 commit 68c9617

File tree

5 files changed

+166
-9
lines changed

5 files changed

+166
-9
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,7 @@ loss.backward()
293293
| Multi Token Attention | `liger_kernel.transformers.LigerMultiTokenAttention` |
294294
| Softmax | `liger_kernel.transformers.LigerSoftmax` |
295295
| Sparsemax | `liger_kernel.transformers.LigerSparsemax` |
296+
| mHC (Hyper-Connections) | `liger_kernel.transformers.LigerMHC` |
296297

297298

298299
### Alignment Kernels

docs/Low-Level-APIs.md

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
| Multi Token Attention | `liger_kernel.transformers.LigerMultiTokenAttention` |
1313
| Softmax | `liger_kernel.transformers.LigerSoftmax` |
1414
| Sparsemax | `liger_kernel.transformers.LigerSparsemax` |
15+
| mHC (Hyper-Connections) | `liger_kernel.transformers.LigerMHC` |
1516

1617

1718
### RMS Norm
@@ -72,6 +73,41 @@ Sparsemax is a sparse alternative to softmax that produces sparse probability di
7273

7374
The implementation achieves significant speed improvements and memory savings compared to standard PyTorch implementations, particularly for large input tensors.
7475

76+
### mHC (Manifold-Constrained Hyper-Connections)
77+
78+
mHC implements fused Triton kernels for Manifold-Constrained Hyper-Connections ([arXiv:2512.24880](https://arxiv.org/abs/2512.24880)). It wraps an arbitrary layer `F: [..., C] -> [..., C]` with multiple residual streams, constraining the residual routing matrix `H_res` onto the Birkhoff polytope (doubly-stochastic matrices) via Sinkhorn-Knopp iterations to stabilize training.
79+
80+
The `LigerMHC` module takes input of shape `[..., HC, C]` where `HC` is the number of residual streams, and performs:
81+
82+
1. **Coefficients** -- Compute data-dependent routing coefficients (`h_pre`, `h_post`, `h_res`) via fused matmul + RMS normalization + Sinkhorn-Knopp iterations.
83+
2. **Pre-aggregate** -- `x_in = sum_i h_pre[i] * x[i]` (shape: `[..., C]`)
84+
3. **Layer** -- `f_out = layer(x_in)` (shape: `[..., C]`)
85+
4. **Post + residual** -- `x_out[o] = sum_i h_res[o,i] * x[i] + h_post[o] * f_out` (shape: `[..., HC, C]`)
86+
87+
Usage:
88+
89+
```python
90+
import torch
91+
import torch.nn as nn
92+
from liger_kernel.transformers import LigerMHC
93+
94+
# Wrap a linear layer with 4 residual streams of dimension 256
95+
layer = nn.Linear(256, 256, bias=False, device="cuda", dtype=torch.bfloat16)
96+
mhc = LigerMHC(layer, hc=4, c=256, phi_dtype=torch.bfloat16).cuda()
97+
98+
# Input: [batch, seq_len, num_streams, channels] in BF16/FP16
99+
x = torch.randn(2, 128, 4, 256, device="cuda", dtype=torch.bfloat16)
100+
out = mhc(x) # shape: [2, 128, 4, 256]
101+
```
102+
103+
Functional APIs are also available:
104+
105+
- `liger_kernel.transformers.functional.liger_mhc_coeffs` -- Compute routing coefficients
106+
- `liger_kernel.transformers.functional.liger_mhc_pre` -- Pre-aggregation
107+
- `liger_kernel.transformers.functional.liger_mhc_post_res` -- Post-aggregation + residual
108+
- `liger_kernel.transformers.functional.liger_mhc_apply` -- Combined pre + post_res
109+
- `liger_kernel.transformers.functional.liger_mhc_forward` -- Full forward pass (coeffs + pre + layer + post_res)
110+
75111
## Alignment Kernels
76112

77113
| **Kernel** | **API** |

src/liger_kernel/ops/swiglu.py

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,13 +104,48 @@ class LigerSiLUMulFunction(torch.autograd.Function):
104104
@staticmethod
105105
@ensure_contiguous
106106
def forward(ctx, a, b):
107-
a, b, c = swiglu_forward(a, b)
108-
ctx.save_for_backward(a, b)
109-
return c
107+
if isinstance(a, torch.distributed.tensor.DTensor) or isinstance(b, torch.distributed.tensor.DTensor):
108+
device_mesh, placements = (
109+
(a.device_mesh, a.placements)
110+
if isinstance(a, torch.distributed.tensor.DTensor)
111+
else (b.device_mesh, b.placements)
112+
)
113+
114+
# Assume that full tensors are gathered before and identical across
115+
# the associated process groups.
116+
if not isinstance(a, torch.distributed.tensor.DTensor):
117+
a = torch.distributed.tensor.distribute_tensor(a, device_mesh=device_mesh, placements=placements)
118+
if not isinstance(b, torch.distributed.tensor.DTensor):
119+
b = torch.distributed.tensor.distribute_tensor(b, device_mesh=device_mesh, placements=placements)
120+
a_local, b_local, c_local = swiglu_forward(a.to_local(), b.to_local())
121+
ctx.save_for_backward(a_local, b_local)
122+
ctx.dtensor_metadata = (device_mesh, placements)
123+
return torch.distributed.tensor.DTensor.from_local(c_local, device_mesh, placements)
124+
else:
125+
a, b, c = swiglu_forward(a, b)
126+
ctx.save_for_backward(a, b)
127+
ctx.dtensor_metadata = None
128+
return c
110129

111130
@staticmethod
112131
@ensure_contiguous
113132
def backward(ctx, dc):
114133
a, b = ctx.saved_tensors
134+
if ctx.dtensor_metadata is not None:
135+
device_mesh, placements = ctx.dtensor_metadata
136+
137+
# Assume that full tensors are gathered before and identical across
138+
# the associated process groups.
139+
dc_local = (
140+
dc.to_local()
141+
if isinstance(dc, torch.distributed.tensor.DTensor)
142+
else torch.distributed.tensor.distribute_tensor(dc, device_mesh=device_mesh, placements=placements)
143+
)
144+
a_local, b_local = swiglu_backward(a, b, dc_local)
145+
return (
146+
torch.distributed.tensor.DTensor.from_local(a_local, device_mesh, placements),
147+
torch.distributed.tensor.DTensor.from_local(b_local, device_mesh, placements),
148+
)
149+
115150
a, b = swiglu_backward(a, b, dc)
116151
return a, b

test/convergence/fp32/test_mini_models.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1618,12 +1618,12 @@ def run_mini_model(
16181618
not LLAMA4_AVAILABLE,
16191619
reason="Llama4 not available in this version of trasnformers",
16201620
),
1621-
pytest.mark.xfail(
1622-
reason=(
1623-
"RuntimeError: Expected query, key, and value to have the same dtype, but got query.dtype:"
1624-
" float key.dtype: c10::BFloat16 and value.dtype: c10::BFloat16 instead."
1625-
)
1626-
),
1621+
# pytest.mark.xfail(
1622+
# reason=(
1623+
# "RuntimeError: Expected query, key, and value to have the same dtype, but got query.dtype:"
1624+
# " float key.dtype: c10::BFloat16 and value.dtype: c10::BFloat16 instead."
1625+
# )
1626+
# ),
16271627
],
16281628
),
16291629
("mini_llama3", 32, 1e-4, torch.float32, 1e-8, 2e-5, 5e-3, 1e-5, 5e-3, 1e-5),

test/transformers/test_swiglu.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
1+
import tempfile
2+
13
import pytest
24
import torch
5+
import torch.multiprocessing as mp
36
import transformers
47

58
from packaging import version
@@ -16,6 +19,7 @@
1619
from liger_kernel.transformers.swiglu import LigerExperts
1720
from liger_kernel.transformers.swiglu import LigerPhi3SwiGLUMLP
1821
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
22+
from liger_kernel.utils import infer_comm_backend
1923
from liger_kernel.utils import infer_device
2024

2125
IS_TRANSFORMERS_V5_OR_LATER = version.parse(transformers.__version__) >= version.parse("5.0.0")
@@ -405,3 +409,84 @@ def test_correctness_functional(bsz, seq_len, size, dtype, atol, rtol):
405409
# Check if gradients are close for x
406410
assert torch.allclose(x1.grad, x2.grad, atol=atol, rtol=rtol)
407411
assert torch.allclose(b1.grad, b2.grad, atol=atol, rtol=rtol)
412+
413+
414+
def _test_dtensor_liger_silumul(rank, world_size, bsz, seq_len, hidden_size, dtype, atol, rtol, file_name):
415+
torch.distributed.init_process_group(
416+
backend=infer_comm_backend(),
417+
init_method=f"file://{file_name}",
418+
rank=rank,
419+
world_size=world_size,
420+
)
421+
device = f"{infer_device()}:{rank}" if infer_device() != "cpu" else "cpu"
422+
device_mesh = torch.distributed.device_mesh.init_device_mesh(
423+
infer_device(), mesh_shape=(world_size,), mesh_dim_names=("tp",)
424+
)
425+
426+
_a = torch.randn(bsz, seq_len, hidden_size, device=device, dtype=dtype)
427+
_b = torch.randn(bsz, seq_len, hidden_size, device=device, dtype=dtype)
428+
429+
# Broadcast from rank 0 so all ranks operate on identical tensors
430+
torch.distributed.broadcast(_a, src=0)
431+
torch.distributed.broadcast(_b, src=0)
432+
433+
assert hidden_size % world_size == 0, f"hidden_size ({hidden_size}) must be divisible by world_size ({world_size})"
434+
435+
# DTensor path: shard inputs along the hidden dim
436+
a1 = _a.clone().detach().requires_grad_(True)
437+
b1 = _b.clone().detach().requires_grad_(True)
438+
da = torch.distributed.tensor.distribute_tensor(
439+
a1, device_mesh=device_mesh, placements=[torch.distributed.tensor.Shard(2)]
440+
)
441+
db = torch.distributed.tensor.distribute_tensor(
442+
b1, device_mesh=device_mesh, placements=[torch.distributed.tensor.Shard(2)]
443+
)
444+
445+
# Regular tensor path
446+
a2 = _a.clone().detach().requires_grad_(True)
447+
b2 = _b.clone().detach().requires_grad_(True)
448+
449+
c1 = LigerSiLUMulFunction.apply(da, db)
450+
c2 = LigerSiLUMulFunction.apply(a2, b2)
451+
452+
torch.testing.assert_close(c1.full_tensor(), c2, atol=atol, rtol=rtol)
453+
454+
grad = torch.randn_like(c2)
455+
torch.distributed.broadcast(grad, src=0)
456+
dgrad = torch.distributed.tensor.distribute_tensor(
457+
grad, device_mesh=device_mesh, placements=[torch.distributed.tensor.Shard(2)]
458+
)
459+
460+
c1.backward(dgrad)
461+
c2.backward(grad)
462+
463+
torch.testing.assert_close(da.grad.full_tensor(), a2.grad, atol=atol, rtol=rtol)
464+
torch.testing.assert_close(db.grad.full_tensor(), b2.grad, atol=atol, rtol=rtol)
465+
466+
467+
@pytest.mark.xfail(
468+
torch.cuda.device_count() < 8,
469+
reason="Pending multi-GPU host support. This test is expected to pass when run with multi-GPU host.",
470+
)
471+
@pytest.mark.parametrize(
472+
"world_size, bsz, seq_len, hidden_size",
473+
[
474+
(4, 2, 2, 8),
475+
(8, 9, 7, 64),
476+
],
477+
)
478+
@pytest.mark.parametrize(
479+
"dtype, atol, rtol",
480+
[
481+
(torch.float32, 1e-4, 1e-6),
482+
(torch.bfloat16, 2e-1, 2e-2),
483+
],
484+
)
485+
def test_dtensor_liger_silumul(world_size, bsz, seq_len, hidden_size, dtype, atol, rtol):
486+
with tempfile.NamedTemporaryFile() as f:
487+
mp.spawn(
488+
_test_dtensor_liger_silumul,
489+
args=(world_size, bsz, seq_len, hidden_size, dtype, atol, rtol, f.name),
490+
nprocs=world_size,
491+
join=True,
492+
)

0 commit comments

Comments
 (0)