Skip to content

Commit 8826599

Browse files
gshtrastjtanaahongxiayang
authored
Fix fused moe (ROCm#506)
* Added the extra use_irope parameter in Co-authored-by: Hongxia Yang <[email protected]> Signed-off-by: tjtanaa <[email protected]> * Fix ROCm V1 Engine Fused MoE Bug Signed-off-by: tjtanaa <[email protected]> * Add warning message that V0 do not support irope Signed-off-by: tjtanaa <[email protected]> --------- Signed-off-by: tjtanaa <[email protected]> Co-authored-by: tjtanaa <[email protected]> Co-authored-by: Hongxia Yang <[email protected]>
1 parent d17d4df commit 8826599

File tree

2 files changed

+19
-1
lines changed

2 files changed

+19
-1
lines changed

vllm/attention/backends/rocm_flash_attn.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -462,11 +462,15 @@ def __init__(
462462
blocksparse_params: Optional[Dict[str, Any]] = None,
463463
logits_soft_cap: Optional[float] = None,
464464
attn_type: str = AttentionType.DECODER,
465+
use_irope: bool = False,
465466
) -> None:
466467
if blocksparse_params is not None:
467468
raise ValueError(
468469
"ROCmFlashAttention does not support blocksparse attention.")
469-
470+
if use_irope:
471+
logger.warning(
472+
"Using irope in V0 is not supported yet, it will fall back "
473+
"to global attention for long context.")
470474
if logits_soft_cap is None:
471475
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
472476
self.logits_soft_cap = 0.0

vllm/model_executor/layers/fused_moe/fused_moe.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -461,6 +461,20 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
461461
use_int4_w4a16: bool,
462462
block_shape: Optional[List[int]] = None) -> None:
463463
assert topk_weights is not None or not mul_routed_weight
464+
if current_platform.is_rocm() and topk_weights is not None:
465+
# This is to handle the bug https://github.com/ROCm/pytorch/issues/2020
466+
# where the In the HIPGraph, it could occur that the `topk_weights`
467+
# tensor has the following properties:
468+
# .shape: ([1024, 1])
469+
# .is_contiguous(): True
470+
# .stride() : [1,1024]
471+
# .is_contiguous(memory_format=torch.channels_last) is False
472+
# .is_contiguous(memory_format=torch.contiguous_format) is True
473+
# This only happens when using V1 Engine on ROCm with HIPGraph
474+
# with torch.compile Dynamo.
475+
# V1 Engine on ROCm with eager mode is fine.
476+
# V0 Engine on ROCm with HIPGraph is fine.
477+
topk_weights = topk_weights.view(-1).reshape(topk_weights.shape)
464478
assert topk_weights is None or topk_weights.stride(1) == 1
465479
assert sorted_token_ids.stride(0) == 1
466480

0 commit comments

Comments
 (0)