|
50 | 50 | from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
51 | 51 | RowParallelLinear)
|
52 | 52 | from vllm.model_executor.layers.quantization import QuantizationConfig
|
| 53 | +from vllm.model_executor.layers.rotary_embedding.common import ( |
| 54 | + dispatch_rotary_emb_function) |
53 | 55 | from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
54 | 56 | from vllm.model_executor.models.module_mapping import MultiModelKeys
|
55 | 57 | from vllm.multimodal import MULTIMODAL_REGISTRY
|
|
63 | 65 | BaseProcessingInfo, PromptReplacement,
|
64 | 66 | PromptUpdate)
|
65 | 67 | from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
66 |
| -from vllm.platforms import _Backend, current_platform |
| 68 | +from vllm.platforms import _Backend |
67 | 69 | from vllm.sequence import IntermediateTensors
|
68 | 70 | from vllm.transformers_utils.tokenizer import AnyTokenizer
|
69 | 71 | from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
@@ -272,13 +274,11 @@ def apply_rotary_emb_torch(x: torch.Tensor,
|
272 | 274 |
|
273 | 275 | def apply_rotary_pos_emb_vision(t: torch.Tensor,
|
274 | 276 | freqs: torch.Tensor) -> torch.Tensor:
|
| 277 | + rotary_emb_function = dispatch_rotary_emb_function() |
275 | 278 | t_ = t.float()
|
276 | 279 | cos = freqs.cos()
|
277 | 280 | sin = freqs.sin()
|
278 |
| - apply_rotary_emb = apply_rotary_emb_torch |
279 |
| - if current_platform.is_cuda(): |
280 |
| - from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb |
281 |
| - output = apply_rotary_emb(t_, cos, sin).type_as(t) |
| 281 | + output = rotary_emb_function(t_, cos, sin).type_as(t) |
282 | 282 | return output
|
283 | 283 |
|
284 | 284 |
|
|
0 commit comments