|
8 | 8 | from transformers import AutoConfig |
9 | 9 | from lightllm.common.fused_moe.topk_select import select_experts |
10 | 10 | from lightllm.common.fused_moe.grouped_fused_moe import fused_experts_impl |
11 | | -from vllm.model_executor.layers.fused_moe.fused_moe import fused_moe as fused_moe_vllm |
12 | | -from sglang.srt.layers.moe.fused_moe_triton.fused_moe import ( |
13 | | - fused_moe as fused_moe_sglang, |
14 | | -) |
15 | 11 |
|
16 | 12 |
|
17 | 13 | def get_model_config(model_name: str, tp_size: int): |
@@ -59,12 +55,10 @@ def get_model_config(model_name: str, tp_size: int): |
59 | 55 | intermediate_size = config.intermediate_size |
60 | 56 | shard_intermediate_size = 2 * intermediate_size // tp_size |
61 | 57 |
|
62 | | - vllm_version_num = vllm.__version_tuple__[0] * 100 + vllm.__version_tuple__[1] * 10 + vllm.__version_tuple__[2] |
63 | 58 | block_shape = None |
64 | 59 | if hasattr(config, "quantization_config") and "weight_block_size" in config.quantization_config: |
65 | 60 | block_shape = config.quantization_config["weight_block_size"] |
66 | 61 | assert len(block_shape) == 2 |
67 | | - assert vllm_version_num >= 66, "Block-wise quantized fp8 fused_moe is only supported for VLLM>=0.6.6.post1" |
68 | 62 |
|
69 | 63 | shape_configs = { |
70 | 64 | "num_experts": E, |
@@ -131,6 +125,7 @@ def fused_moe_vllm_api( |
131 | 125 | a2_scale=None, |
132 | 126 | block_shape=None, |
133 | 127 | ): |
| 128 | + from vllm.model_executor.layers.fused_moe.fused_moe import fused_moe as fused_moe_vllm |
134 | 129 | if block_shape is not None: |
135 | 130 | return fused_moe_vllm( |
136 | 131 | x, |
@@ -179,7 +174,9 @@ def fused_moe_sglang_api( |
179 | 174 | ): |
180 | 175 | from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig |
181 | 176 | from sglang.srt.layers.moe.topk import TopK, TopKConfig, select_experts |
182 | | - |
| 177 | + from sglang.srt.layers.moe.fused_moe_triton.fused_moe import ( |
| 178 | + fused_moe as fused_moe_sglang, |
| 179 | + ) |
183 | 180 | topk_output = select_experts( |
184 | 181 | hidden_states=x, |
185 | 182 | router_logits=input_gating, |
|
0 commit comments