Skip to content

Commit 68b4a38

Browse files
vllmellmjingyu
authored andcommitted
[ROCm][AITER] Support AITER Rope ops in RotaryEmbedding Module. (vllm-project#22521)
Signed-off-by: vllmellm <[email protected]> Signed-off-by: jingyu <[email protected]>
1 parent e366362 commit 68b4a38

File tree

4 files changed

+204
-10
lines changed

4 files changed

+204
-10
lines changed

vllm/model_executor/layers/rotary_embedding/base.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from vllm.model_executor.custom_op import CustomOp
99

1010
from .common import apply_rotary_emb_dispatch, apply_rotary_emb_torch
11+
from .rocm_aiter_rope_ops import is_rocm_rotary_embedding_enabled
1112

1213

1314
@CustomOp.register("rotary_embedding")
@@ -35,6 +36,7 @@ def __init__(
3536
cache = cache.to(dtype)
3637
self.cos_sin_cache: torch.Tensor
3738
self.register_buffer("cos_sin_cache", cache, persistent=False)
39+
self.is_rocm_aiter_enabled = is_rocm_rotary_embedding_enabled()
3840

3941
def _compute_inv_freq(self, base: float) -> torch.Tensor:
4042
"""Compute the inverse frequency."""
@@ -119,6 +121,75 @@ def forward_cuda(
119121
self.cos_sin_cache, self.is_neox_style)
120122
return query, key
121123

124+
def forward_hip(
125+
self,
126+
positions: torch.Tensor,
127+
query: torch.Tensor,
128+
key: Optional[torch.Tensor] = None,
129+
offsets: Optional[torch.Tensor] = None,
130+
is_nope_first=False,
131+
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
132+
# currently only rotary embedding ops from AITER package are
133+
# supported for HiP forward.
134+
if self.is_rocm_aiter_enabled:
135+
return self.forward_hip_rocm_aiter(positions, query, key, offsets,
136+
is_nope_first)
137+
return self.forward_native(positions, query, key, offsets)
138+
139+
def forward_hip_rocm_aiter(
140+
self,
141+
positions: torch.Tensor,
142+
# if is_nope_first
143+
# [[batch_size, seq_len, num_heads, nope_size+rope_size]
144+
# if NOT is_nope_first
145+
# [[batch_size, seq_len, num_heads, rope_size+nope_size],
146+
query: torch.Tensor,
147+
key: Optional[torch.Tensor] = None,
148+
offsets: Optional[torch.Tensor] = None,
149+
is_nope_first: bool = False,
150+
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
151+
if self.cos_sin_cache.device != query.device or \
152+
self.cos_sin_cache.dtype != query.dtype:
153+
self.cos_sin_cache = self.cos_sin_cache.to(query.device,
154+
dtype=query.dtype)
155+
cos, sin = self.cos_sin_cache.chunk(2, dim=-1)
156+
157+
cos = cos.unsqueeze(-2).unsqueeze(-2)
158+
sin = sin.unsqueeze(-2).unsqueeze(-2)
159+
160+
rotate_style = 0 if self.is_neox_style else 1
161+
162+
num_tokens = positions.numel()
163+
164+
query_shape = query.shape
165+
query = query.view(1, num_tokens, -1, self.head_size)
166+
if key is not None:
167+
key_shape = key.shape
168+
key = key.view(1, num_tokens, -1, self.head_size)
169+
170+
positions = positions.view(*query.shape[:2])
171+
if offsets is not None:
172+
offsets = offsets.view(*query.shape[:2])
173+
174+
if not is_nope_first:
175+
query_ = query[..., :self.rotary_dim]
176+
key_ = key[..., :self.rotary_dim] if key is not None else None
177+
else:
178+
query_ = query[..., -self.rotary_dim:]
179+
key_ = key[..., -self.rotary_dim:] if key is not None else None
180+
181+
if key_ is None:
182+
torch.ops.vllm.rocm_aiter_rotary_emb_without_key_forward_hip(
183+
positions, sin, cos, query_, offsets, rotate_style,
184+
is_nope_first)
185+
return query.view(query_shape), None
186+
187+
torch.ops.vllm.rocm_aiter_rotary_emb_with_key_forward_hip(
188+
positions, sin, cos, query_, key_, offsets, rotate_style,
189+
is_nope_first)
190+
191+
return query.view(query_shape), key.view(key_shape)
192+
122193
def forward_xpu(
123194
self,
124195
positions: torch.Tensor,

vllm/model_executor/layers/rotary_embedding/common.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def yarn_linear_ramp_mask(low: float, high: float, dim: int,
9999
return ramp_func
100100

101101

102-
def yarn_get_mscale(scale: float = 1) -> float:
102+
def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
103103
if scale <= 1:
104104
return 1.0
105-
return 0.1 * math.log(scale) + 1.0
105+
return 0.1 * mscale * math.log(scale) + 1.0

vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4-
import math
54
from typing import Optional
65

76
import torch
@@ -10,13 +9,7 @@
109

1110
from .base import RotaryEmbedding
1211
from .common import (rotate_gptj, rotate_neox, yarn_find_correction_range,
13-
yarn_linear_ramp_mask)
14-
15-
16-
def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
17-
if scale <= 1:
18-
return 1.0
19-
return 0.1 * mscale * math.log(scale) + 1.0
12+
yarn_get_mscale, yarn_linear_ramp_mask)
2013

2114

2215
class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
@@ -96,6 +89,9 @@ def forward(
9689
offsets: Optional[torch.Tensor] = None,
9790
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
9891
"""PyTorch-native implementation equivalent to forward()."""
92+
if self.is_rocm_aiter_enabled:
93+
return self.forward_hip_rocm_aiter(positions, query, key, offsets)
94+
9995
assert key is not None
10096
query_rot = query[..., :self.rotary_dim]
10197
key_rot = key[..., :self.rotary_dim]
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
from typing import Optional
5+
6+
import torch
7+
8+
import vllm.envs as envs
9+
from vllm.platforms import current_platform
10+
from vllm.utils import direct_register_custom_op
11+
12+
13+
def is_rocm_rotary_embedding_enabled() -> bool:
14+
return (current_platform.is_rocm() and envs.VLLM_ROCM_USE_AITER)
15+
16+
17+
def rocm_aiter_rotary_emb_without_key_forward_hip_impl(
18+
positions: torch.Tensor,
19+
sin: torch.Tensor,
20+
cos: torch.Tensor,
21+
query: torch.Tensor,
22+
offsets: Optional[torch.Tensor] = None,
23+
rotate_style: int = 0,
24+
is_nope_first: bool = False,
25+
) -> None:
26+
import aiter as ops
27+
if offsets is None:
28+
ops.rope_cached_positions_fwd_inplace(
29+
query,
30+
cos,
31+
sin,
32+
positions,
33+
rotate_style,
34+
reuse_freqs_front_part=True,
35+
nope_first=is_nope_first,
36+
)
37+
else:
38+
ops.rope_cached_positions_offsets_fwd_inplace(
39+
query,
40+
cos,
41+
sin,
42+
positions,
43+
offsets,
44+
rotate_style,
45+
reuse_freqs_front_part=True,
46+
nope_first=is_nope_first,
47+
)
48+
49+
50+
def rocm_aiter_rotary_emb_with_key_forward_hip_impl(
51+
positions: torch.Tensor,
52+
sin: torch.Tensor,
53+
cos: torch.Tensor,
54+
query: torch.Tensor,
55+
key: torch.Tensor,
56+
offsets: Optional[torch.Tensor] = None,
57+
rotate_style: int = 0,
58+
is_nope_first: bool = False,
59+
) -> None:
60+
import aiter as ops
61+
if offsets is None:
62+
ops.rope_cached_positions_2c_fwd_inplace(
63+
query,
64+
key,
65+
cos,
66+
sin,
67+
positions,
68+
rotate_style,
69+
reuse_freqs_front_part=True,
70+
nope_first=is_nope_first,
71+
)
72+
else:
73+
ops.rope_cached_positions_offsets_2c_fwd_inplace(
74+
query,
75+
key,
76+
cos,
77+
sin,
78+
positions,
79+
offsets,
80+
rotate_style,
81+
reuse_freqs_front_part=True,
82+
nope_first=is_nope_first,
83+
)
84+
85+
86+
def rocm_aiter_rotary_emb_with_key_forward_hip_fake(
87+
positions: torch.Tensor,
88+
sin: torch.Tensor,
89+
cos: torch.Tensor,
90+
query: torch.Tensor,
91+
key: torch.Tensor,
92+
offsets: Optional[torch.Tensor] = None,
93+
rotate_style: int = 0,
94+
is_nope_first: bool = False,
95+
) -> None:
96+
pass
97+
98+
99+
def rocm_aiter_rotary_emb_without_key_forward_hip_fake(
100+
positions: torch.Tensor,
101+
sin: torch.Tensor,
102+
cos: torch.Tensor,
103+
query: torch.Tensor,
104+
offsets: Optional[torch.Tensor] = None,
105+
rotate_style: int = 0,
106+
is_nope_first: bool = False,
107+
) -> None:
108+
pass
109+
110+
111+
if is_rocm_rotary_embedding_enabled():
112+
113+
direct_register_custom_op(
114+
op_name="rocm_aiter_rotary_emb_with_key_forward_hip",
115+
op_func=rocm_aiter_rotary_emb_with_key_forward_hip_impl,
116+
mutates_args=["key", "query"],
117+
fake_impl=rocm_aiter_rotary_emb_with_key_forward_hip_fake,
118+
dispatch_key=current_platform.dispatch_key,
119+
)
120+
121+
direct_register_custom_op(
122+
op_name="rocm_aiter_rotary_emb_without_key_forward_hip",
123+
op_func=rocm_aiter_rotary_emb_without_key_forward_hip_impl,
124+
mutates_args=["query"],
125+
fake_impl=rocm_aiter_rotary_emb_without_key_forward_hip_fake,
126+
dispatch_key=current_platform.dispatch_key,
127+
)

0 commit comments

Comments
 (0)