Skip to content

Commit 25f843d

Browse files
authored
Merge branch '355_wip' into ck_tile_gemm
2 parents 5800181 + 48dc133 commit 25f843d

File tree

6 files changed

+130
-52
lines changed

6 files changed

+130
-52
lines changed

vllm/attention/layer.py

Lines changed: 55 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828
logger = init_logger(__name__)
2929
USE_XFORMERS_OPS = None
3030

31+
if current_platform.is_rocm():
32+
VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE = envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE
3133

3234
def check_xformers_availability():
3335
global USE_XFORMERS_OPS
@@ -228,6 +230,9 @@ def forward(
228230
# shape does not match the query shape, so we optionally let the model
229231
# definition specify the output tensor shape.
230232
output_shape: Optional[torch.Size] = None,
233+
positions: torch.Tensor = None,
234+
cos_sin_cache: torch.Tensor = None,
235+
is_neox: bool = False,
231236
) -> torch.Tensor:
232237
"""
233238
The KV cache is stored inside this class and is accessed via
@@ -245,9 +250,15 @@ def forward(
245250
if self.use_output:
246251
output_shape = (output_shape
247252
if output_shape is not None else query.shape)
248-
output = torch.zeros(output_shape,
249-
dtype=query.dtype,
250-
device=query.device)
253+
if VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE:
254+
output = torch.empty(output_shape,
255+
dtype=query.dtype,
256+
device=query.device)
257+
else:
258+
output = torch.zeros(output_shape,
259+
dtype=query.dtype,
260+
device=query.device)
261+
251262
hidden_size = output_shape[-1]
252263
# We skip reshaping query, key and value tensors for the MLA
253264
# backend since these tensors have different semantics and are
@@ -269,15 +280,19 @@ def forward(
269280
attn_metadata = attn_metadata[self.layer_name]
270281
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
271282
self.impl.forward(self,
272-
query,
273-
key,
274-
value,
275-
self_kv_cache,
276-
attn_metadata,
277-
output=output)
283+
query,
284+
key,
285+
value,
286+
self_kv_cache,
287+
attn_metadata,
288+
output=output)
278289
else:
279-
torch.ops.vllm.unified_attention_with_output(
280-
query, key, value, output, self.layer_name)
290+
if VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE:
291+
torch.ops.vllm.unified_attention_with_output(
292+
query, key, value, output, self.layer_name, None, positions, cos_sin_cache, True)
293+
else:
294+
torch.ops.vllm.unified_attention_with_output(
295+
query, key, value, output, self.layer_name)
281296
return output.view(-1, hidden_size)
282297
else:
283298
if self.use_direct_call:
@@ -485,6 +500,9 @@ def unified_attention_with_output(
485500
output: torch.Tensor,
486501
layer_name: str,
487502
output_scale: Optional[torch.Tensor] = None,
503+
positions: Optional[torch.Tensor] = None,
504+
cos_sin_cache: Optional[torch.Tensor] = None,
505+
is_neox: bool = False,
488506
) -> None:
489507
wait_for_kv_layer_from_connector(layer_name)
490508
forward_context: ForwardContext = get_forward_context()
@@ -493,14 +511,29 @@ def unified_attention_with_output(
493511
attn_metadata = attn_metadata[layer_name]
494512
self = forward_context.no_compile_layers[layer_name]
495513
kv_cache = self.kv_cache[forward_context.virtual_engine]
496-
self.impl.forward(self,
497-
query,
498-
key,
499-
value,
500-
kv_cache,
501-
attn_metadata,
502-
output=output,
503-
output_scale=output_scale)
514+
515+
if VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE:
516+
from vllm.v1.attention.backends.triton_attn import TritonAttentionImpl
517+
assert isinstance(self.impl, TritonAttentionImpl), f"Expect attention implementation = TritonAttentionImpl for VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE=1 but got {self.impl=}"
518+
assert self.impl.kv_sharing_target_layer_name is None, "kv_sharing_target_layer_name error"
519+
self.impl.forward(self,
520+
query,
521+
key,
522+
value,
523+
kv_cache,
524+
attn_metadata,
525+
output=output,
526+
output_scale=output_scale,
527+
positions=positions, cos_sin_cache=cos_sin_cache, is_neox=is_neox)
528+
else:
529+
self.impl.forward(self,
530+
query,
531+
key,
532+
value,
533+
kv_cache,
534+
attn_metadata,
535+
output=output,
536+
output_scale=output_scale)
504537

505538
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
506539

@@ -512,6 +545,9 @@ def unified_attention_with_output_fake(
512545
output: torch.Tensor,
513546
layer_name: str,
514547
output_scale: Optional[torch.Tensor] = None,
548+
positions: Optional[torch.Tensor] = None,
549+
cos_sin_cache: Optional[torch.Tensor] = None,
550+
is_neox: bool = False,
515551
) -> None:
516552
return
517553

vllm/envs.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,7 @@
166166
VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8: bool = False
167167
VLLM_USE_FLASHINFER_MOE_MXFP4_BF16: bool = False
168168
VLLM_TUNED_CONFIG_FOLDER: Optional[str] = None
169+
VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE: bool = False
169170

170171

171172
def get_default_cache_root():
@@ -1176,6 +1177,10 @@ def get_vllm_port() -> Optional[int]:
11761177
"VLLM_TUNED_CONFIG_FOLDER":
11771178
lambda: os.getenv("VLLM_TUNED_CONFIG_FOLDER", None),
11781179

1180+
# Use AITER Triton fused rope + zeros + reshape_and_cache
1181+
"VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE":
1182+
lambda: bool(int(os.getenv("VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE", "0"))),
1183+
11791184
}
11801185

11811186
# --8<-- [end:env-vars-definition]

vllm/model_executor/layers/quantization/fp8.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -203,8 +203,7 @@ def __init__(self, quant_config: Fp8Config):
203203
# and at the moment are MI300 series
204204
self.use_aiter_and_is_supported = (current_platform.is_rocm()
205205
and envs.VLLM_ROCM_USE_AITER
206-
and envs.VLLM_ROCM_USE_AITER_LINEAR
207-
and current_platform.is_fp8_fnuz())
206+
and envs.VLLM_ROCM_USE_AITER_LINEAR)
208207
self.use_ck_tile_and_is_supported = (
209208
current_platform.is_rocm() and envs.VLLM_ROCM_USE_AITER
210209
and envs.VLLM_ROCM_USE_AITER_CK_TILE_LINEAR

vllm/model_executor/layers/quantization/utils/fp8_utils.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,12 @@ def rocm_aiter_gemm_w8a8_blockscale_impl(
5454
block_size: list[int],
5555
output_dtype: torch.dtype = torch.float16,
5656
) -> torch.Tensor:
57-
import aiter as rocm_aiter
57+
# import aiter as rocm_aiter
58+
59+
# return rocm_aiter.gemm_a8w8_blockscale(A, B, As, Bs, dtype=output_dtype)
60+
from aiter.ops.triton.gemm_a8w8_blockscale import gemm_a8w8_blockscale
5861

59-
return rocm_aiter.gemm_a8w8_blockscale(A, B, As, Bs, dtype=output_dtype)
62+
return gemm_a8w8_blockscale(A, B, As, Bs, dtype=output_dtype)
6063

6164

6265
def rocm_aiter_gemm_w8a8_blockscale_fake(
@@ -238,7 +241,7 @@ def apply_w8a8_block_fp8_linear(
238241
block_size, input.dtype)
239242

240243
else:
241-
if use_aiter_and_is_supported or use_ck_tile_and_is_supported:
244+
if use_aiter_and_is_supported and current_platform.is_fp8_fnuz():
242245
q_input, x_scale = aiter_per1x128_quant(
243246
input_2d.contiguous(), quant_dtype=rocm_aiter.dtypes.fp8)
244247
else:

vllm/model_executor/models/llama.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,9 @@
5656
make_empty_intermediate_tensors_factory, make_layers,
5757
maybe_prefix)
5858

59+
from vllm.platforms import current_platform
60+
if current_platform.is_rocm():
61+
VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE = envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE
5962

6063
class LlamaMLP(nn.Module):
6164

@@ -197,8 +200,14 @@ def forward(
197200
) -> torch.Tensor:
198201
qkv, _ = self.qkv_proj(hidden_states)
199202
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
200-
q, k = self.rotary_emb(positions, q, k)
201-
attn_output = self.attn(q, k, v)
203+
204+
if VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE:
205+
attn_output = self.attn(q, k, v,
206+
positions=positions, cos_sin_cache=self.rotary_emb.cos_sin_cache, is_neox=self.rotary_emb.is_neox_style)
207+
else:
208+
q, k = self.rotary_emb(positions, q, k)
209+
attn_output = self.attn(q, k, v)
210+
202211
output, _ = self.o_proj(attn_output)
203212
return output
204213

vllm/v1/attention/backends/triton_attn.py

Lines changed: 52 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@
2525

2626
logger = init_logger(__name__)
2727

28+
if current_platform.is_rocm():
29+
VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE = envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE
30+
if VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE:
31+
from aiter.ops.triton.fused_kv_cache import fused_qk_rope_reshape_and_cache
2832

2933
@dataclass
3034
class TritonAttentionMetadata:
@@ -288,6 +292,9 @@ def forward(
288292
attn_metadata: FlashAttentionMetadata,
289293
output: Optional[torch.Tensor] = None,
290294
output_scale: Optional[torch.Tensor] = None,
295+
positions: torch.Tensor = None,
296+
cos_sin_cache: torch.Tensor = None,
297+
is_neox: bool = False,
291298
) -> torch.Tensor:
292299
"""Forward pass with FlashAttention.
293300
@@ -325,32 +332,51 @@ def forward(
325332
kv_cache, self.num_kv_heads, self.head_size)
326333
else:
327334
key_cache, value_cache = kv_cache.unbind(0)
328-
329-
if self.kv_sharing_target_layer_name is None:
330-
# Reshape the input keys and values and store them in the cache.
331-
# Skip this if sharing KV cache with an earlier attention layer.
332-
if use_prefill_decode_attn:
333-
PagedAttention.write_to_paged_cache(
334-
key,
335-
value,
336-
key_cache,
337-
value_cache,
338-
attn_metadata.slot_mapping,
339-
self.kv_cache_dtype,
340-
layer._k_scale,
341-
layer._v_scale,
342-
)
343-
else:
344-
torch.ops._C_cache_ops.reshape_and_cache_flash(
345-
key,
346-
value,
347-
key_cache,
348-
value_cache,
349-
attn_metadata.slot_mapping,
350-
self.kv_cache_dtype,
351-
layer._k_scale,
352-
layer._v_scale,
353-
)
335+
336+
if VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE:
337+
assert self.kv_sharing_target_layer_name is None, "self.kv_sharing_target_layer_name error"
338+
cos, sin = cos_sin_cache.chunk(2, dim = -1)
339+
is_fp8_kv_cache = self.kv_cache_dtype.startswith("fp8")
340+
if is_fp8_kv_cache:
341+
key_cache_og_dtype = key_cache.dtype
342+
value_cache_og_dtype = value_cache.dtype
343+
key_cache = key_cache.view(self.fp8_dtype)
344+
value_cache = value_cache.view(self.fp8_dtype)
345+
query, key, key_cache, value_cache, output = fused_qk_rope_reshape_and_cache(
346+
query, key, value, key_cache, value_cache, attn_metadata.slot_mapping,
347+
positions, cos, sin,
348+
layer._k_scale, layer._v_scale,
349+
is_neox,
350+
flash_layout=(not use_prefill_decode_attn), apply_scale=is_fp8_kv_cache, offs=None, q_out=query, k_out=key, output_zeros=True, zeros_out=output)
351+
if is_fp8_kv_cache:
352+
key_cache = key_cache.view(key_cache_og_dtype)
353+
value_cache = value_cache.view(value_cache_og_dtype)
354+
else:
355+
if self.kv_sharing_target_layer_name is None:
356+
# Reshape the input keys and values and store them in the cache.
357+
# Skip this if sharing KV cache with an earlier attention layer.
358+
if use_prefill_decode_attn:
359+
PagedAttention.write_to_paged_cache(
360+
key,
361+
value,
362+
key_cache,
363+
value_cache,
364+
attn_metadata.slot_mapping,
365+
self.kv_cache_dtype,
366+
layer._k_scale,
367+
layer._v_scale,
368+
)
369+
else:
370+
torch.ops._C_cache_ops.reshape_and_cache_flash(
371+
key,
372+
value,
373+
key_cache,
374+
value_cache,
375+
attn_metadata.slot_mapping,
376+
self.kv_cache_dtype,
377+
layer._k_scale,
378+
layer._v_scale,
379+
)
354380

355381
if self.kv_cache_dtype.startswith("fp8"):
356382
key_cache = key_cache.view(self.fp8_dtype)

0 commit comments

Comments
 (0)