28
28
logger = init_logger (__name__ )
29
29
USE_XFORMERS_OPS = None
30
30
31
+ if current_platform .is_rocm ():
32
+ VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE = envs .VLLM_ROCM_USE_AITER and envs .VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE
31
33
32
34
def check_xformers_availability ():
33
35
global USE_XFORMERS_OPS
@@ -228,6 +230,9 @@ def forward(
228
230
# shape does not match the query shape, so we optionally let the model
229
231
# definition specify the output tensor shape.
230
232
output_shape : Optional [torch .Size ] = None ,
233
+ positions : torch .Tensor = None ,
234
+ cos_sin_cache : torch .Tensor = None ,
235
+ is_neox : bool = False ,
231
236
) -> torch .Tensor :
232
237
"""
233
238
The KV cache is stored inside this class and is accessed via
@@ -245,9 +250,15 @@ def forward(
245
250
if self .use_output :
246
251
output_shape = (output_shape
247
252
if output_shape is not None else query .shape )
248
- output = torch .zeros (output_shape ,
249
- dtype = query .dtype ,
250
- device = query .device )
253
+ if VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE :
254
+ output = torch .empty (output_shape ,
255
+ dtype = query .dtype ,
256
+ device = query .device )
257
+ else :
258
+ output = torch .zeros (output_shape ,
259
+ dtype = query .dtype ,
260
+ device = query .device )
261
+
251
262
hidden_size = output_shape [- 1 ]
252
263
# We skip reshaping query, key and value tensors for the MLA
253
264
# backend since these tensors have different semantics and are
@@ -269,15 +280,19 @@ def forward(
269
280
attn_metadata = attn_metadata [self .layer_name ]
270
281
self_kv_cache = self .kv_cache [forward_context .virtual_engine ]
271
282
self .impl .forward (self ,
272
- query ,
273
- key ,
274
- value ,
275
- self_kv_cache ,
276
- attn_metadata ,
277
- output = output )
283
+ query ,
284
+ key ,
285
+ value ,
286
+ self_kv_cache ,
287
+ attn_metadata ,
288
+ output = output )
278
289
else :
279
- torch .ops .vllm .unified_attention_with_output (
280
- query , key , value , output , self .layer_name )
290
+ if VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE :
291
+ torch .ops .vllm .unified_attention_with_output (
292
+ query , key , value , output , self .layer_name , None , positions , cos_sin_cache , True )
293
+ else :
294
+ torch .ops .vllm .unified_attention_with_output (
295
+ query , key , value , output , self .layer_name )
281
296
return output .view (- 1 , hidden_size )
282
297
else :
283
298
if self .use_direct_call :
@@ -485,6 +500,9 @@ def unified_attention_with_output(
485
500
output : torch .Tensor ,
486
501
layer_name : str ,
487
502
output_scale : Optional [torch .Tensor ] = None ,
503
+ positions : Optional [torch .Tensor ] = None ,
504
+ cos_sin_cache : Optional [torch .Tensor ] = None ,
505
+ is_neox : bool = False ,
488
506
) -> None :
489
507
wait_for_kv_layer_from_connector (layer_name )
490
508
forward_context : ForwardContext = get_forward_context ()
@@ -493,14 +511,29 @@ def unified_attention_with_output(
493
511
attn_metadata = attn_metadata [layer_name ]
494
512
self = forward_context .no_compile_layers [layer_name ]
495
513
kv_cache = self .kv_cache [forward_context .virtual_engine ]
496
- self .impl .forward (self ,
497
- query ,
498
- key ,
499
- value ,
500
- kv_cache ,
501
- attn_metadata ,
502
- output = output ,
503
- output_scale = output_scale )
514
+
515
+ if VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE :
516
+ from vllm .v1 .attention .backends .triton_attn import TritonAttentionImpl
517
+ assert isinstance (self .impl , TritonAttentionImpl ), f"Expect attention implementation = TritonAttentionImpl for VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE=1 but got { self .impl = } "
518
+ assert self .impl .kv_sharing_target_layer_name is None , "kv_sharing_target_layer_name error"
519
+ self .impl .forward (self ,
520
+ query ,
521
+ key ,
522
+ value ,
523
+ kv_cache ,
524
+ attn_metadata ,
525
+ output = output ,
526
+ output_scale = output_scale ,
527
+ positions = positions , cos_sin_cache = cos_sin_cache , is_neox = is_neox )
528
+ else :
529
+ self .impl .forward (self ,
530
+ query ,
531
+ key ,
532
+ value ,
533
+ kv_cache ,
534
+ attn_metadata ,
535
+ output = output ,
536
+ output_scale = output_scale )
504
537
505
538
maybe_save_kv_layer_to_connector (layer_name , kv_cache )
506
539
@@ -512,6 +545,9 @@ def unified_attention_with_output_fake(
512
545
output : torch .Tensor ,
513
546
layer_name : str ,
514
547
output_scale : Optional [torch .Tensor ] = None ,
548
+ positions : Optional [torch .Tensor ] = None ,
549
+ cos_sin_cache : Optional [torch .Tensor ] = None ,
550
+ is_neox : bool = False ,
515
551
) -> None :
516
552
return
517
553
0 commit comments