Skip to content

Commit 9c4f484

Browse files
committed
Add APC support for HPU
Signed-off-by: Agata Dobrzyniewicz <[email protected]>
1 parent 976711d commit 9c4f484

File tree

5 files changed

+193
-93
lines changed

5 files changed

+193
-93
lines changed

requirements/hpu.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,4 @@ numpy==1.26.4
99
tabulate
1010
setuptools>=61
1111
setuptools-scm>=8
12-
vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@4312768
12+
vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@62ad004

vllm/attention/backends/hpu_attn.py

Lines changed: 69 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,10 @@
99
from typing import Any, Dict, List, Optional, Tuple, Type
1010

1111
import torch
12+
import vllm_hpu_extension.kernels as kernels
1213
import vllm_hpu_extension.ops as ops
13-
from vllm_hpu_extension.utils import (Matmul, ModuleFusedSDPA, Softmax,
14-
VLLMKVCache)
14+
from vllm_hpu_extension.flags import enabled_flags
15+
from vllm_hpu_extension.utils import Matmul, Softmax, VLLMKVCache
1516

1617
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
1718
AttentionLayer,
@@ -57,16 +58,16 @@ def get_kv_cache_shape(
5758
def swap_blocks(
5859
src_kv_cache: torch.Tensor,
5960
dst_kv_cache: torch.Tensor,
60-
src_to_dst: Dict[int, int],
61+
src_to_dsts: torch.Tensor,
6162
) -> None:
62-
HPUPagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
63+
HPUPagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dsts)
6364

6465
@staticmethod
6566
def copy_blocks(
6667
kv_caches: List[torch.Tensor],
67-
src_to_dists: Dict[int, List[int]],
68+
src_to_dsts: torch.Tensor,
6869
) -> None:
69-
HPUPagedAttention.copy_blocks(kv_caches, src_to_dists)
70+
HPUPagedAttention.copy_blocks(kv_caches, src_to_dsts)
7071

7172

7273
@dataclass
@@ -77,6 +78,7 @@ class HPUAttentionMetadata(HPUPagedAttentionMetadata, AttentionMetadata):
7778
is_prompt: bool
7879
attn_bias: Optional[torch.Tensor]
7980
seq_lens_tensor: Optional[torch.Tensor]
81+
context_lens_tensor: Optional[torch.Tensor]
8082

8183

8284
class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
@@ -126,7 +128,15 @@ def __init__(
126128
self.block2batch_matmul = Matmul()
127129
self.k_cache = VLLMKVCache()
128130
self.v_cache = VLLMKVCache()
129-
ops.pa_impl = ops.pa
131+
self.fused_scaled_dot_product_attention = kernels.fsdpa()
132+
133+
self.prefill_impl = 'naive'
134+
if "flex_attention" in enabled_flags():
135+
self.prefill_impl = 'flex'
136+
if "fsdpa" in enabled_flags():
137+
assert alibi_slopes is None, \
138+
'Prefill with FusedSDPA not supported with alibi slopes!'
139+
self.prefill_impl = 'fsdpa'
130140

131141
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
132142
self.sliding_window = sliding_window
@@ -138,27 +148,18 @@ def __init__(
138148
assert self.num_heads % self.num_kv_heads == 0
139149
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
140150

141-
self.prefill_usefusedsdpa = os.getenv('VLLM_PROMPT_USE_FUSEDSDPA',
142-
'0').lower() in ['1', 'true']
143-
self.fused_scaled_dot_product_attention = None
144-
if self.prefill_usefusedsdpa:
151+
if self.prefill_impl == 'fsdpa':
145152
assert alibi_slopes is None, \
146153
'Prefill with FusedSDPA not supported with alibi slopes!'
147-
try:
148-
from habana_frameworks.torch.hpex.kernels import FusedSDPA
149-
self.fused_scaled_dot_product_attention = ModuleFusedSDPA(
150-
FusedSDPA)
151-
except ImportError:
152-
logger.warning("Could not import HPU FusedSDPA kernel. "
153-
"vLLM will use native implementation.")
154154

155155
supported_head_sizes = HPUPagedAttention.get_supported_head_sizes()
156156
if head_size not in supported_head_sizes:
157157
raise ValueError(
158158
f"Head size {head_size} is not supported by PagedAttention. "
159159
f"Supported head sizes are: {supported_head_sizes}.")
160160

161-
if attn_type != AttentionType.DECODER:
161+
self.attn_type = attn_type
162+
if self.attn_type != AttentionType.DECODER:
162163
raise NotImplementedError("Encoder self-attention and "
163164
"encoder/decoder cross-attention "
164165
"are not implemented for "
@@ -192,15 +193,17 @@ def forward(
192193
batch_size, seq_len, hidden_size = query.shape
193194
_, seq_len_kv, _ = key.shape
194195

195-
query = query.view(-1, self.num_heads, self.head_size)
196196
key = key.view(-1, self.num_kv_heads, self.head_size)
197197
value = value.view(-1, self.num_kv_heads, self.head_size)
198198
block_indices = attn_metadata.block_indices
199199
block_offsets = attn_metadata.block_offsets
200-
if attn_metadata.is_prompt:
200+
key_cache = None
201+
value_cache = None
202+
if attn_metadata.is_prompt and self.attn_type \
203+
is not AttentionType.ENCODER_ONLY:
201204
key = key.unflatten(0, (block_indices.size(0), -1))
202205
value = value.unflatten(0, (block_indices.size(0), -1))
203-
if kv_cache is not None:
206+
if kv_cache is not None and isinstance(kv_cache, tuple):
204207
key_cache, value_cache = HPUPagedAttention.split_kv_cache(
205208
kv_cache, self.num_kv_heads, self.head_size)
206209

@@ -214,36 +217,32 @@ def forward(
214217

215218
if attn_metadata.is_prompt:
216219
# Prompt run.
217-
if not self.prefill_usefusedsdpa:
218-
# TODO: move this outside of model
219-
assert attn_metadata.attn_bias is not None, \
220-
'attn_bias must be set before calling model.forward!'
221-
attn_bias = attn_metadata.attn_bias
222-
if self.alibi_slopes is not None:
223-
position_bias = _make_alibi_bias(self.alibi_slopes,
224-
self.num_kv_heads,
225-
attn_bias.dtype,
226-
attn_bias.shape[-1])
227-
attn_bias = attn_bias.tile((1, self.num_kv_heads, 1, 1))
228-
attn_bias.add_(position_bias)
229-
else:
230-
attn_bias = None
231-
232220
query_shape = (batch_size, seq_len, self.num_heads, self.head_size)
233221
kv_shape = (batch_size, seq_len_kv, self.num_kv_heads,
234222
self.head_size)
223+
224+
attn_bias = attn_metadata.attn_bias
225+
if attn_bias is not None and self.alibi_slopes is not None:
226+
position_bias = _make_alibi_bias(self.alibi_slopes,
227+
self.num_kv_heads,
228+
attn_bias.dtype,
229+
attn_bias.shape[-1])
230+
attn_bias = attn_bias.tile((1, self.num_kv_heads, 1, 1))
231+
attn_bias.add_(position_bias)
232+
233+
block_list = attn_metadata.block_list if attn_metadata \
234+
and attn_metadata.block_list is not None else None
235+
235236
out = ops.prompt_attention(
236-
query.view(query_shape),
237-
key.view(kv_shape),
238-
value.view(kv_shape),
237+
impl=self.prefill_impl,
238+
query=query.view(query_shape),
239+
key=key.view(kv_shape),
240+
value=value.view(kv_shape),
241+
is_causal=True,
239242
attn_bias=attn_bias,
240-
p=0.0,
241-
scale=self.scale,
242-
matmul_qk_op=self.matmul_qk,
243-
softmax_op=self.softmax,
244-
matmul_av_op=self.matmul_av,
245-
fsdpa_op=self.fused_scaled_dot_product_attention,
246-
)
243+
valid_seq_lengths=attn_metadata.seq_lens_tensor,
244+
**self.common_attention_args(block_list, key_cache,
245+
value_cache))
247246
output = out.reshape(batch_size, seq_len, hidden_size)
248247
else:
249248
# Decoding run.
@@ -267,6 +266,29 @@ def forward(
267266
return output.view(batch_size, seq_len, hidden_size)
268267

269268

269+
def common_attention_args(self,
270+
block_list=None,
271+
key_cache=None,
272+
value_cache=None):
273+
fsdpa_op = self.fused_scaled_dot_product_attention.apply \
274+
if self.fused_scaled_dot_product_attention is not None else None
275+
276+
return {
277+
'scale': self.scale,
278+
'matmul_qk_op': self.matmul_qk,
279+
'matmul_av_op': self.matmul_av,
280+
'batch2block_matmul_op': self.batch2block_matmul,
281+
'block2batch_matmul_op': self.block2batch_matmul,
282+
'fsdpa_op': fsdpa_op,
283+
'keys_fetch_func': self.k_cache.fetch_from_cache,
284+
'values_fetch_func': self.v_cache.fetch_from_cache,
285+
'softmax_op': self.softmax,
286+
'block_list': block_list,
287+
'key_cache': key_cache,
288+
'value_cache': value_cache,
289+
}
290+
291+
270292
def _make_alibi_bias(
271293
alibi_slopes: torch.Tensor,
272294
num_kv_heads: int,

vllm/attention/ops/hpu_paged_attn.py

Lines changed: 9 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
###############################################################################
66

77
from dataclasses import dataclass
8-
from typing import Dict, List, Optional, Tuple
8+
from typing import List, Optional, Tuple
99

1010
import torch
1111
from vllm_hpu_extension import cache_ops, ops
@@ -64,43 +64,25 @@ def write_to_paged_cache(key: torch.Tensor, value: torch.Tensor,
6464
def forward_decode(**kwargs) -> torch.Tensor:
6565
return ops.flat_pa(**kwargs)
6666

67-
@staticmethod
68-
def forward_prefix(
69-
query: torch.Tensor,
70-
key: torch.Tensor,
71-
value: torch.Tensor,
72-
key_cache: torch.Tensor,
73-
value_cache: torch.Tensor,
74-
block_tables: torch.Tensor,
75-
subquery_start_loc: torch.Tensor,
76-
seq_lens_tensor: torch.Tensor,
77-
context_lens: torch.Tensor,
78-
max_query_len: int,
79-
alibi_slopes: Optional[torch.Tensor],
80-
sliding_window: Optional[int],
81-
) -> torch.Tensor:
82-
raise NotImplementedError(
83-
"forward_prefix is not implemented for HPUPagedAttention")
84-
8567
@staticmethod
8668
def swap_blocks(
87-
src_kv_cache: torch.Tensor,
88-
dst_kv_cache: torch.Tensor,
89-
src_to_dst: Dict[int, int],
69+
src_kv_cache: Tuple[torch.Tensor, torch.Tensor],
70+
dst_kv_cache: Tuple[torch.Tensor, torch.Tensor],
71+
src_to_dsts: torch.Tensor,
9072
) -> None:
9173
src_key_cache = src_kv_cache[0]
9274
dst_key_cache = dst_kv_cache[0]
93-
cache_ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst)
75+
cache_ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dsts)
9476

9577
src_value_cache = src_kv_cache[1]
9678
dst_value_cache = dst_kv_cache[1]
97-
cache_ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst)
79+
cache_ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dsts)
9880

9981
@staticmethod
10082
def copy_blocks(
101-
kv_caches: List[torch.Tensor],
102-
src_to_dists: Dict[int, List[int]],
83+
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
84+
src_to_dsts: torch.Tensor,
10385
) -> None:
10486
key_caches = [kv_cache[0] for kv_cache in kv_caches]
10587
value_caches = [kv_cache[1] for kv_cache in kv_caches]
106-
cache_ops.copy_blocks(key_caches, value_caches, src_to_dists)
88+
cache_ops.copy_blocks(key_caches, value_caches, src_to_dsts)

vllm/model_executor/layers/layernorm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,8 @@ def forward_hpu(
168168
x: torch.Tensor,
169169
residual: Optional[torch.Tensor] = None,
170170
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
171-
from vllm_hpu_extension.ops import HPUFusedRMSNorm
171+
from vllm_hpu_extension.kernels import rms_norm
172+
HPUFusedRMSNorm = rms_norm()
172173
if HPUFusedRMSNorm is None:
173174
return self.forward_native(x, residual)
174175
if residual is not None:

0 commit comments

Comments
 (0)