Skip to content

Commit 4d63faf

Browse files
committed
add AITER Triton RoPE as a registered ops with VLLM_USE_AITER_TRITON_ROPE=1
1 parent a45886a commit 4d63faf

File tree

3 files changed

+87
-29
lines changed

3 files changed

+87
-29
lines changed

vllm/model_executor/layers/rotary_embedding/base.py

Lines changed: 20 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,11 @@
55

66
import torch
77

8-
from vllm import envs
8+
import vllm.envs as envs
99
from vllm.model_executor.custom_op import CustomOp
1010

1111
from .common import apply_rotary_emb_dispatch, apply_rotary_emb_torch
12-
from .rocm_aiter_rope_ops import is_rocm_rotary_embedding_enabled
13-
12+
from .rocm_aiter_rope_ops import is_rocm_rotary_embedding_enabled, is_rocm_triton_rotary_embedding_enabled
1413

1514
@CustomOp.register("rotary_embedding")
1615
class RotaryEmbedding(CustomOp):
@@ -38,6 +37,7 @@ def __init__(
3837
self.cos_sin_cache: torch.Tensor
3938
self.register_buffer("cos_sin_cache", cache, persistent=False)
4039
self.is_rocm_aiter_enabled = is_rocm_rotary_embedding_enabled()
40+
self.is_rocm_aiter_triton_enabled = is_rocm_triton_rotary_embedding_enabled()
4141

4242
def _compute_inv_freq(self, base: float) -> torch.Tensor:
4343
"""Compute the inverse frequency."""
@@ -110,8 +110,8 @@ def forward_cuda(
110110
dtype=query.dtype)
111111

112112
num_tokens = positions.numel()
113-
if envs.VLLM_USE_AITER_TRITON_ROPE and num_tokens <= 128:
114-
import aiter.ops.triton.rope as ops
113+
114+
if envs.VLLM_USE_AITER_TRITON_ROPE:
115115
assert key is not None
116116
cos, sin = self.cos_sin_cache.chunk(2, dim=-1)
117117
query_shape = query.shape
@@ -120,30 +120,20 @@ def forward_cuda(
120120
key = key.view(num_tokens, -1, self.head_size)
121121
query_ = query[..., :self.rotary_dim]
122122
key_ = key[..., :self.rotary_dim]
123+
rotate_style = 0 if self.is_neox_style else 1
124+
positions = positions.view(*query.shape[:1])
123125
if offsets is not None:
124126
offsets = offsets.view(*query.shape[:1])
125-
ops.rope_cached_thd_positions_offsets_2c_fwd_inplace(
126-
query_,
127-
key_,
128-
cos,
129-
sin,
130-
positions,
131-
offsets,
132-
0,
133-
True,
134-
False,
135-
)
136-
else:
137-
ops.rope_cached_thd_positions_2c_fwd_inplace(
138-
query_,
139-
key_,
140-
cos,
141-
sin,
142-
positions,
143-
0,
144-
True,
145-
False,
146-
)
127+
torch.ops.vllm.rocm_aiter_rotary_emb_with_key_forward_triton(
128+
positions,
129+
sin,
130+
cos,
131+
query_,
132+
key_,
133+
offsets,
134+
rotate_style,
135+
False,
136+
)
147137
query = query.view(query_shape)
148138
key = key.view(key_shape)
149139
else:
@@ -173,7 +163,9 @@ def forward_hip(
173163
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
174164
# currently only rotary embedding ops from AITER package are
175165
# supported for HiP forward.
176-
if self.is_rocm_aiter_enabled:
166+
if self.is_rocm_aiter_triton_enabled:
167+
return self.forward_cuda(positions, query, key, offsets)
168+
elif self.is_rocm_aiter_enabled:
177169
return self.forward_hip_rocm_aiter(positions, query, key, offsets,
178170
is_nope_first)
179171
return self.forward_native(positions, query, key, offsets)

vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,9 @@ def forward(
8989
offsets: Optional[torch.Tensor] = None,
9090
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
9191
"""PyTorch-native implementation equivalent to forward()."""
92-
if self.is_rocm_aiter_enabled:
92+
if self.is_rocm_aiter_triton_enabled:
93+
return self.forward_cuda(positions, query, key, offsets)
94+
elif self.is_rocm_aiter_enabled:
9395
return self.forward_hip_rocm_aiter(positions, query, key, offsets)
9496

9597
assert key is not None

vllm/model_executor/layers/rotary_embedding/rocm_aiter_rope_ops.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313
def is_rocm_rotary_embedding_enabled() -> bool:
1414
return (current_platform.is_rocm() and envs.VLLM_ROCM_USE_AITER)
1515

16+
def is_rocm_triton_rotary_embedding_enabled() -> bool:
17+
return (current_platform.is_rocm() and envs.VLLM_ROCM_USE_AITER and envs.VLLM_USE_AITER_TRITON_ROPE)
18+
1619

1720
def rocm_aiter_rotary_emb_without_key_forward_hip_impl(
1821
positions: torch.Tensor,
@@ -124,4 +127,65 @@ def rocm_aiter_rotary_emb_without_key_forward_hip_fake(
124127
mutates_args=["query"],
125128
fake_impl=rocm_aiter_rotary_emb_without_key_forward_hip_fake,
126129
dispatch_key=current_platform.dispatch_key,
130+
)
131+
132+
133+
134+
def rocm_aiter_rotary_emb_with_key_forward_triton_impl(
135+
positions: torch.Tensor,
136+
sin: torch.Tensor,
137+
cos: torch.Tensor,
138+
query: torch.Tensor,
139+
key: torch.Tensor,
140+
offsets: Optional[torch.Tensor] = None,
141+
rotate_style: int = 0,
142+
is_nope_first: bool = False,
143+
) -> None:
144+
import aiter.ops.triton.rope as ops
145+
if offsets is None:
146+
ops.rope_cached_thd_positions_2c_fwd_inplace(
147+
query,
148+
key,
149+
cos,
150+
sin,
151+
positions,
152+
rotate_style,
153+
reuse_freqs_front_part=True,
154+
nope_first=is_nope_first,
155+
)
156+
else:
157+
ops.rope_cached_thd_positions_offsets_2c_fwd_inplace(
158+
query,
159+
key,
160+
cos,
161+
sin,
162+
positions,
163+
offsets,
164+
rotate_style,
165+
reuse_freqs_front_part=True,
166+
nope_first=is_nope_first,
167+
)
168+
169+
170+
def rocm_aiter_rotary_emb_with_key_forward_triton_fake(
171+
positions: torch.Tensor,
172+
sin: torch.Tensor,
173+
cos: torch.Tensor,
174+
query: torch.Tensor,
175+
key: torch.Tensor,
176+
offsets: Optional[torch.Tensor] = None,
177+
rotate_style: int = 0,
178+
is_nope_first: bool = False,
179+
) -> None:
180+
pass
181+
182+
183+
if is_rocm_triton_rotary_embedding_enabled():
184+
185+
direct_register_custom_op(
186+
op_name="rocm_aiter_rotary_emb_with_key_forward_triton",
187+
op_func=rocm_aiter_rotary_emb_with_key_forward_triton_impl,
188+
mutates_args=["key", "query"],
189+
fake_impl=rocm_aiter_rotary_emb_with_key_forward_triton_fake,
190+
dispatch_key=current_platform.dispatch_key,
127191
)

0 commit comments

Comments
 (0)