Skip to content

Commit 0589a1c

Browse files
authored
It worth not hiding exceptions (#21432)
* Worth not to hide exceptions What do you think about logging of exceptions, instead of hiding them? + There is an example of such hidden exception when head_shards and q_seq_shards were just not initialized because get_dist()==None. And despite hardware supports flash attention - this code silently falls back to the standard with O(N^2) memory. * pre-commit run
1 parent 5d50953 commit 0589a1c

File tree

1 file changed

+21
-20
lines changed
  • keras/src/backend/jax

1 file changed

+21
-20
lines changed

keras/src/backend/jax/nn.py

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import jax
66
import jax.experimental.sparse as jax_sparse
77
import jax.numpy as jnp
8+
from absl import logging
89
from jax import lax
910
from jax import nn as jnn
1011
from jax.experimental.pallas.ops.tpu.splash_attention import (
@@ -1256,6 +1257,9 @@ def dot_product_attention(
12561257
# TPU-specific flash attention path
12571258
if is_tpu and flash_attention:
12581259
# Get sharding parameters from distribution context
1260+
head_shards = 1
1261+
# Typically keep q_seq_shards=1 for best performance
1262+
q_seq_shards = 1
12591263
try:
12601264
from keras.src.distribution.distribution_lib import ModelParallel
12611265
from keras.src.distribution.distribution_lib import (
@@ -1270,12 +1274,12 @@ def dot_product_attention(
12701274
model_dim_index = mesh.axis_names.index("model")
12711275
# Set head_shards based on the model dimension of the mesh
12721276
head_shards = mesh.shape[model_dim_index]
1273-
# Typically keep q_seq_shards=1 for best performance
1274-
q_seq_shards = 1
12751277
except (ImportError, ValueError, AttributeError):
12761278
# Use default values if detection fails
1277-
head_shards = 1
1278-
q_seq_shards = 1
1279+
logging.exception(
1280+
"Failed to determine distribution context for sharding. "
1281+
"Using default head_shards=1 and q_seq_shards=1."
1282+
)
12791283
# Transpose to ('batch', 'heads', 'length', 'head_dim')
12801284
query_tpu_layout = jnp.transpose(query, axes=(0, 2, 1, 3))
12811285
key_tpu_layout = jnp.transpose(key, axes=(0, 2, 1, 3))
@@ -1328,24 +1332,17 @@ def dot_product_attention(
13281332
# Transpose output back to Keras layout
13291333
return jnp.transpose(output, axes=(0, 2, 1, 3))
13301334
except Exception:
1335+
logging.exception(
1336+
"Failed to apply Splash kernel for flash attention. "
1337+
"Falling back to JAX native dot_product_attention."
1338+
)
13311339
flash_attention = False
13321340

13331341
# JAX native dot_product_attention for GPU or fallback for TPU
13341342
if hasattr(jax.nn, "dot_product_attention"):
1335-
try:
1336-
return jax.nn.dot_product_attention(
1337-
query,
1338-
key,
1339-
value,
1340-
bias=bias,
1341-
mask=mask,
1342-
scale=scale,
1343-
is_causal=is_causal,
1344-
implementation="cudnn" if flash_attention else "xla",
1345-
)
1346-
except Exception:
1347-
# If flash attention fails, fall back to XLA implementation
1348-
if flash_attention:
1343+
impls = ["cudnn", "xla"] if flash_attention else ["xla"]
1344+
for impl in impls:
1345+
try:
13491346
return jax.nn.dot_product_attention(
13501347
query,
13511348
key,
@@ -1354,9 +1351,13 @@ def dot_product_attention(
13541351
mask=mask,
13551352
scale=scale,
13561353
is_causal=is_causal,
1357-
implementation="xla",
1354+
implementation=impl,
1355+
)
1356+
except Exception:
1357+
logging.exception(
1358+
f"Failed to apply {impl} implementation of "
1359+
"jax.nn.dot_product_attention."
13581360
)
1359-
raise
13601361

13611362
if flash_attention:
13621363
raise RuntimeError(

0 commit comments

Comments
 (0)