Skip to content

Commit d001dbe

Browse files
authored
MoE TensorParallel with Eager (#2582)
1 parent f4c47c9 commit d001dbe

File tree

3 files changed

+315
-148
lines changed

3 files changed

+315
-148
lines changed
Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
from functools import partial
2+
3+
import torch
4+
from torch.distributed.tensor.placement_types import Placement, Shard, Replicate
5+
from torch.distributed.tensor.parallel import (
6+
parallelize_module,
7+
ColwiseParallel,
8+
RowwiseParallel,
9+
ParallelStyle,
10+
)
11+
from torch.distributed.tensor import (
12+
DeviceMesh,
13+
distribute_module,
14+
distribute_tensor,
15+
DTensor,
16+
)
17+
from torch.distributed.device_mesh import init_device_mesh
18+
import torch.nn as nn
19+
20+
import thunder.tests.llama4_moe as llama4_moe
21+
from thunder.tests.distributed.helper import DistributedParallelTestCase
22+
23+
24+
# Referred from torchtitan: https://github.com/pytorch/torchtitan/blob/827255bb/torchtitan/experiments/llama4/infra/expert_parallel.py#L25
25+
class GroupedLinearColwiseParallel(ParallelStyle):
26+
def __init__(
27+
self,
28+
*,
29+
use_local_output: bool = True,
30+
):
31+
super().__init__()
32+
self.use_local_output = use_local_output
33+
34+
@staticmethod
35+
def _prepare_input_fn(mod, inputs, device_mesh):
36+
prepared_inputs = []
37+
INPUT_LAYOUTS = (Replicate(), Replicate())
38+
assert len(INPUT_LAYOUTS) == len(inputs), "input_layouts and inputs have different lengths"
39+
# annotate module input placements/sharding with input_layouts
40+
for inp, input_layout in zip(inputs, INPUT_LAYOUTS):
41+
assert isinstance(inp, (torch.Tensor, list)), f"inp is not a torch.Tensor or list: {type(inp)}"
42+
if isinstance(inp, torch.Tensor):
43+
assert not isinstance(inp, DTensor), "inp is already a DTensor"
44+
inp = DTensor.from_local(inp, device_mesh, (input_layout,), run_check=False)
45+
prepared_inputs.append(inp)
46+
return tuple(prepared_inputs)
47+
48+
def _partition_fn(self, name, module, device_mesh):
49+
module.register_parameter(
50+
"weight", nn.Parameter(distribute_tensor(module.weight, device_mesh, [Shard(2)]))
51+
) # Column-wise sharding
52+
53+
@staticmethod
54+
def _prepare_output_fn(use_local_output, mod, outputs, device_mesh):
55+
OUTPUT_LAYOUT = Shard(1)
56+
if outputs.placements != (OUTPUT_LAYOUT,):
57+
outputs = outputs.redistribute(placements=(OUTPUT_LAYOUT,), async_op=True)
58+
# back to local tensor
59+
return outputs.to_local() if use_local_output else outputs
60+
61+
def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
62+
return distribute_module(
63+
module,
64+
device_mesh,
65+
self._partition_fn,
66+
self._prepare_input_fn,
67+
partial(self._prepare_output_fn, self.use_local_output),
68+
)
69+
70+
71+
class GroupedLinearRowwiseParallel(ParallelStyle):
72+
def __init__(
73+
self,
74+
*,
75+
input_layouts: tuple[Placement | None] | None = None,
76+
output_layouts: Placement | None = None,
77+
use_local_output: bool = True,
78+
):
79+
super().__init__()
80+
self.input_layouts = input_layouts or (Shard(-1), Replicate())
81+
self.output_layout = output_layouts or Replicate()
82+
self.desired_input_layouts = (Shard(-1), Replicate())
83+
self.use_local_output = use_local_output
84+
85+
@staticmethod
86+
def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh):
87+
prepared_inputs = []
88+
# annotate module input placements/sharding with input_layouts
89+
for inp, input_layout, desired_input_layout in zip(inputs, input_layouts, desired_input_layouts):
90+
if isinstance(inp, torch.Tensor):
91+
if not isinstance(inp, DTensor):
92+
inp = DTensor.from_local(inp, device_mesh, (input_layout,), run_check=False)
93+
if input_layout != desired_input_layout:
94+
inp = inp.redistribute(placements=(desired_input_layout,), async_op=True)
95+
prepared_inputs.append(inp)
96+
return tuple(prepared_inputs)
97+
98+
def _partition_fn(self, name, module, device_mesh):
99+
module.register_parameter("weight", nn.Parameter(distribute_tensor(module.weight, device_mesh, [Shard(1)])))
100+
101+
@staticmethod
102+
def _prepare_output_fn(output_layout, use_local_output, mod, outputs, device_mesh):
103+
if outputs.placements != (output_layout,):
104+
outputs = outputs.redistribute(placements=(output_layout,), async_op=True)
105+
# back to local tensor
106+
return outputs.to_local() if use_local_output else outputs
107+
108+
def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
109+
return distribute_module(
110+
module,
111+
device_mesh,
112+
self._partition_fn,
113+
partial(self._prepare_input_fn, self.input_layouts, self.desired_input_layouts),
114+
partial(self._prepare_output_fn, self.output_layout, self.use_local_output),
115+
)
116+
117+
118+
def parallelize_moe_model(model: llama4_moe.Llama4MoE, device_mesh: torch.distributed.DeviceMesh):
119+
"""Apply TensorParallel to the MoE model"""
120+
121+
# Define the parallelization plan as a dictionary
122+
parallelize_plan = {
123+
# Shared experts - SwiGLU components
124+
"shared_experts.gate_proj": ColwiseParallel(use_local_output=False, output_layouts=Shard(2)),
125+
"shared_experts.up_proj": ColwiseParallel(use_local_output=False, output_layouts=Shard(2)),
126+
"shared_experts.down_proj": RowwiseParallel(),
127+
# Routed experts
128+
"routed_experts.gate_proj": GroupedLinearColwiseParallel(use_local_output=False),
129+
"routed_experts.up_proj": GroupedLinearColwiseParallel(use_local_output=False),
130+
"routed_experts.down_proj": GroupedLinearRowwiseParallel(),
131+
}
132+
133+
# Parallelize the model
134+
parallelized_model = parallelize_module(
135+
model,
136+
device_mesh,
137+
parallelize_plan,
138+
)
139+
return parallelized_model
140+
141+
142+
class TestLlama4MoEDistributed(DistributedParallelTestCase):
143+
def test_llama4_moe_distributed(self):
144+
# Get world size
145+
world_size = self.world_size
146+
device = f"cuda:{self.rank}"
147+
148+
# Initialize device mesh for TensorParallel
149+
device_mesh = init_device_mesh("cuda", (world_size,))
150+
151+
config = llama4_moe.Config(
152+
name="small", hidden_size=256, intermediate_size=512, num_routed_experts=8, num_shared_experts=1
153+
)
154+
155+
# Create model with distributed tensors
156+
model = llama4_moe.Llama4MoE(config)
157+
158+
# Apply TensorParallel
159+
parallelized_model = parallelize_moe_model(model, device_mesh)
160+
161+
# Without this, `thunderfx` falls back to `inductor` for `_grouped_mm`
162+
# as it doesn't have a grad-rule for the same.
163+
parallelized_model.requires_grad_(False)
164+
165+
batch_size, seq_len = 1, 2048
166+
inp = torch.randn(batch_size, seq_len, config.hidden_size, dtype=torch.bfloat16, device=device)
167+
168+
# Run forward pass
169+
actual = parallelized_model(inp)
170+
expected = model(inp)
171+
172+
torch.testing.assert_close(actual, expected, atol=1e-5, rtol=1e-5)

thunder/tests/llama4_moe.py

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
import math
2+
from dataclasses import dataclass
3+
from looseversion import LooseVersion
4+
5+
import torch
6+
import torch.nn as nn
7+
import torch.nn.functional as F
8+
9+
10+
@dataclass(frozen=True)
11+
class Config:
12+
name: str
13+
hidden_size: int
14+
intermediate_size: int
15+
num_routed_experts: int
16+
num_shared_experts: int
17+
dtype: torch.dtype = torch.bfloat16
18+
device: str = "cuda"
19+
20+
21+
class SwiGLU(nn.Module):
22+
def __init__(self, hidden_size: int, intermediate_size: int, dtype: torch.dtype, device: str):
23+
super().__init__()
24+
self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False, dtype=dtype, device=device)
25+
self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False, dtype=dtype, device=device)
26+
self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False, dtype=dtype, device=device)
27+
28+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
29+
return self.down_proj(F.silu(self.gate_proj(hidden_states)) * self.up_proj(hidden_states))
30+
31+
32+
def _group_sizes_from_offsets(offsets: torch.Tensor) -> list[int]:
33+
group_sizes = []
34+
prev = 0
35+
for offset in offsets:
36+
group_sizes.append(offset - prev)
37+
prev = offset
38+
return group_sizes
39+
40+
41+
if LooseVersion(torch.__version__) >= LooseVersion("2.8.0"):
42+
# Required otherwise, there is a graph-break.
43+
_grouped_mm = torch.compiler.allow_in_graph(torch._grouped_mm)
44+
45+
46+
# This function should be replaced with torch._grouped_mm. However,
47+
# torch._grouped_mm is yet to be usable because it requires offsets being
48+
# multiples of 16.
49+
def grouped_mm(a: torch.Tensor, b: torch.Tensor, tokens_per_expert_or_offsets: torch.Tensor) -> torch.Tensor:
50+
if torch.compiler.is_compiling():
51+
offsets = tokens_per_expert_or_offsets # [n]
52+
return _grouped_mm(a, b, offsets)
53+
54+
group_outs = []
55+
tokens_per_expert = tokens_per_expert_or_offsets
56+
for idx, group_a in enumerate(a.split(tokens_per_expert)):
57+
group_outs.append(group_a @ b[idx])
58+
return torch.cat(group_outs)
59+
60+
61+
class GroupedLinear(nn.Module):
62+
def __init__(self, groups: int, in_features: int, out_features: int, dtype: torch.dtype, device: str):
63+
super().__init__()
64+
self.weight = nn.Parameter(torch.empty(groups, in_features, out_features, dtype=dtype, device=device))
65+
# Initialize the weight in the same way as nn.Linear
66+
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
67+
68+
def forward(self, hidden_states: torch.Tensor, tokens_per_expert_or_offsets: torch.Tensor) -> torch.Tensor:
69+
return grouped_mm(hidden_states, self.weight, tokens_per_expert_or_offsets)
70+
71+
72+
class GroupedSwiGLU(nn.Module):
73+
def __init__(self, groups: int, hidden_size: int, intermediate_size: int, dtype: torch.dtype, device: str):
74+
super().__init__()
75+
self.gate_proj = GroupedLinear(groups, hidden_size, intermediate_size, dtype, device)
76+
self.up_proj = GroupedLinear(groups, hidden_size, intermediate_size, dtype, device)
77+
self.down_proj = GroupedLinear(groups, intermediate_size, hidden_size, dtype, device)
78+
79+
def forward(self, hidden_states: torch.Tensor, tokens_per_expert_or_offsets: torch.Tensor) -> torch.Tensor:
80+
return self.down_proj(
81+
F.silu(self.gate_proj(hidden_states, tokens_per_expert_or_offsets))
82+
* self.up_proj(hidden_states, tokens_per_expert_or_offsets),
83+
tokens_per_expert_or_offsets,
84+
)
85+
86+
87+
class Llama4MoE(nn.Module):
88+
def __init__(self, config: Config):
89+
super().__init__()
90+
self.config = config
91+
self.gate = nn.Linear(
92+
config.hidden_size, config.num_routed_experts, bias=False, dtype=config.dtype, device=config.device
93+
)
94+
self.shared_experts = SwiGLU(
95+
config.hidden_size, config.intermediate_size * config.num_shared_experts, config.dtype, config.device
96+
)
97+
self.routed_experts = GroupedSwiGLU(
98+
config.num_routed_experts, config.hidden_size, config.intermediate_size, config.dtype, config.device
99+
)
100+
101+
def run_routed_experts(self, hidden_states: torch.Tensor) -> torch.Tensor:
102+
batch_size, seq_len, _ = hidden_states.size()
103+
hidden_states = hidden_states.view(-1, hidden_states.size(-1)) # [s, h]
104+
105+
router_logits = self.gate(hidden_states) # [s, n]
106+
topk_weight, topk_ids = router_logits.topk(1) # [s, 1]
107+
router_scores = topk_weight.sigmoid() # [s, 1]
108+
hidden_states = hidden_states * router_scores # [s, h]
109+
110+
counts = torch.zeros(
111+
topk_ids.size(0),
112+
self.config.num_routed_experts,
113+
device=topk_ids.device,
114+
dtype=torch.int32,
115+
) # [s, n]
116+
117+
counts = counts.scatter(1, topk_ids, 1) # [s, n]
118+
tokens_per_expert = counts.sum(0) # [n]
119+
token_ids_sorted_by_expert_id = topk_ids.view(-1).argsort() # [s]
120+
tokens_sorted_by_expert_id = hidden_states[token_ids_sorted_by_expert_id] # [s, h]
121+
122+
if not torch.compiler.is_compiling():
123+
tokens_per_expert_or_offsets = tokens_per_expert.tolist()
124+
else:
125+
tokens_per_expert_or_offsets = torch.cumsum(tokens_per_expert, 0, dtype=torch.int32) # [n]
126+
127+
outs_sorted_by_expert_id = self.routed_experts(
128+
tokens_sorted_by_expert_id, tokens_per_expert_or_offsets
129+
) # [s, h]
130+
131+
token_ids_sorted_by_expert_inverse_id = torch.argsort(token_ids_sorted_by_expert_id)
132+
outs_sorted_by_token_id = outs_sorted_by_expert_id[token_ids_sorted_by_expert_inverse_id]
133+
134+
return outs_sorted_by_token_id
135+
136+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
137+
return self.shared_experts(hidden_states) + self.run_routed_experts(hidden_states)

0 commit comments

Comments
 (0)