Skip to content

Commit 635db73

Browse files
committed
transposed w2 to have reduction dim be innermost dim
1 parent f68e81e commit 635db73

File tree

5 files changed

+22
-19
lines changed

5 files changed

+22
-19
lines changed

mixtral-moe/generate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def device_sync(device):
2424

2525
torch._inductor.config.coordinate_descent_tuning = True
2626
torch._inductor.config.triton.unique_kernel_names = True
27-
torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
27+
# torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
2828

2929

3030
# support running without installing as a package

mixtral-moe/model.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -183,21 +183,22 @@ def forward(self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Optiona
183183
y = self.wo(y)
184184
return y
185185

186+
import torch.distributed
186187

187188
class ConditionalFeedForward(nn.Module):
188189
def __init__(self, config):
189190
super().__init__()
190191
self.w1 = nn.Parameter(torch.empty(config.num_experts, config.intermediate_size, config.dim))
191-
self.w2 = nn.Parameter(torch.empty(config.num_experts, config.intermediate_size, config.dim))
192+
self.w2 = nn.Parameter(torch.empty(config.num_experts, config.dim, config.intermediate_size))
192193
self.w3 = nn.Parameter(torch.empty(config.num_experts, config.intermediate_size, config.dim))
193194

194195
def forward(self, x: Tensor, expert_indices: Tensor) -> Tensor:
195-
w1_weights = self.w1[expert_indices].transpose(-1, -2) # [T, A, D, D]
196-
w3_weights = self.w3[expert_indices].transpose(-1, -2) # [T, A, D, D]
196+
w1_weights = self.w1[expert_indices] # [T, A, D, D]
197+
w3_weights = self.w3[expert_indices] # [T, A, D, D]
197198
w2_weights = self.w2[expert_indices] # [T, A, D, D]
198-
x1 = F.silu(torch.einsum('ti,taio -> tao', x, w1_weights))
199-
x3 = torch.einsum('ti, taio -> tao', x, w3_weights)
200-
expert_outs = torch.einsum('tao, taoi -> tai', (x1 * x3), w2_weights)
199+
x1 = F.silu(torch.einsum('ti,taoi -> tao', x, w1_weights))
200+
x3 = torch.einsum('ti, taoi -> tao', x, w3_weights)
201+
expert_outs = torch.einsum('tao, taio -> tai', (x1 * x3), w2_weights)
201202
return expert_outs
202203

203204

mixtral-moe/quantize.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -75,11 +75,11 @@ def create_quantized_state_dict(self):
7575
cur_state_dict[f"{fqn}.weight"] = int8_weight
7676
cur_state_dict[f"{fqn}.scales"] = scales.to(mod.weight.dtype)
7777
elif isinstance(mod, ConditionalFeedForward):
78-
num_experts, intermediate_size, dim = mod.w1.shape
7978
for weight_idx in range(0, 3):
8079
weight_name = f"w{weight_idx + 1}"
8180
scales_name = f"scales{weight_idx + 1}"
8281
weight = getattr(mod, weight_name)
82+
num_experts, intermediate_size, dim = weight.shape
8383

8484
bit8_weight_list = []
8585
scales_list = []
@@ -125,20 +125,20 @@ def __init__(self, num_experts, intermediate_size, dim, target_dtype):
125125
self.target_dtype = target_dtype
126126

127127
self.register_buffer("w1", torch.empty(num_experts, intermediate_size, dim, dtype=target_dtype))
128-
self.register_buffer("w2", torch.empty(num_experts, intermediate_size, dim, dtype=target_dtype))
128+
self.register_buffer("w2", torch.empty(num_experts, dim, intermediate_size, dtype=target_dtype))
129129
self.register_buffer("w3", torch.empty(num_experts, intermediate_size, dim, dtype=target_dtype))
130130

131131
self.register_buffer("scales1", torch.empty(num_experts, intermediate_size, dtype=torch.bfloat16))
132-
self.register_buffer("scales2", torch.empty(num_experts, intermediate_size, dtype=torch.bfloat16))
132+
self.register_buffer("scales2", torch.empty(num_experts, dim, dtype=torch.bfloat16))
133133
self.register_buffer("scales3", torch.empty(num_experts, intermediate_size, dtype=torch.bfloat16))
134134

135135
def forward(self, x, expert_indices):
136-
w1_weights = (self.w1.to(x.dtype)[expert_indices] * self.scales1[expert_indices].to(x.dtype).unsqueeze(-1)).transpose(-1, -2) # [T, A, D, D]
137-
w3_weights = (self.w3.to(x.dtype)[expert_indices] * self.scales3[expert_indices].to(x.dtype).unsqueeze(-1)).transpose(-1, -2) # [T, A, D, D]
138-
w2_weights = (self.w2.to(x.dtype)[expert_indices] * self.scales2[expert_indices].to(x.dtype).unsqueeze(-1)) # [T, A, D, D]
139-
x1 = F.silu(torch.einsum('ti,taio -> tao', x, w1_weights))
140-
x3 = torch.einsum('ti, taio -> tao', x, w3_weights)
141-
expert_outs = torch.einsum('tao, taoi -> tai', (x1 * x3), w2_weights)
136+
w1_weights = self.w1.to(x.dtype)[expert_indices] # [T, A, D, D]
137+
w3_weights = self.w3.to(x.dtype)[expert_indices] # [T, A, D, D]
138+
w2_weights = self.w2.to(x.dtype)[expert_indices]
139+
x1 = F.silu(torch.einsum('ti,taoi -> tao', x, w1_weights) * self.scales1[expert_indices].to(x.dtype))
140+
x3 = torch.einsum('ti, taoi -> tao', x, w3_weights) * self.scales3[expert_indices].to(x.dtype)
141+
expert_outs = torch.einsum('tao, taio -> tai', (x1 * x3), w2_weights) * self.scales2[expert_indices].to(x.dtype) # [T, A, D, D]
142142
return expert_outs
143143

144144

mixtral-moe/scripts/convert_hf_checkpoint.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,11 @@ def convert_hf_checkpoint(
7676
del final_result[key]
7777
del final_result[key.replace("wq", "wk")]
7878
del final_result[key.replace("wq", "wv")]
79-
if "w1" in key or "w2" in key or "w3" in key:
79+
elif "w1" in key or "w3" in key:
8080
final_result[key] = final_result[key].reshape(config.num_experts, config.intermediate_size, config.dim).contiguous()
81-
if "gate" in key:
81+
elif "w2" in key:
82+
final_result[key] = final_result[key].reshape(config.num_experts, config.dim, config.intermediate_size).contiguous()
83+
elif "gate" in key:
8284
final_result[key] = final_result[key].contiguous()
8385

8486
print(f"Saving checkpoint to {checkpoint_dir / 'model.pth'}")

mixtral-moe/tp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def shard_qkv(qkv, dim, weight_splits):
9999
def _apply_tp_moe_ffn(mlp: MOEFeedForward) -> None:
100100
mlp.cond_ffn.w1 = nn.Parameter(shard(mlp.cond_ffn.w1, 1), requires_grad=False)
101101
mlp.cond_ffn.w3 = nn.Parameter(shard(mlp.cond_ffn.w3, 1), requires_grad=False)
102-
mlp.cond_ffn.w2 = nn.Parameter(shard(mlp.cond_ffn.w2, 1), requires_grad=False)
102+
mlp.cond_ffn.w2 = nn.Parameter(shard(mlp.cond_ffn.w2, 2), requires_grad=False)
103103

104104
if hasattr(mlp.cond_ffn, "scales1"):
105105
mlp.cond_ffn.scales1 = nn.Parameter(shard(mlp.cond_ffn.scales1, 1), requires_grad=False)

0 commit comments

Comments
 (0)