@@ -3238,6 +3238,26 @@ def __call__(
32383238
32393239 return hidden_states
32403240
3241+ def xla_scaled_dot_product_attention (query , key , value , attn_mask = None , dropout_p = 0.0 , is_causal = False , scale = None ) -> torch .Tensor :
3242+ L , S = query .size (- 2 ), key .size (- 2 )
3243+ scale_factor = 1 / math .sqrt (query .size (- 1 )) if scale is None else scale
3244+ attn_bias = torch .zeros (L , S , dtype = query .dtype )
3245+ if is_causal :
3246+ assert attn_mask is None
3247+ temp_mask = torch .ones (L , S , dtype = torch .bool ).tril (diagonal = 0 )
3248+ attn_bias .masked_fill_ (temp_mask .logical_not (), float ("-inf" ))
3249+ attn_bias .to (query .dtype )
3250+
3251+ if attn_mask is not None :
3252+ if attn_mask .dtype == torch .bool :
3253+ attn_bias .masked_fill_ (attn_mask .logical_not (), float ("-inf" ))
3254+ else :
3255+ attn_bias += attn_mask
3256+ attn_weight = query @ key .transpose (- 2 , - 1 ) * scale_factor
3257+ attn_weight += attn_bias
3258+ attn_weight = torch .softmax (attn_weight , dim = - 1 )
3259+ attn_weight = torch .dropout (attn_weight , dropout_p , train = True )
3260+ return attn_weight @ value
32413261
32423262class AttnProcessor2_0 :
32433263 r"""
@@ -3310,7 +3330,7 @@ def __call__(
33103330
33113331 # the output of sdp = (batch, num_heads, seq_len, head_dim)
33123332 # TODO: add support for attn.scale when we move to Torch 2.1
3313- hidden_states = F . scaled_dot_product_attention (
3333+ hidden_states = self . xla_scaled_dot_product_attention (
33143334 query , key , value , attn_mask = attention_mask , dropout_p = 0.0 , is_causal = False
33153335 )
33163336
@@ -3408,6 +3428,7 @@ def __call__(
34083428 # the output of sdp = (batch, num_heads, seq_len, head_dim)
34093429 # TODO: add support for attn.scale when we move to Torch 2.1
34103430 if all (tensor .shape [2 ] >= 4096 for tensor in [query , key , value ]):
3431+ logger .warning ("Using flash attention" )
34113432 if attention_mask is not None :
34123433 attention_mask = attention_mask .view (batch_size , 1 , 1 , attention_mask .shape [- 1 ])
34133434 # Convert mask to float and replace 0s with -inf and 1s with 0
@@ -3426,7 +3447,7 @@ def __call__(
34263447 logger .warning (
34273448 "Unable to use the flash attention pallas kernel API call due to QKV sequence length < 4096."
34283449 )
3429- hidden_states = F . scaled_dot_product_attention (
3450+ hidden_states = xla_scaled_dot_product_attention (
34303451 query , key , value , attn_mask = attention_mask , dropout_p = 0.0 , is_causal = False
34313452 )
34323453
0 commit comments