Skip to content

Commit 5e4a822

Browse files
authored
[Qwen][ROCm] Flash Attention Rotary Embeddings (vllm-project#24642)
Signed-off-by: vllmellm <[email protected]>
1 parent e51de38 commit 5e4a822

File tree

2 files changed

+28
-5
lines changed

2 files changed

+28
-5
lines changed

vllm/model_executor/layers/rotary_embedding/common.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,21 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

44
import math
5+
from functools import cache
6+
from importlib.util import find_spec
7+
from typing import Callable
58

69
import torch
710

11+
from vllm.logger import init_logger
812
from vllm.platforms import current_platform
913
from vllm.utils import direct_register_custom_op
1014

1115
if current_platform.is_cuda():
1216
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
1317

18+
logger = init_logger(__name__)
19+
1420

1521
# common functions
1622
def rotate_neox(x: torch.Tensor) -> torch.Tensor:
@@ -65,6 +71,23 @@ def apply_rotary_emb_dispatch(x: torch.Tensor, cos: torch.Tensor,
6571
return apply_rotary_emb_torch(x, cos, sin, is_neox_style)
6672

6773

74+
@cache
75+
def dispatch_rotary_emb_function() -> Callable[..., torch.Tensor]:
76+
if current_platform.is_cuda():
77+
return apply_rotary_emb
78+
79+
if current_platform.is_rocm():
80+
if find_spec("flash_attn") is not None:
81+
from flash_attn.ops.triton.rotary import apply_rotary
82+
return apply_rotary
83+
else:
84+
logger.warning(
85+
"flash_attn is not installed. Falling back to PyTorch "
86+
"implementation for rotary embeddings.")
87+
88+
return apply_rotary_emb_torch
89+
90+
6891
# yarn functions
6992
# Inverse dim formula to find dim based on number of rotations
7093
def yarn_find_correction_dim(num_rotations: int,

vllm/model_executor/models/qwen2_vl.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@
5050
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
5151
RowParallelLinear)
5252
from vllm.model_executor.layers.quantization import QuantizationConfig
53+
from vllm.model_executor.layers.rotary_embedding.common import (
54+
dispatch_rotary_emb_function)
5355
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
5456
from vllm.model_executor.models.module_mapping import MultiModelKeys
5557
from vllm.multimodal import MULTIMODAL_REGISTRY
@@ -63,7 +65,7 @@
6365
BaseProcessingInfo, PromptReplacement,
6466
PromptUpdate)
6567
from vllm.multimodal.profiling import BaseDummyInputsBuilder
66-
from vllm.platforms import _Backend, current_platform
68+
from vllm.platforms import _Backend
6769
from vllm.sequence import IntermediateTensors
6870
from vllm.transformers_utils.tokenizer import AnyTokenizer
6971
from vllm.utils.tensor_schema import TensorSchema, TensorShape
@@ -272,13 +274,11 @@ def apply_rotary_emb_torch(x: torch.Tensor,
272274

273275
def apply_rotary_pos_emb_vision(t: torch.Tensor,
274276
freqs: torch.Tensor) -> torch.Tensor:
277+
rotary_emb_function = dispatch_rotary_emb_function()
275278
t_ = t.float()
276279
cos = freqs.cos()
277280
sin = freqs.sin()
278-
apply_rotary_emb = apply_rotary_emb_torch
279-
if current_platform.is_cuda():
280-
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
281-
output = apply_rotary_emb(t_, cos, sin).type_as(t)
281+
output = rotary_emb_function(t_, cos, sin).type_as(t)
282282
return output
283283

284284

0 commit comments

Comments
 (0)