File tree Expand file tree Collapse file tree 1 file changed +6
-2
lines changed
Expand file tree Collapse file tree 1 file changed +6
-2
lines changed Original file line number Diff line number Diff line change @@ -576,6 +576,9 @@ def _flash_attention(
576576 query : torch .Tensor ,
577577 key : torch .Tensor ,
578578 value : torch .Tensor ,
579+ dropout_p : float = 0.0 ,
580+ scale : Optional [float ] = None ,
581+ is_causal : bool = False ,
579582 window_size : Tuple [int , int ] = (- 1 , - 1 ),
580583 softcap : float = 0.0 ,
581584 alibi_slopes : Optional [torch .Tensor ] = None ,
@@ -586,6 +589,9 @@ def _flash_attention(
586589 q = query ,
587590 k = key ,
588591 v = value ,
592+ dropout_p = dropout_p ,
593+ softmax_scale = scale ,
594+ causal = is_causal ,
589595 window_size = window_size ,
590596 softcap = softcap ,
591597 alibi_slopes = alibi_slopes ,
@@ -748,7 +754,6 @@ def _flash_attention_hub(
748754 query : torch .Tensor ,
749755 key : torch .Tensor ,
750756 value : torch .Tensor ,
751- dropout_p : float = 0.0 ,
752757 scale : Optional [float ] = None ,
753758 is_causal : bool = False ,
754759 window_size : Tuple [int , int ] = (- 1 , - 1 ),
@@ -761,7 +766,6 @@ def _flash_attention_hub(
761766 q = query ,
762767 k = key ,
763768 v = value ,
764- dropout_p = dropout_p ,
765769 softmax_scale = scale ,
766770 causal = is_causal ,
767771 window_size = window_size ,
You can’t perform that action at this time.
0 commit comments