Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 38 additions & 3 deletions src/liger_kernel/ops/swiglu.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,13 +104,48 @@ class LigerSiLUMulFunction(torch.autograd.Function):
@staticmethod
@ensure_contiguous
def forward(ctx, a, b):
a, b, c = swiglu_forward(a, b)
ctx.save_for_backward(a, b)
return c
if isinstance(a, torch.distributed.tensor.DTensor) or isinstance(b, torch.distributed.tensor.DTensor):
device_mesh, placements = (
(a.device_mesh, a.placements)
if isinstance(a, torch.distributed.tensor.DTensor)
else (b.device_mesh, b.placements)
)

# Assume that full tensors are gathered before and identical across
# the associated process groups.
if not isinstance(a, torch.distributed.tensor.DTensor):
a = torch.distributed.tensor.distribute_tensor(a, device_mesh=device_mesh, placements=placements)
if not isinstance(b, torch.distributed.tensor.DTensor):
b = torch.distributed.tensor.distribute_tensor(b, device_mesh=device_mesh, placements=placements)
a_local, b_local, c_local = swiglu_forward(a.to_local(), b.to_local())
ctx.save_for_backward(a_local, b_local)
ctx.dtensor_metadata = (device_mesh, placements)
return torch.distributed.tensor.DTensor.from_local(c_local, device_mesh, placements)
else:
a, b, c = swiglu_forward(a, b)
ctx.save_for_backward(a, b)
ctx.dtensor_metadata = None
return c

@staticmethod
@ensure_contiguous
def backward(ctx, dc):
a, b = ctx.saved_tensors
if ctx.dtensor_metadata is not None:
device_mesh, placements = ctx.dtensor_metadata

# Assume that full tensors are gathered before and identical across
# the associated process groups.
dc_local = (
dc.to_local()
if isinstance(dc, torch.distributed.tensor.DTensor)
else torch.distributed.tensor.distribute_tensor(dc, device_mesh=device_mesh, placements=placements)
)
a_local, b_local = swiglu_backward(a, b, dc_local)
return (
torch.distributed.tensor.DTensor.from_local(a_local, device_mesh, placements),
torch.distributed.tensor.DTensor.from_local(b_local, device_mesh, placements),
)

a, b = swiglu_backward(a, b, dc)
return a, b
12 changes: 6 additions & 6 deletions test/convergence/fp32/test_mini_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1618,12 +1618,12 @@ def run_mini_model(
not LLAMA4_AVAILABLE,
reason="Llama4 not available in this version of trasnformers",
),
pytest.mark.xfail(
reason=(
"RuntimeError: Expected query, key, and value to have the same dtype, but got query.dtype:"
" float key.dtype: c10::BFloat16 and value.dtype: c10::BFloat16 instead."
)
),
# pytest.mark.xfail(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need to uncomment it

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually it was commented originally. make checkstyle added a space to make it pass the formatting checks.

# reason=(
# "RuntimeError: Expected query, key, and value to have the same dtype, but got query.dtype:"
# " float key.dtype: c10::BFloat16 and value.dtype: c10::BFloat16 instead."
# )
# ),
],
),
("mini_llama3", 32, 1e-4, torch.float32, 1e-8, 2e-5, 5e-3, 1e-5, 5e-3, 1e-5),
Expand Down
85 changes: 85 additions & 0 deletions test/transformers/test_swiglu.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import tempfile

import pytest
import torch
import torch.multiprocessing as mp
import transformers

from packaging import version
Expand All @@ -16,6 +19,7 @@
from liger_kernel.transformers.swiglu import LigerExperts
from liger_kernel.transformers.swiglu import LigerPhi3SwiGLUMLP
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
from liger_kernel.utils import infer_comm_backend
from liger_kernel.utils import infer_device

IS_TRANSFORMERS_V5_OR_LATER = version.parse(transformers.__version__) >= version.parse("5.0.0")
Expand Down Expand Up @@ -405,3 +409,84 @@ def test_correctness_functional(bsz, seq_len, size, dtype, atol, rtol):
# Check if gradients are close for x
assert torch.allclose(x1.grad, x2.grad, atol=atol, rtol=rtol)
assert torch.allclose(b1.grad, b2.grad, atol=atol, rtol=rtol)


def _test_dtensor_liger_silumul(rank, world_size, bsz, seq_len, hidden_size, dtype, atol, rtol, file_name):
torch.distributed.init_process_group(
backend=infer_comm_backend(),
init_method=f"file://{file_name}",
rank=rank,
world_size=world_size,
)
device = f"{infer_device()}:{rank}" if infer_device() != "cpu" else "cpu"
device_mesh = torch.distributed.device_mesh.init_device_mesh(
infer_device(), mesh_shape=(world_size,), mesh_dim_names=("tp",)
)

_a = torch.randn(bsz, seq_len, hidden_size, device=device, dtype=dtype)
_b = torch.randn(bsz, seq_len, hidden_size, device=device, dtype=dtype)

# Broadcast from rank 0 so all ranks operate on identical tensors
torch.distributed.broadcast(_a, src=0)
torch.distributed.broadcast(_b, src=0)

assert hidden_size % world_size == 0, f"hidden_size ({hidden_size}) must be divisible by world_size ({world_size})"

# DTensor path: shard inputs along the hidden dim
a1 = _a.clone().detach().requires_grad_(True)
b1 = _b.clone().detach().requires_grad_(True)
da = torch.distributed.tensor.distribute_tensor(
a1, device_mesh=device_mesh, placements=[torch.distributed.tensor.Shard(2)]
)
db = torch.distributed.tensor.distribute_tensor(
b1, device_mesh=device_mesh, placements=[torch.distributed.tensor.Shard(2)]
)

# Regular tensor path
a2 = _a.clone().detach().requires_grad_(True)
b2 = _b.clone().detach().requires_grad_(True)

c1 = LigerSiLUMulFunction.apply(da, db)
c2 = LigerSiLUMulFunction.apply(a2, b2)

torch.testing.assert_close(c1.full_tensor(), c2, atol=atol, rtol=rtol)

grad = torch.randn_like(c2)
torch.distributed.broadcast(grad, src=0)
dgrad = torch.distributed.tensor.distribute_tensor(
grad, device_mesh=device_mesh, placements=[torch.distributed.tensor.Shard(2)]
)

c1.backward(dgrad)
c2.backward(grad)

torch.testing.assert_close(da.grad.full_tensor(), a2.grad, atol=atol, rtol=rtol)
torch.testing.assert_close(db.grad.full_tensor(), b2.grad, atol=atol, rtol=rtol)


@pytest.mark.xfail(
torch.cuda.device_count() < 8,
reason="Pending multi-GPU host support. This test is expected to pass when run with multi-GPU host.",
)
@pytest.mark.parametrize(
"world_size, bsz, seq_len, hidden_size",
[
(4, 2, 2, 8),
(8, 9, 7, 64),
],
)
@pytest.mark.parametrize(
"dtype, atol, rtol",
[
(torch.float32, 1e-4, 1e-6),
(torch.bfloat16, 2e-1, 2e-2),
],
)
def test_dtensor_liger_silumul(world_size, bsz, seq_len, hidden_size, dtype, atol, rtol):
with tempfile.NamedTemporaryFile() as f:
mp.spawn(
_test_dtensor_liger_silumul,
args=(world_size, bsz, seq_len, hidden_size, dtype, atol, rtol, f.name),
nprocs=world_size,
join=True,
)
Loading