Skip to content

Commit dd96465

Browse files
xuechendiywang96
andauthored
[BugFix][QWEN-VL]fix wrong apply_rotary_emb_torch selection introduced by vllm-project#24642 (vllm-project#26123)
Signed-off-by: Chendi Xue <[email protected]> Signed-off-by: Chendi.Xue <[email protected]> Co-authored-by: Roger Wang <[email protected]>
1 parent 4f8f47e commit dd96465

File tree

2 files changed

+10
-4
lines changed

2 files changed

+10
-4
lines changed

vllm/model_executor/layers/rotary_embedding/common.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import math
55
from functools import cache
66
from importlib.util import find_spec
7-
from typing import Callable
7+
from typing import Callable, Optional
88

99
import torch
1010

@@ -72,7 +72,9 @@ def apply_rotary_emb_dispatch(x: torch.Tensor, cos: torch.Tensor,
7272

7373

7474
@cache
75-
def dispatch_rotary_emb_function() -> Callable[..., torch.Tensor]:
75+
def dispatch_rotary_emb_function(
76+
default: Optional[Callable[..., torch.Tensor]] = None
77+
) -> Callable[..., torch.Tensor]:
7678
if current_platform.is_cuda():
7779
return apply_rotary_emb
7880

@@ -85,7 +87,10 @@ def dispatch_rotary_emb_function() -> Callable[..., torch.Tensor]:
8587
"flash_attn is not installed. Falling back to PyTorch "
8688
"implementation for rotary embeddings.")
8789

88-
return apply_rotary_emb_torch
90+
if default is not None:
91+
return default
92+
else:
93+
return apply_rotary_emb_torch
8994

9095

9196
# yarn functions

vllm/model_executor/models/qwen2_vl.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,8 @@ def apply_rotary_emb_torch(x: torch.Tensor,
276276

277277
def apply_rotary_pos_emb_vision(t: torch.Tensor,
278278
freqs: torch.Tensor) -> torch.Tensor:
279-
rotary_emb_function = dispatch_rotary_emb_function()
279+
rotary_emb_function = dispatch_rotary_emb_function(
280+
default=apply_rotary_emb_torch)
280281
t_ = t.float()
281282
cos = freqs.cos()
282283
sin = freqs.sin()

0 commit comments

Comments
 (0)