@@ -550,8 +550,10 @@ def forward(
550550 attn_metadata : ROCmFlashAttentionMetadata ,
551551 k_scale : torch .Tensor ,
552552 v_scale : torch .Tensor ,
553+ q_scale : Optional [torch .Tensor ] = None ,
554+ prob_scale : Optional [torch .Tensor ] = None ,
555+ fp8_out_scale : Optional [torch .Tensor ] = None ,
553556 output : Optional [torch .Tensor ] = None ,
554- fp8_comp_scales : List [Optional [torch .Tensor ]] = None ,
555557 ) -> torch .Tensor :
556558 """Forward pass with FlashAttention and PagedAttention.
557559
@@ -601,9 +603,6 @@ def forward(
601603 Returns:
602604 shape = [num_tokens, num_heads * head_size]
603605 """
604- q_scale , prob_scale , fp8_out_scale = fp8_comp_scales or [None , None ,
605- None ]
606-
607606 query = query .view (- 1 , self .num_heads , self .head_size )
608607 if key is not None :
609608 assert value is not None
@@ -687,7 +686,7 @@ def forward(
687686 1.0 / q_scale .item (), 1.0 / k_scale .item (),
688687 1.0 / v_scale .item (), 1.0 / prob_scale .item (),
689688 fp8_out_scale .item ()) if (
690- fp8_out_scale
689+ fp8_out_scale and q_scale and prob_scale
691690 and envs .VLLM_USE_ROCM_FP8_FLASH_ATTN ) else None
692691 out , _ = self .attn_func (
693692 query ,
0 commit comments