@@ -102,98 +102,38 @@ def torch_moe(
102102 Unified Mixture-of-Experts (MoE) operator that uses a Mixtral-style dispatch
103103 (token routing + index_add_ accumulation) and a selectable per-expert MLP.
104104
105- Supports both:
106- - Standard MoE with per-expert weight lists (apply_routing_on_input=False)
107- - Llama4 MoE with stacked weight tensors (apply_routing_on_input=True)
108-
109105 Parameters:
110106 x (torch.Tensor): Input tensor of shape (B, H) or (B, S, H), where B is the batch size,
111107 S is the sequence length, and H is the hidden size.
112108 selected_experts (torch.Tensor): A tensor of shape (B, TOP_K) or (B*S, TOP_K) containing the indices
113109 of the selected experts for each token. Only experts within range [0,num_experts) is processed
114110 routing_weights (torch.Tensor): A tensor of shape (B, TOP_K) or (B*S, TOP_K) containing the normalized
115111 routing weights for the selected experts.
116- - Standard MoE: softmax normalized weights
117- - Llama4 MoE: sigmoid activated weights
118- w1_weight:
119- For per-expert lists:
120- • is_gated_mlp==True: List of W1 with shape (I, H) — "gate" projection.
121- • is_gated_mlp==False: List of W_up with shape (I, H) — up projection.
122- For stacked tensors (Llama4):
123- • Single-element list containing stacked w3_w1 tensor with shape (E, 2*I, H) in TRT-LLM format
124- w2_weight:
125- For per-expert lists:
126- • List of W2/W_down with shape (H, I) — down projection.
127- For stacked tensors (Llama4):
128- • Single-element list containing stacked w2 tensor with shape (E, H, I) in TRT-LLM format
129- w3_weight:
130- For per-expert lists with gated_mlp:
131- • List of W3 with shape (I, H) — "up" (second) projection in gated MLP.
132- For is_gated_mlp==False or stacked tensors:
133- • pass an empty list []; ignored.
134- is_gated_mlp:
135- Selects the per-expert MLP computation:
136- • is_gated_mlp==True (default, Mixtral/DeepSeek/Llama4-style):
137- y = W2( act(W1 x) * (W3 x) )
138- • is_gated_mlp==False (NemotronH-style 2-layer MLP):
139- y = W_down( act(W_up x) )
140- act_fn:
141- Elementwise activation applied inside the expert MLP.
112+ w1_weight: List of per-expert weight tensors of up projection.
113+ w2_weight: List of per-expert weight tensors of down projection.
114+ w3_weight: List of per-expert weight tensors of gate projection.
115+ is_gated_mlp: If True, use a gated MLP. If False, use a simple MLP.
116+ act_fn: Activation function applied inside the expert MLP.
142117 Supported: ActivationType.Silu (default), ActivationType.Relu2 (ReLU then square).
143- apply_routing_on_input:
144- If True (Llama4 pattern): multiply routing weights with INPUT before MLP
145- Result: act(input * routing_weight) - routing affects activation
146- If False (standard pattern): multiply routing weights with OUTPUT after MLP
147- Result: act(input) * routing_weight - routing scales output
118+ apply_routing_on_input: If True, multiply routing weights with INPUT before MLP
119+ This means: silu(input * routing_weight)
120+ If False, multiply routing weights with OUTPUT after MLP
121+ This means: silu(input) * routing_weight
148122 Returns:
149123 torch.Tensor: Output tensor with the same shape as the input x.
150124 """
151125 torch_act_fn = _resolve_torch_fn (act_fn )
152126
153- # Detect if using stacked tensor format (Llama4) vs per-expert lists (standard)
154- is_stacked = len (w1_weight ) == 1 and w1_weight [0 ].ndim == 3
155-
156- # Todo: either change torch_moe to use a single condition, or refactor this code.
157- # it should be :
158- # is_gated_mlp:
159- # stacked:
160- # ...
161- # not stacked:
162- # .
163- # else:
164- # assert (not stacked)
165- # ...
166- # .
167- if is_stacked :
168- # Llama4 stacked tensor format - only supports gated_mlp
169- if not is_gated_mlp :
170- raise ValueError ("Stacked tensor format only supports gated MLP style" )
171-
172- w3_w1_stacked = w1_weight [0 ] # (E, 2*I, H)
173- intermediate_size = w3_w1_stacked .shape [1 ] // 2
174- w2_stacked = w2_weight [0 ] # (E, H, I)
175-
176- def make_mlp (idx : int ):
177- gate_up = w3_w1_stacked [idx ] # (2*I, H)
178- W3 = gate_up [:intermediate_size , :] # (I, H)
179- W1 = gate_up [intermediate_size :, :] # (I, H)
180- W2 = w2_stacked [idx ] # (H, I)
181- weight_dtype = W1 .dtype
182- return lambda inp : F .linear (
183- torch_act_fn (F .linear (inp .to (weight_dtype ), W1 ))
184- * F .linear (inp .to (weight_dtype ), W3 ),
185- W2 ,
186- )
187-
188- mlps = [make_mlp (idx ) for idx in range (w3_w1_stacked .shape [0 ])]
189-
190- elif is_gated_mlp :
127+ mlps = []
128+ if is_gated_mlp :
191129 # Standard per-expert list format with gated MLP
192130 def make_mlp (i : int ):
193131 W1 = w1_weight [i ] # (I, H)
194132 W2 = w2_weight [i ] # (H, I)
195133 W3 = w3_weight [i ] # (I, H)
196- return lambda inp : F .linear (torch_act_fn (F .linear (inp , W1 )) * F .linear (inp , W3 ), W2 )
134+ return lambda inp : F .linear (
135+ torch_act_fn (F .linear (inp .to (W1 .dtype ), W1 )) * F .linear (inp .to (W3 .dtype ), W3 ), W2
136+ )
197137
198138 mlps = [make_mlp (i ) for i in range (len (w1_weight ))]
199139
0 commit comments