|
33 | 33 | from vllm.model_executor.utils import set_weight_attrs
|
34 | 34 | from vllm.platforms import current_platform
|
35 | 35 | from vllm.platforms.interface import CpuArchEnum
|
36 |
| -from vllm.utils import direct_register_custom_op, has_deep_ep, has_pplx |
| 36 | +from vllm.utils import (direct_register_custom_op, has_deep_ep, has_pplx, |
| 37 | + round_up) |
37 | 38 | from vllm.utils.flashinfer import has_flashinfer
|
38 | 39 |
|
39 | 40 | if current_platform.is_cuda_alike():
|
@@ -719,6 +720,12 @@ def __init__(
|
719 | 720 |
|
720 | 721 | self.global_num_experts = num_experts + num_redundant_experts
|
721 | 722 |
|
| 723 | + # we padding globally so EP buffer allocation works |
| 724 | + if quant_config and quant_config.get_name() == "mxfp4" and ( |
| 725 | + envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8 |
| 726 | + or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16): |
| 727 | + hidden_size = round_up(hidden_size, 256) |
| 728 | + |
722 | 729 | # For smuggling this layer into the fused moe custom op
|
723 | 730 | compilation_config = vllm_config.compilation_config
|
724 | 731 | if prefix in compilation_config.static_forward_context:
|
@@ -1064,6 +1071,18 @@ def weight_loader(self,
|
1064 | 1071 | shard_id: str,
|
1065 | 1072 | expert_id: int,
|
1066 | 1073 | return_success: bool = False) -> Optional[bool]:
|
| 1074 | + |
| 1075 | + if self.quant_config and self.quant_config.get_name() == "mxfp4": |
| 1076 | + # (FIXME) for gpt-oss all experts are combined |
| 1077 | + if "bias" in weight_name: |
| 1078 | + dim1 = loaded_weight.shape[1] |
| 1079 | + param.data[:, :dim1].copy_(loaded_weight) |
| 1080 | + else: |
| 1081 | + dim1 = loaded_weight.shape[1] |
| 1082 | + dim2 = loaded_weight.shape[2] |
| 1083 | + param.data[:, :dim1, :dim2].copy_(loaded_weight) |
| 1084 | + return True if return_success else None |
| 1085 | + |
1067 | 1086 | expert_id = self._map_global_expert_id_to_local_expert_id(expert_id)
|
1068 | 1087 | if expert_id == -1:
|
1069 | 1088 | # Failed to load this param since it's not local to this rank
|
@@ -1476,13 +1495,20 @@ def maybe_all_reduce_tensor_model_parallel(
|
1476 | 1495 |
|
1477 | 1496 | def forward(self, hidden_states: torch.Tensor,
|
1478 | 1497 | router_logits: torch.Tensor):
|
| 1498 | + og_hidden_states = hidden_states.shape[-1] |
| 1499 | + if self.hidden_size != og_hidden_states: |
| 1500 | + hidden_states = F.pad(hidden_states, |
| 1501 | + (0, self.hidden_size - og_hidden_states), |
| 1502 | + mode='constant', |
| 1503 | + value=0.0) |
1479 | 1504 | # TODO: Once the OOM issue for the TPU backend is resolved, we will
|
1480 | 1505 | # switch to using the moe_forward custom op.
|
1481 | 1506 | if current_platform.is_tpu():
|
1482 | 1507 | return self.forward_impl(hidden_states, router_logits)
|
1483 | 1508 | else:
|
1484 |
| - return torch.ops.vllm.moe_forward(hidden_states, router_logits, |
1485 |
| - self.layer_name) |
| 1509 | + return torch.ops.vllm.moe_forward( |
| 1510 | + hidden_states, router_logits, |
| 1511 | + self.layer_name)[..., :og_hidden_states] |
1486 | 1512 |
|
1487 | 1513 | def forward_impl_chunked(self, full_hidden_states: torch.Tensor,
|
1488 | 1514 | full_router_logits: torch.Tensor):
|
|
0 commit comments