Skip to content

Commit ba9d00d

Browse files
committed
Merge remote-tracking branch 'origin/fix_env' into 0902_rc1
2 parents 0aed031 + 6186904 commit ba9d00d

File tree

3 files changed

+131
-75
lines changed

3 files changed

+131
-75
lines changed

vllm/attention/backends/rocm_flash_attn.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -828,7 +828,8 @@ def forward(
828828
use_custom = use_rocm_custom_paged_attention(
829829
decode_query.dtype, head_size, block_size, gqa_ratio,
830830
decode_meta.max_decode_seq_len, self.sliding_window,
831-
self.kv_cache_dtype, self.alibi_slopes)
831+
self.kv_cache_dtype,
832+
self.alibi_slopes) and not is_rocm_aiter_paged_attn_enabled()
832833

833834
if use_custom:
834835
max_seq_len = (decode_meta.max_decode_seq_len if self.attn_type

vllm/platforms/rocm.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -144,9 +144,7 @@ def use_rocm_custom_paged_attention(
144144
and (block_size == 16 or block_size == 32)
145145
and (gqa_ratio >= 1 and gqa_ratio <= 16)
146146
and max_seq_len <= 128 * 1024
147-
and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)
148-
and not (envs.VLLM_ROCM_USE_AITER_PAGED_ATTN
149-
and envs.VLLM_ROCM_USE_AITER) and sinks is None)
147+
and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN) and sinks is None)
150148

151149
else:
152150
return (ON_GFX11_GFX12 and (not envs.VLLM_USE_V1 or sliding_window == 0

vllm/v1/attention/backends/rocm_aiter_fa.py

Lines changed: 128 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -89,23 +89,34 @@ def _vllm_layout_trans_kernel(
8989
tl.store(k_values_ptr + kv_values_off, k_vals, mask=block_mask)
9090
tl.store(v_values_ptr + kv_values_off, v_vals, mask=block_mask)
9191

92-
def vllm_layout_trans(b_query_lens_loc, b_seq_lens_loc, block_table,
93-
k_cache, v_cache, max_seq_len, k_scale, v_scale,
94-
output_dtype, total_tokens):
92+
@torch.inference_mode()
93+
def vllm_layout_trans(b_query_lens_loc,
94+
b_seq_lens_loc,
95+
block_table,
96+
k_cache,
97+
v_cache,
98+
max_seq_len,
99+
k_scale,
100+
v_scale,
101+
output_dtype,
102+
total_tokens,
103+
k_values=None,
104+
v_values=None):
95105
H_KV = v_cache.shape[2]
96106
D = v_cache.shape[3]
97107
BLOCK_SIZE = v_cache.shape[1]
98-
99-
k_values = torch.empty(
100-
(total_tokens, H_KV, D),
101-
dtype=output_dtype,
102-
device=k_cache.device,
103-
)
104-
v_values = torch.empty(
105-
(total_tokens, H_KV, D),
106-
dtype=output_dtype,
107-
device=v_cache.device,
108-
)
108+
if k_values is None:
109+
k_values = torch.empty(
110+
(total_tokens, H_KV, D),
111+
dtype=output_dtype,
112+
device=k_cache.device,
113+
)
114+
if v_values is None:
115+
v_values = torch.empty(
116+
(total_tokens, H_KV, D),
117+
dtype=output_dtype,
118+
device=v_cache.device,
119+
)
109120

110121
grid = (block_table.shape[0],
111122
(max_seq_len + BLOCK_SIZE - 1) // BLOCK_SIZE)
@@ -148,13 +159,14 @@ def flash_attn_varlen_func_impl(
148159
block_table: torch.Tensor,
149160
k_scale: torch.Tensor,
150161
v_scale: torch.Tensor,
151-
total_tokens: int = 0,
162+
total_tokens: int,
163+
k_values: Optional[torch.Tensor] = None,
164+
v_values: Optional[torch.Tensor] = None,
152165
) -> torch.Tensor:
153-
if total_tokens == 0:
154-
total_tokens = int(cu_seqlens_k[-1].item())
155166
k, v = vllm_layout_trans(cu_seqlens_q, cu_seqlens_k, block_table,
156167
k_cache, v_cache, max_seqlen_k, k_scale,
157-
v_scale, q.dtype, total_tokens)
168+
v_scale, q.dtype, total_tokens, k_values,
169+
v_values)
158170

159171
output = aiter.flash_attn_varlen_func(
160172
q=q,
@@ -222,24 +234,27 @@ class AiterFlashAttentionMetadata:
222234
seq_lens: torch.Tensor
223235
slot_mapping: torch.Tensor
224236
block_table: torch.Tensor
225-
cu_seq_lens: Optional[torch.Tensor]
226237

227238
# For cascade attention.
228239
use_cascade: bool
229240
common_prefix_len: int
230-
total_tokens: int
241+
k_buffer: torch.Tensor
242+
v_buffer: torch.Tensor
243+
workspace_buffer: torch.Tensor
244+
cu_seq_lens: torch.Tensor
231245

232246

233247
class AiterFlashAttentionMetadataBuilder(
234248
AttentionMetadataBuilder[AiterFlashAttentionMetadata]):
235-
cudagraph_support = AttentionCGSupport.ALWAYS
249+
cudagraph_support = AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
236250

237251
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
238252
vllm_config: VllmConfig, device: torch.device):
239253
self.vllm_config = vllm_config
240254
self.model_config = vllm_config.model_config
241255
self.parallel_config = vllm_config.parallel_config
242256
self.cache_config = vllm_config.cache_config
257+
self.compilation_config = vllm_config.compilation_config
243258
self.device = device
244259

245260
self.num_heads_q = self.model_config.get_num_attention_heads(
@@ -249,53 +264,68 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
249264
self.headdim = self.model_config.get_head_size()
250265
self.block_size = kv_cache_spec.block_size
251266
self.kv_cache_spec = kv_cache_spec
252-
# Sliding window size to be used with the AOT scheduler will be
253-
# populated on first build() call.
254-
self.aot_sliding_window: Optional[tuple[int, int]] = None
255-
self.total_tokens: int = 0
256-
257-
def build_for_cudagraph_capture(
258-
self, common_attn_metadata: CommonAttentionMetadata):
259-
self.total_tokens = self.model_config.max_model_len \
260-
* self.vllm_config.scheduler_config.max_num_partial_prefills
261-
res = self.build(common_prefix_len=0,
262-
common_attn_metadata=common_attn_metadata)
263-
self.total_tokens = 0
264-
return res
265267

266268
def build(self,
267269
common_prefix_len: int,
268270
common_attn_metadata: CommonAttentionMetadata,
269271
fast_build: bool = False) -> 'AiterFlashAttentionMetadata':
270-
271272
num_actual_tokens = common_attn_metadata.num_actual_tokens
272273
max_query_len = common_attn_metadata.max_query_len
273-
max_seq_len = common_attn_metadata.max_seq_len
274+
274275
query_start_loc = common_attn_metadata.query_start_loc
275276
seq_lens = common_attn_metadata.seq_lens
276277
block_table_tensor = common_attn_metadata.block_table_tensor
277278
slot_mapping = common_attn_metadata.slot_mapping
278-
if max_query_len > 1:
279-
# We pre-compute cumulative seq len needed for prefill attention
280-
# here to avoid recomputing it for every layer
281-
cu_seq_lens = torch.zeros(seq_lens.shape[0] + 1,
282-
dtype=torch.int32,
283-
device=seq_lens.device)
284-
torch.cumsum(seq_lens,
285-
dim=0,
286-
dtype=cu_seq_lens.dtype,
287-
out=cu_seq_lens[1:])
288-
num_actual_kv_tokens = int(cu_seq_lens[-1].item())
289-
else:
290-
cu_seq_lens = None
291-
num_actual_kv_tokens = 0
292-
293-
def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
294-
max_seq_len, causal):
295-
return None
279+
num_seqs = common_attn_metadata.num_reqs
280+
max_seq_len = common_attn_metadata.max_seq_len
281+
num_actual_kv_tokens = int(seq_lens.sum())
296282

297283
use_cascade = common_prefix_len > 0
298284

285+
nbytes_per_qo_elem = torch.finfo(self.model_config.dtype).bits // 8
286+
max_num_partitions = (max_seq_len + _PARTITION_SIZE_ROCM -
287+
1) // _PARTITION_SIZE_ROCM
288+
empty_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info()
289+
k_buffer = None
290+
v_buffer = None
291+
workspace_buffer = None
292+
cu_seq_lens = None
293+
if max_query_len > 1:
294+
required_memory = num_actual_kv_tokens * \
295+
self.num_heads_kv * self.headdim * 2 * 2
296+
if required_memory >= empty_gpu_memory:
297+
raise ValueError(
298+
f"Not enough GPU memory to allocate k_buffer and v_buffer. "
299+
f"Required: {required_memory} bytes, "
300+
f"Available: {empty_gpu_memory} bytes, please reduce the "
301+
f"max_num_seqs or max_model_len.")
302+
if not torch.cuda.graphs.is_current_stream_capturing():
303+
k_buffer = torch.empty(
304+
(num_actual_kv_tokens, self.num_heads_kv, self.headdim),
305+
dtype=self.model_config.dtype,
306+
device=self.device,
307+
)
308+
v_buffer = torch.empty(
309+
(num_actual_kv_tokens, self.num_heads_kv, self.headdim),
310+
dtype=self.model_config.dtype,
311+
device=self.device,
312+
)
313+
cu_seq_lens = torch.zeros(seq_lens.shape[0] + 1,
314+
dtype=torch.int32,
315+
device=self.device)
316+
torch.cumsum(seq_lens,
317+
dim=0,
318+
dtype=cu_seq_lens.dtype,
319+
out=cu_seq_lens[1:])
320+
321+
workspace_buffer = torch.empty(
322+
(num_seqs * self.num_heads_q * max_num_partitions * self.headdim) *
323+
nbytes_per_qo_elem + 2 *
324+
(num_seqs * self.num_heads_q * max_num_partitions) * 4,
325+
dtype=torch.uint8,
326+
device=self.device,
327+
)
328+
299329
attn_metadata = AiterFlashAttentionMetadata(
300330
num_actual_tokens=num_actual_tokens,
301331
num_actual_kv_tokens=num_actual_kv_tokens,
@@ -305,10 +335,12 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
305335
seq_lens=seq_lens,
306336
block_table=block_table_tensor,
307337
slot_mapping=slot_mapping,
308-
cu_seq_lens=cu_seq_lens,
309338
use_cascade=use_cascade,
310339
common_prefix_len=common_prefix_len,
311-
total_tokens=self.total_tokens,
340+
k_buffer=k_buffer,
341+
v_buffer=v_buffer,
342+
workspace_buffer=workspace_buffer,
343+
cu_seq_lens=cu_seq_lens,
312344
)
313345
return attn_metadata
314346

@@ -381,6 +413,7 @@ def __init__(
381413
logits_soft_cap: Optional[float] = None,
382414
attn_type: AttentionType = AttentionType.DECODER,
383415
kv_sharing_target_layer_name: Optional[int] = None,
416+
sinks: Optional[torch.Tensor] = None,
384417
) -> None:
385418
self.num_heads = num_heads
386419
self.head_size = head_size
@@ -410,6 +443,9 @@ def __init__(
410443
"encoder/decoder cross-attention "
411444
"are not implemented for "
412445
"FlashAttentionImpl")
446+
self.sinks = sinks
447+
if self.sinks is not None:
448+
raise NotImplementedError("Sinks are not supported for ROCM AITER")
413449

414450
def forward(
415451
self,
@@ -491,6 +527,17 @@ def forward(
491527
block_table = attn_metadata.block_table
492528

493529
if max_seqlen_q > 1:
530+
if attn_metadata.cu_seq_lens is None:
531+
cu_seq_lens = torch.zeros(seqused_k.shape[0] + 1,
532+
dtype=torch.int32,
533+
device=query.device)
534+
torch.cumsum(seqused_k,
535+
dim=0,
536+
dtype=cu_seq_lens.dtype,
537+
out=cu_seq_lens[1:])
538+
else:
539+
cu_seq_lens = attn_metadata.cu_seq_lens
540+
494541
torch.ops.vllm.flash_attn_varlen_func(
495542
query[:num_actual_tokens],
496543
key_cache,
@@ -503,25 +550,29 @@ def forward(
503550
alibi_slopes=self.alibi_slopes,
504551
window_size=self.sliding_window,
505552
block_table=block_table,
506-
cu_seqlens_k=attn_metadata.cu_seq_lens,
553+
cu_seqlens_k=cu_seq_lens,
507554
k_scale=layer._k_scale,
508555
v_scale=layer._v_scale,
509556
total_tokens=attn_metadata.num_actual_kv_tokens,
557+
k_values=attn_metadata.k_buffer,
558+
v_values=attn_metadata.v_buffer,
510559
)
511-
512-
_, num_heads, head_size = query.shape
513-
nbytes_per_qo_elem = torch.finfo(query.dtype).bits // 8
514-
num_seqs = seqused_k.shape[0]
515-
max_num_partitions = (max_seqlen_k + _PARTITION_SIZE_ROCM -
516-
1) // _PARTITION_SIZE_ROCM
517-
518-
workspace_buffer = torch.empty(
519-
(num_seqs * num_heads * max_num_partitions * head_size) *
520-
nbytes_per_qo_elem + 2 *
521-
(num_seqs * num_heads * max_num_partitions) * 4,
522-
dtype=torch.uint8,
523-
device=output.device,
524-
)
560+
if attn_metadata.workspace_buffer is None:
561+
_, num_heads, head_size = query.shape
562+
nbytes_per_qo_elem = torch.finfo(query.dtype).bits // 8
563+
num_seqs = seqused_k.shape[0]
564+
max_num_partitions = (max_seqlen_k + _PARTITION_SIZE_ROCM -
565+
1) // _PARTITION_SIZE_ROCM
566+
567+
workspace_buffer = torch.empty(
568+
(num_seqs * num_heads * max_num_partitions * head_size) *
569+
nbytes_per_qo_elem + 2 *
570+
(num_seqs * num_heads * max_num_partitions) * 4,
571+
dtype=torch.uint8,
572+
device=output.device,
573+
)
574+
else:
575+
workspace_buffer = attn_metadata.workspace_buffer
525576

526577
torch.ops.aiter.paged_attention_v1(
527578
output[:num_actual_tokens],
@@ -543,6 +594,12 @@ def forward(
543594
None,
544595
_PARTITION_SIZE_ROCM,
545596
)
597+
if workspace_buffer is not None:
598+
workspace_buffer.zero_()
599+
if attn_metadata.k_buffer is not None:
600+
attn_metadata.k_buffer.zero_()
601+
if attn_metadata.v_buffer is not None:
602+
attn_metadata.v_buffer.zero_()
546603
return output
547604
else:
548605
raise NotImplementedError(

0 commit comments

Comments
 (0)