|
8 | 8 |
|
9 | 9 | from vllm import _custom_ops as ops |
10 | 10 | from vllm._custom_ops import cpu_fused_moe, cpu_prepack_moe_weight |
11 | | -from vllm.model_executor.layers.activation import SiluAndMul, SwigluOAIAndMul |
| 11 | +from vllm.model_executor.layers.activation import SiluAndMul |
12 | 12 | from vllm.model_executor.layers.quantization.utils.layer_utils import replace_parameter |
13 | 13 | from vllm.utils.torch_utils import direct_register_custom_op |
14 | 14 |
|
15 | 15 | _CPU_MOE_LAYER_CACHE = {} |
16 | 16 |
|
17 | 17 |
|
18 | | -class _LazyActivationDict(dict): |
19 | | - """Lazily instantiate activation functions on first access. |
| 18 | +def _swigluoai_forward_native( |
| 19 | + x: torch.Tensor, |
| 20 | + alpha: float = 1.702, |
| 21 | + limit: float = 7.0, |
| 22 | +) -> torch.Tensor: |
| 23 | + """PyTorch-native implementation of SwigluOAIAndMul.forward_native. |
20 | 24 |
|
21 | | - Avoids triggering CustomOp.__init__() at module import time, |
22 | | - which would call get_current_vllm_config() before config is set. |
| 25 | + Standalone function to avoid instantiating SwigluOAIAndMul (a CustomOp) |
| 26 | + which would trigger get_current_vllm_config() before config is set. |
23 | 27 | """ |
| 28 | + gate, up = x[..., ::2], x[..., 1::2] |
| 29 | + gate = gate.clamp(min=None, max=limit) |
| 30 | + up = up.clamp(min=-limit, max=limit) |
| 31 | + glu = gate * torch.sigmoid(gate * alpha) |
| 32 | + gated_output = (up + 1) * glu |
| 33 | + return gated_output |
24 | 34 |
|
25 | | - _factories: dict[str, type[SiluAndMul] | type[SwigluOAIAndMul]] = { |
26 | | - "silu": SiluAndMul, |
27 | | - "swigluoai": SwigluOAIAndMul, |
28 | | - } |
29 | 35 |
|
30 | | - def __missing__(self, key: str) -> SiluAndMul | SwigluOAIAndMul: |
31 | | - if key not in self._factories: |
32 | | - raise KeyError(f"{key} is not a supported activation") |
33 | | - self[key] = self._factories[key]() |
34 | | - return self[key] |
35 | | - |
36 | | - |
37 | | -_CPU_MOE_ACT = _LazyActivationDict() |
| 36 | +# Map activation names to their native forward functions. |
| 37 | +# Uses static methods or standalone functions to avoid instantiating CustomOp |
| 38 | +# classes, which would call get_current_vllm_config() before config is set. |
| 39 | +_CPU_MOE_ACT_FN: dict[str, Callable[[torch.Tensor], torch.Tensor]] = { |
| 40 | + "silu": SiluAndMul.forward_native, |
| 41 | + "swigluoai": _swigluoai_forward_native, |
| 42 | +} |
38 | 43 |
|
39 | 44 |
|
40 | 45 | def grouped_topk( |
@@ -230,7 +235,7 @@ def __call__( |
230 | 235 | apply_router_weight_on_input: bool = False, |
231 | 236 | activation: str = "silu", |
232 | 237 | ) -> torch.Tensor: |
233 | | - assert activation in _CPU_MOE_ACT._factories, f"{activation} is not supported." |
| 238 | + assert activation in _CPU_MOE_ACT_FN, f"{activation} is not supported." |
234 | 239 | assert not apply_router_weight_on_input |
235 | 240 |
|
236 | 241 | topk_weights, topk_ids = select_experts( |
@@ -418,7 +423,7 @@ def cpu_fused_moe_torch( |
418 | 423 | tokens_for_this_expert = sorted_tokens[start_idx:end_idx] |
419 | 424 |
|
420 | 425 | gate_up = layer.gate_up_linear[i](tokens_for_this_expert) # type: ignore |
421 | | - gate_up = _CPU_MOE_ACT[activation].forward_native(gate_up) |
| 426 | + gate_up = _CPU_MOE_ACT_FN[activation](gate_up) |
422 | 427 | expert_out = layer.down_linear[i](gate_up) # type: ignore |
423 | 428 | outputs.append(expert_out) |
424 | 429 | start_idx = end_idx |
|
0 commit comments