9
9
from typing import Any , Dict , List , Optional , Tuple , Type
10
10
11
11
import torch
12
+ import vllm_hpu_extension .kernels as kernels
12
13
import vllm_hpu_extension .ops as ops
13
- from vllm_hpu_extension .utils import ( Matmul , ModuleFusedSDPA , Softmax ,
14
- VLLMKVCache )
14
+ from vllm_hpu_extension .flags import enabled_flags
15
+ from vllm_hpu_extension . utils import Matmul , Softmax , VLLMKVCache
15
16
16
17
from vllm .attention .backends .abstract import (AttentionBackend , AttentionImpl ,
17
18
AttentionLayer ,
@@ -57,16 +58,16 @@ def get_kv_cache_shape(
57
58
def swap_blocks (
58
59
src_kv_cache : torch .Tensor ,
59
60
dst_kv_cache : torch .Tensor ,
60
- src_to_dst : Dict [ int , int ] ,
61
+ src_to_dsts : torch . Tensor ,
61
62
) -> None :
62
- HPUPagedAttention .swap_blocks (src_kv_cache , dst_kv_cache , src_to_dst )
63
+ HPUPagedAttention .swap_blocks (src_kv_cache , dst_kv_cache , src_to_dsts )
63
64
64
65
@staticmethod
65
66
def copy_blocks (
66
67
kv_caches : List [torch .Tensor ],
67
- src_to_dists : Dict [ int , List [ int ]] ,
68
+ src_to_dsts : torch . Tensor ,
68
69
) -> None :
69
- HPUPagedAttention .copy_blocks (kv_caches , src_to_dists )
70
+ HPUPagedAttention .copy_blocks (kv_caches , src_to_dsts )
70
71
71
72
72
73
@dataclass
@@ -77,6 +78,7 @@ class HPUAttentionMetadata(HPUPagedAttentionMetadata, AttentionMetadata):
77
78
is_prompt : bool
78
79
attn_bias : Optional [torch .Tensor ]
79
80
seq_lens_tensor : Optional [torch .Tensor ]
81
+ context_lens_tensor : Optional [torch .Tensor ]
80
82
81
83
82
84
class HPUAttentionImpl (AttentionImpl , torch .nn .Module ):
@@ -126,7 +128,15 @@ def __init__(
126
128
self .block2batch_matmul = Matmul ()
127
129
self .k_cache = VLLMKVCache ()
128
130
self .v_cache = VLLMKVCache ()
129
- ops .pa_impl = ops .pa
131
+ self .fused_scaled_dot_product_attention = kernels .fsdpa ()
132
+
133
+ self .prefill_impl = 'naive'
134
+ if "flex_attention" in enabled_flags ():
135
+ self .prefill_impl = 'flex'
136
+ if "fsdpa" in enabled_flags ():
137
+ assert alibi_slopes is None , \
138
+ 'Prefill with FusedSDPA not supported with alibi slopes!'
139
+ self .prefill_impl = 'fsdpa'
130
140
131
141
self .num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
132
142
self .sliding_window = sliding_window
@@ -138,27 +148,18 @@ def __init__(
138
148
assert self .num_heads % self .num_kv_heads == 0
139
149
self .num_queries_per_kv = self .num_heads // self .num_kv_heads
140
150
141
- self .prefill_usefusedsdpa = os .getenv ('VLLM_PROMPT_USE_FUSEDSDPA' ,
142
- '0' ).lower () in ['1' , 'true' ]
143
- self .fused_scaled_dot_product_attention = None
144
- if self .prefill_usefusedsdpa :
151
+ if self .prefill_impl == 'fsdpa' :
145
152
assert alibi_slopes is None , \
146
153
'Prefill with FusedSDPA not supported with alibi slopes!'
147
- try :
148
- from habana_frameworks .torch .hpex .kernels import FusedSDPA
149
- self .fused_scaled_dot_product_attention = ModuleFusedSDPA (
150
- FusedSDPA )
151
- except ImportError :
152
- logger .warning ("Could not import HPU FusedSDPA kernel. "
153
- "vLLM will use native implementation." )
154
154
155
155
supported_head_sizes = HPUPagedAttention .get_supported_head_sizes ()
156
156
if head_size not in supported_head_sizes :
157
157
raise ValueError (
158
158
f"Head size { head_size } is not supported by PagedAttention. "
159
159
f"Supported head sizes are: { supported_head_sizes } ." )
160
160
161
- if attn_type != AttentionType .DECODER :
161
+ self .attn_type = attn_type
162
+ if self .attn_type != AttentionType .DECODER :
162
163
raise NotImplementedError ("Encoder self-attention and "
163
164
"encoder/decoder cross-attention "
164
165
"are not implemented for "
@@ -192,15 +193,17 @@ def forward(
192
193
batch_size , seq_len , hidden_size = query .shape
193
194
_ , seq_len_kv , _ = key .shape
194
195
195
- query = query .view (- 1 , self .num_heads , self .head_size )
196
196
key = key .view (- 1 , self .num_kv_heads , self .head_size )
197
197
value = value .view (- 1 , self .num_kv_heads , self .head_size )
198
198
block_indices = attn_metadata .block_indices
199
199
block_offsets = attn_metadata .block_offsets
200
- if attn_metadata .is_prompt :
200
+ key_cache = None
201
+ value_cache = None
202
+ if attn_metadata .is_prompt and self .attn_type \
203
+ is not AttentionType .ENCODER_ONLY :
201
204
key = key .unflatten (0 , (block_indices .size (0 ), - 1 ))
202
205
value = value .unflatten (0 , (block_indices .size (0 ), - 1 ))
203
- if kv_cache is not None :
206
+ if kv_cache is not None and isinstance ( kv_cache , tuple ) :
204
207
key_cache , value_cache = HPUPagedAttention .split_kv_cache (
205
208
kv_cache , self .num_kv_heads , self .head_size )
206
209
@@ -214,36 +217,32 @@ def forward(
214
217
215
218
if attn_metadata .is_prompt :
216
219
# Prompt run.
217
- if not self .prefill_usefusedsdpa :
218
- # TODO: move this outside of model
219
- assert attn_metadata .attn_bias is not None , \
220
- 'attn_bias must be set before calling model.forward!'
221
- attn_bias = attn_metadata .attn_bias
222
- if self .alibi_slopes is not None :
223
- position_bias = _make_alibi_bias (self .alibi_slopes ,
224
- self .num_kv_heads ,
225
- attn_bias .dtype ,
226
- attn_bias .shape [- 1 ])
227
- attn_bias = attn_bias .tile ((1 , self .num_kv_heads , 1 , 1 ))
228
- attn_bias .add_ (position_bias )
229
- else :
230
- attn_bias = None
231
-
232
220
query_shape = (batch_size , seq_len , self .num_heads , self .head_size )
233
221
kv_shape = (batch_size , seq_len_kv , self .num_kv_heads ,
234
222
self .head_size )
223
+
224
+ attn_bias = attn_metadata .attn_bias
225
+ if attn_bias is not None and self .alibi_slopes is not None :
226
+ position_bias = _make_alibi_bias (self .alibi_slopes ,
227
+ self .num_kv_heads ,
228
+ attn_bias .dtype ,
229
+ attn_bias .shape [- 1 ])
230
+ attn_bias = attn_bias .tile ((1 , self .num_kv_heads , 1 , 1 ))
231
+ attn_bias .add_ (position_bias )
232
+
233
+ block_list = attn_metadata .block_list if attn_metadata \
234
+ and attn_metadata .block_list is not None else None
235
+
235
236
out = ops .prompt_attention (
236
- query .view (query_shape ),
237
- key .view (kv_shape ),
238
- value .view (kv_shape ),
237
+ impl = self .prefill_impl ,
238
+ query = query .view (query_shape ),
239
+ key = key .view (kv_shape ),
240
+ value = value .view (kv_shape ),
241
+ is_causal = True ,
239
242
attn_bias = attn_bias ,
240
- p = 0.0 ,
241
- scale = self .scale ,
242
- matmul_qk_op = self .matmul_qk ,
243
- softmax_op = self .softmax ,
244
- matmul_av_op = self .matmul_av ,
245
- fsdpa_op = self .fused_scaled_dot_product_attention ,
246
- )
243
+ valid_seq_lengths = attn_metadata .seq_lens_tensor ,
244
+ ** self .common_attention_args (block_list , key_cache ,
245
+ value_cache ))
247
246
output = out .reshape (batch_size , seq_len , hidden_size )
248
247
else :
249
248
# Decoding run.
@@ -267,6 +266,29 @@ def forward(
267
266
return output .view (batch_size , seq_len , hidden_size )
268
267
269
268
269
+ def common_attention_args (self ,
270
+ block_list = None ,
271
+ key_cache = None ,
272
+ value_cache = None ):
273
+ fsdpa_op = self .fused_scaled_dot_product_attention .apply \
274
+ if self .fused_scaled_dot_product_attention is not None else None
275
+
276
+ return {
277
+ 'scale' : self .scale ,
278
+ 'matmul_qk_op' : self .matmul_qk ,
279
+ 'matmul_av_op' : self .matmul_av ,
280
+ 'batch2block_matmul_op' : self .batch2block_matmul ,
281
+ 'block2batch_matmul_op' : self .block2batch_matmul ,
282
+ 'fsdpa_op' : fsdpa_op ,
283
+ 'keys_fetch_func' : self .k_cache .fetch_from_cache ,
284
+ 'values_fetch_func' : self .v_cache .fetch_from_cache ,
285
+ 'softmax_op' : self .softmax ,
286
+ 'block_list' : block_list ,
287
+ 'key_cache' : key_cache ,
288
+ 'value_cache' : value_cache ,
289
+ }
290
+
291
+
270
292
def _make_alibi_bias (
271
293
alibi_slopes : torch .Tensor ,
272
294
num_kv_heads : int ,
0 commit comments