Skip to content

Commit d008d64

Browse files
authored
feat:[AutoDeploy] Update MoE pattern matcher to drop expert selection logic (#3283)
* update matcher to match expert compute first, then extract other args with LCA Signed-off-by: Frida Hou <[email protected]> * support 3D and 2D input in torch.ops.moe.trtllm_fused_moe Signed-off-by: Frida Hou <[email protected]> * update custom ops to support 3D and 2D inputs Signed-off-by: Ubuntu <[email protected]> * update deepseek patch Signed-off-by: Ubuntu <[email protected]> --------- Signed-off-by: Frida Hou <[email protected]>
1 parent b0ce137 commit d008d64

File tree

3 files changed

+269
-188
lines changed

3 files changed

+269
-188
lines changed

tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,9 @@ def torch_moe(
3535
torch.Tensor: Output tensor with the same shape as the input x.
3636
"""
3737

38-
hidden_dim = x.shape[-1]
38+
x_shape = x.shape
39+
hidden_dim = x_shape[-1]
40+
x = x.view(-1, hidden_dim)
3941
num_experts = len(w1_weight)
4042

4143
final_hidden_states = torch.zeros_like(x)
@@ -63,7 +65,7 @@ def torch_moe(
6365
0, top_x, current_hidden_states.to(final_hidden_states.dtype)
6466
)
6567

66-
return final_hidden_states.view_as(x)
68+
return final_hidden_states.view(x_shape)
6769

6870

6971
@torch_moe.register_fake
@@ -104,6 +106,8 @@ def torch_fused_moe(
104106
Returns:
105107
torch.Tensor: Output tensor with the same shape as the input x.
106108
"""
109+
x_shape = x.shape
110+
x = x.view(-1, x_shape[-1])
107111
num_experts = w2_stacked_weight.shape[0]
108112
intermediate_size = w3_w1_stacked_weight.shape[1] // 2
109113
results = torch.zeros_like(x)
@@ -129,7 +133,7 @@ def torch_fused_moe(
129133
scaling = routing_weights[batch_idx, nth_expert].unsqueeze(-1)
130134
results[batch_idx] += scaling * expert_out
131135

132-
return results.view_as(x)
136+
return results.view(x_shape)
133137

134138

135139
@torch_fused_moe.register_fake
@@ -151,6 +155,9 @@ def trtllm_fused_moe(
151155
w3_w1_stacked_weight: torch.Tensor,
152156
w2_stacked_weight: torch.Tensor,
153157
) -> torch.Tensor:
158+
x_shape = x.shape
159+
x = x.view(-1, x_shape[-1])
160+
154161
routing_weights = routing_weights.to(torch.float32)
155162
selected_experts = selected_experts.to(torch.int32)
156163
quant_scales = []
@@ -167,7 +174,7 @@ def trtllm_fused_moe(
167174
tp_rank=0,
168175
ep_size=1,
169176
ep_rank=0,
170-
)[0]
177+
)[0].view(x_shape)
171178

172179

173180
@trtllm_fused_moe.register_fake

tensorrt_llm/_torch/auto_deploy/models/deepseek.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -129,12 +129,8 @@ def deepseek_v3_moe_exact(self, hidden_states):
129129
@torch.inference_mode()
130130
def deepseek_v3_moe(self, hidden_states):
131131
"""DeepSeekV3MoE forward function rewritten in Mixtral style to enable torch export."""
132-
identity = hidden_states
133-
batch_size, sequence_length, hidden_dim = hidden_states.shape
134132

135133
selected_experts, routing_weights, *_ = self.gate(hidden_states)
136-
hidden_states = hidden_states.view(-1, hidden_dim)
137-
138134
final_hidden_states = torch.ops.moe.torch_moe(
139135
hidden_states,
140136
selected_experts,
@@ -144,10 +140,8 @@ def deepseek_v3_moe(self, hidden_states):
144140
w3_weight=[expert.up_proj.weight for expert in self.experts],
145141
)
146142

147-
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
148-
149143
if self.config.n_shared_experts is not None:
150-
final_hidden_states = final_hidden_states + self.shared_experts(identity)
144+
final_hidden_states = final_hidden_states + self.shared_experts(hidden_states)
151145

152146
return final_hidden_states.to(hidden_states.dtype)
153147

0 commit comments

Comments
 (0)