66
77from keras_hub .src .layers .modeling .rotary_embedding import RotaryEmbedding
88from keras_hub .src .utils .keras_utils import clone_initializer
9- from keras_hub .src .utils .keras_utils import has_flash_attention_support
9+ from keras_hub .src .utils .keras_utils import fused_attention_op_available
10+ from keras_hub .src .utils .keras_utils import gpu_supports_fused_attention_op
11+ from keras_hub .src .utils .keras_utils import running_on_gpu
1012from keras_hub .src .utils .keras_utils import running_on_tpu
1113
1214
@@ -106,17 +108,22 @@ def _apply_rope(self, x, start_index):
106108 )
107109 return x
108110
109- def _can_use_flash_attention (self ):
110- if not has_flash_attention_support ():
111+ def _use_fused_attention_op (self ):
112+ if not fused_attention_op_available ():
111113 return False
112114 if self .dropout > 0.0 :
113115 return False
114- if self .logit_soft_cap is None :
115- return True
116- sig = inspect .signature (ops .dot_product_attention )
117- # We can currently only run soft capped attention for keras >= 3.10
118- # and only on TPU.
119- return running_on_tpu () and "attn_logits_soft_cap" in sig .parameters
116+ if running_on_gpu ():
117+ # GPU never supports softcap in the fused op.
118+ if self .logit_soft_cap is not None :
119+ return False
120+ return gpu_supports_fused_attention_op ()
121+ elif running_on_tpu ():
122+ # TPU supports softcap with on keras >= 3.10.
123+ sig = inspect .signature (ops .dot_product_attention )
124+ return "attn_logits_soft_cap" in sig .parameters
125+ else :
126+ return False
120127
121128 def _compute_attention (
122129 self ,
@@ -140,7 +147,7 @@ def _compute_attention(
140147 cache_update_index = cache_update_index ,
141148 )
142149
143- if self ._can_use_flash_attention ():
150+ if self ._use_fused_attention_op ():
144151 if attention_mask is not None :
145152 attention_mask = ops .expand_dims (attention_mask , axis = 1 )
146153 attention_mask = ops .cast (attention_mask , dtype = "bool" )
0 commit comments