5
5
import jax
6
6
import jax .experimental .sparse as jax_sparse
7
7
import jax .numpy as jnp
8
+ from absl import logging
8
9
from jax import lax
9
10
from jax import nn as jnn
10
11
from jax .experimental .pallas .ops .tpu .splash_attention import (
@@ -1256,6 +1257,9 @@ def dot_product_attention(
1256
1257
# TPU-specific flash attention path
1257
1258
if is_tpu and flash_attention :
1258
1259
# Get sharding parameters from distribution context
1260
+ head_shards = 1
1261
+ # Typically keep q_seq_shards=1 for best performance
1262
+ q_seq_shards = 1
1259
1263
try :
1260
1264
from keras .src .distribution .distribution_lib import ModelParallel
1261
1265
from keras .src .distribution .distribution_lib import (
@@ -1270,12 +1274,12 @@ def dot_product_attention(
1270
1274
model_dim_index = mesh .axis_names .index ("model" )
1271
1275
# Set head_shards based on the model dimension of the mesh
1272
1276
head_shards = mesh .shape [model_dim_index ]
1273
- # Typically keep q_seq_shards=1 for best performance
1274
- q_seq_shards = 1
1275
1277
except (ImportError , ValueError , AttributeError ):
1276
1278
# Use default values if detection fails
1277
- head_shards = 1
1278
- q_seq_shards = 1
1279
+ logging .exception (
1280
+ "Failed to determine distribution context for sharding. "
1281
+ "Using default head_shards=1 and q_seq_shards=1."
1282
+ )
1279
1283
# Transpose to ('batch', 'heads', 'length', 'head_dim')
1280
1284
query_tpu_layout = jnp .transpose (query , axes = (0 , 2 , 1 , 3 ))
1281
1285
key_tpu_layout = jnp .transpose (key , axes = (0 , 2 , 1 , 3 ))
@@ -1328,24 +1332,17 @@ def dot_product_attention(
1328
1332
# Transpose output back to Keras layout
1329
1333
return jnp .transpose (output , axes = (0 , 2 , 1 , 3 ))
1330
1334
except Exception :
1335
+ logging .exception (
1336
+ "Failed to apply Splash kernel for flash attention. "
1337
+ "Falling back to JAX native dot_product_attention."
1338
+ )
1331
1339
flash_attention = False
1332
1340
1333
1341
# JAX native dot_product_attention for GPU or fallback for TPU
1334
1342
if hasattr (jax .nn , "dot_product_attention" ):
1335
- try :
1336
- return jax .nn .dot_product_attention (
1337
- query ,
1338
- key ,
1339
- value ,
1340
- bias = bias ,
1341
- mask = mask ,
1342
- scale = scale ,
1343
- is_causal = is_causal ,
1344
- implementation = "cudnn" if flash_attention else "xla" ,
1345
- )
1346
- except Exception :
1347
- # If flash attention fails, fall back to XLA implementation
1348
- if flash_attention :
1343
+ impls = ["cudnn" , "xla" ] if flash_attention else ["xla" ]
1344
+ for impl in impls :
1345
+ try :
1349
1346
return jax .nn .dot_product_attention (
1350
1347
query ,
1351
1348
key ,
@@ -1354,9 +1351,13 @@ def dot_product_attention(
1354
1351
mask = mask ,
1355
1352
scale = scale ,
1356
1353
is_causal = is_causal ,
1357
- implementation = "xla" ,
1354
+ implementation = impl ,
1355
+ )
1356
+ except Exception :
1357
+ logging .exception (
1358
+ f"Failed to apply { impl } implementation of "
1359
+ "jax.nn.dot_product_attention."
1358
1360
)
1359
- raise
1360
1361
1361
1362
if flash_attention :
1362
1363
raise RuntimeError (
0 commit comments