Skip to content

Commit 9edcf5e

Browse files
committed
refactor the SparseAttnIndexer as CustomOp
Signed-off-by: ganyi <[email protected]>
1 parent 3b45a44 commit 9edcf5e

File tree

8 files changed

+622
-339
lines changed

8 files changed

+622
-339
lines changed

vllm/_aiter_ops.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@
77
from torch._ops import OpOverload
88

99
import vllm.envs as envs
10+
from vllm.attention.ops.rocm_aiter_mla_sparse import (
11+
rocm_aiter_sparse_attn_indexer,
12+
rocm_aiter_sparse_attn_indexer_fake,
13+
)
1014
from vllm.platforms import current_platform
1115
from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer
1216

@@ -1067,6 +1071,14 @@ def register_ops_once() -> None:
10671071
dispatch_key=current_platform.dispatch_key,
10681072
)
10691073

1074+
direct_register_custom_op(
1075+
op_name="rocm_aiter_sparse_attn_indexer",
1076+
op_func=rocm_aiter_sparse_attn_indexer,
1077+
mutates_args=["topk_indices_buffer"],
1078+
fake_impl=rocm_aiter_sparse_attn_indexer_fake,
1079+
dispatch_key=current_platform.dispatch_key,
1080+
)
1081+
10701082
_OPS_REGISTERED = True
10711083

10721084
@staticmethod

0 commit comments

Comments
 (0)